Dispatch-Aware Ragged Attention for Pruned Vision Transformers
arXiv:2604.15408v2 Announce Type: replace
Abstract: Token pruning methods for Vision Transformers (ViTs) promise quadratic reductions in attention FLOPs by dropping uninformative patches. Yet standard variable-length attention APIs -- including FlashAttention-2's varlen and PyTorch's NestedTensor SDPA -- fail to translate these savings into proportional wall-clock gains at the short post-pruning sequence lengths typical of ViTs ($\leq$197 tokens). We identify a dispatch-overhead bottleneck: at these lengths, host-side kernel dispatch consumes ${\sim}$50\,$\mu$s regardless of workload, exceeding the actual GPU compute time at moderate-to-high pruning rates. We present a lightweight bidirectional Triton attention kernel whose dispatch floor is ${\sim}$24\,$\mu$s -- roughly 2.17$\times$ lower than FlashAttention-2 varlen -- allowing pruning savings to become visible in wall-clock time. Integrated into a complete pack-attend-unpack pipeline and evaluated on an NVIDIA RTX 4000 Ada Generation GPU, our system achieves 1.88$\times$ end-to-end throughput over padded PyTorch SDPA at standard 224$\times$224 inputs, scaling to 2.51$\times$ at 384$\times$384. Against FlashAttention-2 varlen -- the strongest baseline -- our kernel delivers 9-12\% higher throughput at serving batch sizes (BS=1-4), and 2.17$\times$ lower kernel latency at 80\% token pruning. Numerical correctness is verified with max absolute logit difference $<$0.004 and bit-exact top-1 predictions.