TokenButler: Token Importance is Predictable
arXiv:2503.07518v2 Announce Type: replace-cross
Abstract: Large Language Models (LLMs) rely on the Key-Value (KV) Cache to store token history, enabling efficient decoding of tokens. As the KV-Cache grows, it becomes a major memory and computation bottleneck. However, there is an opportunity to alleviate this bottleneck, prior research has shown that only a small subset of tokens contribute meaningfully to each decoding step. A key challenge in finding these critical tokens is that they are dynamic, and heavily input query-dependent. Existing methods either risk quality by evicting tokens permanently, or retain the full KV-Cache but rely on retrieving chunks of tokens and many existing KV-Cache sparsity methods rely on inaccurate proxies for token importance. To address these limitations, we introduce TokenButler, a high-granularity, query-aware predictor that learns to identify these critical tokens. TokenButler predicts low-dimensional importance queries at a fixed depth stride, and combines them with a learned projection of the real KV-cache keys to score tokens cheaply, enabling dynamic per-token selection under a fixed budget while preserving the full KV cache. We train TokenButler by distilling the model's masked causal attention distributions, optimizing a lightweight predictor with minimal parameter overhead. We evaluate TokenButler on a novel synthetic small-context co-referential retrieval task, demonstrating near-oracle accuracy where existing methods fail. Furthermore, TokenButler achieves competitive or superior performance on long-context benchmarks (RULER, LongBench), up to $\approx1.6\times$ on-GPU speedup using our proposed *prediction interval with neighbor fetching* that amortizes predictor cost while maintaining accuracy within $\approx$1.1\%, and up to 7.6$\times$ reduction in latency compared to Dense Attention with CPU offloading. Code is available: https://github.com/abdelfattah-lab/TokenButler