Tree Training: Accelerating Agentic LLMs Training via Shared Prefix Reuse

arXiv:2511.00413v4 Announce Type: replace Abstract: Agentic large language model (LLM) training often involves multi-turn interaction trajectories that branch into multiple execution paths due to concurrent tool use, think-mode, sub-agent, context management and other runtime designs. As a result, the tokens produced by a single task naturally form a tree-structured token trajectory with shared prefixes, rather than a linear sequence. Existing training pipelines linearize such trajectories and treat each branch independently, leading to substantial redundant computation in both forward and backward passes. We derive that averaging the loss over all branches independently is algebraically identical to a per-token weighted loss, where each token's weight equals the fraction of branches passing through it. The problem therefore reduces to computing the log-probability of every token in the prefix tree exactly once, with no repeated computation across shared prefixes: we propose DFS serialization of the tree, which visits every token exactly once, and adapt full-attention and SSM layers to ensure the resulting log-probabilities match independent per-branch calculation exactly. In practice, a single trajectory tree can be too large to fit in GPU memory; we therefore propose Tree Partitioning, a memory-efficient partitioning strategy that splits the tree into subtrees each fitting within GPU memory while preserving high prefix reuse. Together, these contributions form Tree Training, an efficient framework for training LLMs on tree-structured trajectories, achieving up to 6.2x end-to-end training speedup on dense and MOE models for both supervised fine-tuning and reinforcement learning.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top