[P] I implemented "Screening Is Enough" (arXiv:2604.01178) in PyTorch and benchmarked it

Last week's paper replaces softmax attention with an
absolute threshold
mechanism:

```
alpha = [max(1 - r * (1 - cosine_sim), 0)]^2
```

Keys below the threshold get zeroed out entirely — no global competition, no softmax denominator. Paper claims ~40% fewer params at comparable loss (in their full-scale iso-performance experiments up to 4B params — not iso-parameter comparisons) and 3.2x lower latency at 100K context.

I built a PyTorch implementation:
https://github.com/ibusan100/PyTorch-implementation-of-Screening-Attention

---

Latency (torch.utils.benchmark, RTX 4060 Ti, 8GB VRAM)

seq_len Screening nn.MHA F.SDPA
512 2.75ms 0.92ms 0.69ms
2048 43ms 8.4ms 4.8ms
4096
1956ms
30ms 15ms

3–66x slower than MHA. At seq_len=4096 the alpha matrix alone is ~2GB (B=4, H=8, T²=16M floats), plus a separate softmask tensor of similar size. `relu(...)` in PyTorch is a dense op — it allocates and computes the full O(T²) matrix and then sets values to zero, which means no sparsity benefit and no FlashAttention-style memory tricks apply. A Triton kernel that fuses the threshold check and skips zero-alpha keys entirely would change this picture completely.

Interesting finding:
100% of keys are screened out at initialization. In 64-dim head space, random unit vectors have cosine similarity std≈0.125, so P(sim > 0.5) = 0.001%. The r=2 threshold is simply unreachable at init. The model starts with attention completely off and must learn to lower r during training. This is by design — explains why the paper uses an unusually high LR (0.0625, written as 2^-4 in the paper).

WikiText-2 perplexity
(GPT-2 BPE, d_model=128, heads=4, layers=4, ~7.3M params, 10K steps)

Model test PPL time
TransformerLM 221.6 481s
MultiscreenLM
191.3
608s

Note: both models are deliberately matched at ~7.3M params (same d_model, heads, layers) for a fair architectural comparison — this is not a test of the paper's param-efficiency claim. Absolute PPL is high because d_model=128 is tiny relative to vocab_size=50K.

MultiscreenLM's test PPL is 14% lower (221.6 → 191.3), at the cost of 26% more training time and ~20% higher peak training memory (dense alpha matrix + softmask stored for backprop). The validation curves tell the same story — MultiscreenLM is already ahead at step 1K (valid PPL 402 vs 502), so attention opens up fast despite the dead-init.

I also tracked how r evolves during training. Spoiler: barely.

After 5K steps, mean r across all heads/layers drops from 2.0 → 1.93. Sparsity goes from 100% → ~95%. The attention never really "opens up" in the way you might expect — the model learns to selectively attend to maybe 5% of keys, and those few attended positions appear to be doing real work.

![r evolution and attention maps](
benchmarks/r_evolution.png
)

Honest caveats: this is a single-seed run — I haven't verified the 14% gap is stable across seeds, and at this scale the variance could be significant. Also, it's plausible that the threshold acts purely as a regularizer (sparsity-as-dropout) rather than anything architecturally meaningful. Distinguishing those two hypotheses requires larger-scale experiments.

---

Bottom line:
The mechanism works and the quality results are promising, but the paper's latency claims are entirely contingent on a custom sparse kernel that doesn't exist yet in PyTorch. The sparsity is 100% at init and only gets higher — the room for speedup is real. Therefore, the Triton kernel version is currently under development.

English is not my first language, so I am using machine translation for this communication.
Happy to discuss the math or implementation.

submitted by /u/Pleasant_Yard_8879
[link] [comments]

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top