Last week's paper replaces softmax attention with an
absolute threshold
mechanism:
```
alpha = [max(1 - r * (1 - cosine_sim), 0)]^2
```
Keys below the threshold get zeroed out entirely — no global competition, no softmax denominator. Paper claims ~40% fewer params at comparable loss (in their full-scale iso-performance experiments up to 4B params — not iso-parameter comparisons) and 3.2x lower latency at 100K context.
I built a PyTorch implementation:
https://github.com/ibusan100/PyTorch-implementation-of-Screening-Attention
---
Latency (torch.utils.benchmark, RTX 4060 Ti, 8GB VRAM)
| seq_len | Screening | nn.MHA | F.SDPA |
|---|---|---|---|
| 512 | 2.75ms | 0.92ms | 0.69ms |
| 2048 | 43ms | 8.4ms | 4.8ms |
| 4096 | |||
| 1956ms | |||
| 30ms | 15ms |
3–66x slower than MHA. At seq_len=4096 the alpha matrix alone is ~2GB (B=4, H=8, T²=16M floats), plus a separate softmask tensor of similar size. `relu(...)` in PyTorch is a dense op — it allocates and computes the full O(T²) matrix and then sets values to zero, which means no sparsity benefit and no FlashAttention-style memory tricks apply. A Triton kernel that fuses the threshold check and skips zero-alpha keys entirely would change this picture completely.
Interesting finding:
100% of keys are screened out at initialization. In 64-dim head space, random unit vectors have cosine similarity std≈0.125, so P(sim > 0.5) = 0.001%. The r=2 threshold is simply unreachable at init. The model starts with attention completely off and must learn to lower r during training. This is by design — explains why the paper uses an unusually high LR (0.0625, written as 2^-4 in the paper).
WikiText-2 perplexity
(GPT-2 BPE, d_model=128, heads=4, layers=4, ~7.3M params, 10K steps)
| Model | test PPL | time |
|---|---|---|
| TransformerLM | 221.6 | 481s |
| MultiscreenLM | ||
| 191.3 | ||
| 608s |
Note: both models are deliberately matched at ~7.3M params (same d_model, heads, layers) for a fair architectural comparison — this is not a test of the paper's param-efficiency claim. Absolute PPL is high because d_model=128 is tiny relative to vocab_size=50K.
MultiscreenLM's test PPL is 14% lower (221.6 → 191.3), at the cost of 26% more training time and ~20% higher peak training memory (dense alpha matrix + softmask stored for backprop). The validation curves tell the same story — MultiscreenLM is already ahead at step 1K (valid PPL 402 vs 502), so attention opens up fast despite the dead-init.
I also tracked how r evolves during training. Spoiler: barely.
After 5K steps, mean r across all heads/layers drops from 2.0 → 1.93. Sparsity goes from 100% → ~95%. The attention never really "opens up" in the way you might expect — the model learns to selectively attend to maybe 5% of keys, and those few attended positions appear to be doing real work.

Honest caveats: this is a single-seed run — I haven't verified the 14% gap is stable across seeds, and at this scale the variance could be significant. Also, it's plausible that the threshold acts purely as a regularizer (sparsity-as-dropout) rather than anything architecturally meaningful. Distinguishing those two hypotheses requires larger-scale experiments.
---
Bottom line:
The mechanism works and the quality results are promising, but the paper's latency claims are entirely contingent on a custom sparse kernel that doesn't exist yet in PyTorch. The sparsity is 100% at init and only gets higher — the room for speedup is real. Therefore, the Triton kernel version is currently under development.
English is not my first language, so I am using machine translation for this communication.
Happy to discuss the math or implementation.
[link] [comments]