A Faster Generalized Two-Stage Approximate Top-K
arXiv:2506.04165v3 Announce Type: replace
Abstract: We consider the Top-$K$ selection problem, which aims to identify the largest $K$ elements in an array. Top-$K$ selection arises in many machine learning algorithms and often becomes a bottleneck on accelerators, which are optimized for dense matrix multiplications. To address this problem, Chern et al. (2022) proposed a fast two-stage approximate Top-$K$ algorithm that: (i) partitions the input array into equal-sized chunks and selects the top-$1$ element from each partition; and (ii) sorts the resulting smaller subset and returns the top $K$ elements. In this paper, we generalize the first stage so that each partition selects the top $K'$ elements (for $1 \leq K' \leq K$). Our contributions include: (i) an expression for the expected recall of this generalized algorithm under random partitioning, and a demonstration that choosing $K' > 1$ with fewer partitions in the first stage more effectively reduces the input size to the second stage while maintaining the same expected recall as the original algorithm; (ii) a bound on the expected recall of the original algorithm as a function of the algorithm parameters that is provably tighter by a factor of $2$ than the bound reported by Chern et al. (2022); and (iii) an implementation of our algorithm on Cloud TPUv5e that achieves approximately an order of magnitude speedup over the original algorithm without sacrificing recall.