FlashNorm: Fast Normalization for Transformers
arXiv:2407.09577v5 Announce Type: replace
Abstract: Normalization layers are ubiquitous in large language models (LLMs) yet represent a compute bottleneck: on hardware with distinct vector and matrix execution units, the RMS calculation blocks the subsequent matrix multiplication, preventing parallel execution.
We present FlashNorm, an exact reformulation of RMSNorm followed by a linear layer that (i) eliminates the normalization weights by folding them into the subsequent linear layer, and (ii) defers the scalar RMS normalization to the output of the matrix multiplication, enabling the two operations to execute in parallel.
Additionally, by the scale invariance of RMS, an RMSNorm followed by a linear layer followed by another RMSNorm allows the first RMSNorm to be eliminated entirely -- a mathematically identical simplification that removes the pre-attention RMSNorm in models using QKV-normalization (e.g., Gemma~4) and in MLA-models with latent normalization (e.g., DeepSeek-V2, Mistral Small 4, and OpenMythos).
The same techniques extend to LayerNorm, Dynamic Tanh (DyT), feed-forward networks with GLU variants, and RoPE-based attention.
On an NVIDIA T4 GPU, FlashNorm achieves 33 - 35% lower latency on the norm-then-project operation in the compute-bound (prefill) regime at SmolLM2-135M scale, and 12 - 14% at Llama-7B scale. We verify zero-loss weight folding on three models. Beyond inference speed, FlashNorm simplifies model implementations by reducing parameter tensor count.
Watch our explainer video https://youtu.be/GEuJv34_XgU?si and see https://github.com/OpenMachine-ai/transformer-tricks for code.