FlashNorm: Fast Normalization for Transformers
arXiv:2407.09577v4 Announce Type: replace
Abstract: Normalization layers are ubiquitous in large language models (LLMs) yet represent a compute bottleneck: on hardware with distinct vector and matrix execution units, the RMS calculation blocks the subsequent matrix multiplication, preventing parallel execution.
We present FlashNorm, an exact reformulation of RMSNorm followed by a linear layer that (i) eliminates the normalization weights by folding them into the subsequent linear layer, and (ii) defers the scalar RMS normalization to the output of the matrix multiplication, enabling the two operations to execute in parallel.
FlashNorm is mathematically identical to the original computation, it introduces no approximation and requires no retraining. The same technique extends to LayerNorm, Dynamic Tanh (DyT), feed-forward networks with GLU variants, and RoPE-based attention.
On an NVIDIA T4 GPU, FlashNorm achieves 33 to 35% lower latency on the norm-then-project operation in the compute-bound (prefill) regime at SmolLM2-135M scale, and 12 to 14% at Llama-7B scale. We verify zero-loss weight folding on SmolLM2-135M, Llama-3.2-1B, and Llama-3.1-8B.
Beyond inference speed, FlashNorm simplifies model implementations by reducing parameter tensor count, analogous to the simplification achieved by PaLM's removal of bias-parameters from all linear layers.
Watch our explainer video https://youtu.be/GEuJv34_XgU?si and see https://github.com/OpenMachine-ai/transformer-tricks for code.