Every resource says "We scale by 1/√d_k to prevent softmax saturation." Almost none of them explain why saturation happens or why that specific scaling constant appears.
When you compute Q·Kᵀ without scaling, each element is a dot product of two d_k-dimensional vectors. If the components of Q and K are approximately zero-mean with unit variance at initialization, the dot product has variance d_k (sum of d_k independent products, each with variance 1). So the pre-softmax attention scores have standard deviation √d_k.
At d_k = 64 (typical), standard deviation ≈ 8. Your softmax input can easily look like [-20, 3, 25, -15]. The softmax collapses to near-one-hot and the gradients for non-max positions become extremely small. The model starts behaving like “pick one token aggressively” instead of learning nuanced attention distributions.
Dividing by √d_k fixes this by normalizing the dot products back toward unit variance. Now the logits look more like [-2.5, 0.4, 3.1, -1.9], producing softer distributions the optimizer can actually learn from.
The detail most explanations skip: this is initialization-dependent. The exact scaling depends on the variance of the projected Q/K activations at initialization. The standard 1/√d_k factor is correct because Transformer projections are typically initialized to keep activations near unit variance (Xavier initialization).
So the scaling factor is not just a heuristic; it falls directly out of the variance growth of high-dimensional dot products.
I’ve written deeper guide including transformer internals here: https://www.calibreos.com/learn/ml-transformers
Curious what other “Everybody repeats this but rarely explains the actual derivation” concepts people have run into in machine learning.
[link] [comments]