The Transformer’s attention mechanism has barely changed since 2017. Most efficiency work has tried to replace softmax attention outright. A new paper takes a different route. It keeps softmax attention and bolts on a correction branch.
A team of researchers from Northwestern University, Tilde Research, and University of Washington introduce a parameterized Local Linear Attention called ‘Parallax’ that scales to LLM pretraining and codesigns with Muon.
Parallax does not chase efficiency by cutting compute. It adds compute deliberately, then makes that compute cheaper to run on modern GPUs.
What is Parallax
Parallax builds on Local Linear Attention (LLA). LLA comes from the test-time regression framework. That framework reads attention as a regression solver over key-value pairs.
In this view, keys are training data points. Values are labels. The query is the test point. Softmax attention is a nonparametric estimator called Nadaraya-Watson. It fits a local constant function for each query.
LLA upgrades that local constant estimate to a local linear estimate. The research team proves this yields strictly smaller integrated mean squared error. The benefit is better bias-variance tradeoffs for associative memory.
But LLA has a problem at scale. Its exact forward requires solving a linear system for every query. That uses a parallel conjugate gradient (CG) solver. The CG solver creates three issues: intensive I/O, a hard regularization-expressiveness tradeoff, and low-precision incompatibility.
Parallax removes the solver. Instead, it learns an extra projection matrix. The research team writes this as ρi = WRxi. Here WR is a learnable matrix that probes the KV covariance directly from the layer input.
So Parallax keeps the local linear principle. It just replaces the per-query solve with a learned, query-like projector. That makes it simpler, more efficient, and easier to implement.
How the Mechanism Works
Parallax reformulates LLA as softmax attention plus an additive correction. The output equals the softmax attention output minus a projected covariance term. In the research paper’s notation, that term is the KV covariance multiplied by the learned probe ρi.
The research team also drops one piece of LLA called the boundary amplification factor, set to zero. This is necessary for stability. Once the probe is parametric, the original geometric interpretation breaks. Leaving the factor in could cause the scaling to diverge or flip sign.
Parallax sits inside a family of attention mechanisms. The research team organizes them in the paper by three axes: the bandwidth, the probe construction, and the affine structure. At one extreme, Parallax degenerates exactly to softmax attention when the probe norm goes to zero.
Setting WR = 0 makes a Parallax layer behave identically to softmax attention. So a pretrained Transformer checkpoint can be converted by adding WR and fine-tuning.
The Hardware Argument
Parallax inherits the streaming structure of FlashAttention. It adds one covariance branch that reuses the same key-value stream.
The research team expands the forward into two parallel scoring branches. Both branches share the online maximum, the rescaling factor, and the K and V tiles. So Parallax needs no extra I/O per iteration.
The key property is higher arithmetic intensity (AI). AI is the ratio of floating point operations to high-bandwidth memory traffic. In the regime where KV work dominates, Parallax roughly doubles the arithmetic intensity. It adds compute while reusing the same memory stream.
This shifts attention toward a more compute-bound regime. That is exactly the regime where kernel optimization helps on modern hardware.
The research team prototyped a decode kernel in CuTeDSL on NVIDIA Hopper GPUs. Hopper’s tensor core matmul instructions operate on tiles of at least 64 rows. A decode step supplies only one query row. So the QK and RK products can be computed jointly, within instructions standard attention already issues.
They profiled against FlashAttention 2 and 3 on H200 GPUs at BF16 precision. They swept batch sizes from 1 to 2,048 and context lengths from 128 to 32,768. The prototype kernel matches or outperforms FlashAttention across all configurations. The below figure annotates speedups of 1.54× in the compute-matched setting and 1.14× in the I/O-matched setting.
What the Experiments Show
The research team validated Parallax on synthetic tasks and on LLM pretraining at 0.6B and 1.7B scales. Models used the Qwen-3 architecture in the torchtitan repository. They trained on the Ultra-FineWeb dataset with a 4096 context length. Baselines included softmax attention (Transformer), Mamba, Gated DeltaNet, MesaNet, and Kimi DeltaAttention.
On the MAD-Benchmark, Parallax attained the highest overall accuracy at 0.716 average. It consistently improved recall-oriented tasks like In-Context-Recall and Selective-Copying. It stayed competitive on compression and memorization tasks.
On language modeling, Parallax with Muon achieved the best perplexity at both scales. It also posted the highest average downstream accuracy. At 1.7B, Parallax scored 62.45 average against the Transformer’s 61.43.
Two controls test where the gain comes from. A parameter-matched Transformer closed only a small fraction of the gap. A compute-matched Parallax still beat both baselines. The paper argues this points to the mechanism itself, not extra parameters or compute.
The Optimizer Twist
A core finding is an optimizer-architecture interaction. Parallax shows a large advantage under Muon. Under AdamW, the advantage shrinks markedly or even disappears.
Muon is a recent optimizer for matrix parameters in hidden layers. It uses the polar factor of the momentum buffer, so updates have condition number exactly one. Prior work shows this produces better-conditioned weight matrices.
The research team in the paper traces the gap to the correction branch. They define a correction-to-output ratio (COR). Under Muon, COR exceeds 8 in the deepest layers. Under AdamW, it stays below 4.
The WR projection is disproportionately affected. Its stable rank collapses under AdamW but stays high under Muon. A gating experiment confirms the pattern. Under AdamW, the model learns to suppress the correction branch rather than use it.
The research team call this the first empirical demonstration of strong architecture-optimizer codesign for attention mechanisms. They do not claim Muon with WSD is the optimal recipe. An appendix ablation shows the advantage shrinks during the decay phase.
How the Scores Differ
Parallax also produces different score distributions from softmax attention. Its per-token weights can take negative values and exceed one in magnitude. Standard softmax weights cannot do this.
The research team reports three effects. Parallax can actively subtract value components from irrelevant tokens. It substantially reduces the attention sink on the first token. Its base softmax entropy stays higher, giving more diffuse attention weights.
Strengths and Weaknesses and Open Questions
Strengths
Keeps softmax attention intact, so a pretrained Transformer can convert by adding WR and fine-tuning.
Adds no extra I/O per iteration by reusing the FlashAttention key-value stream.
Doubles arithmetic intensity, with a prototype kernel matching or beating FlashAttention 2/3 in decode.
Shows consistent perplexity and downstream gains under parameter-matched and compute-matched controls.
Weaknesses and Open Questions
Gains depend heavily on Muon; under AdamW the advantage largely disappears.
The precise cause of the optimizer dependence remains an open question.
Results stop at 1.7B scale, without MoE, longer context, or larger runs.
The advantage erodes during the WSD decay phase, only partially fixed by weight decay annealing.
Key Takeaways
Parallax keeps softmax attention and adds a learned covariance correction branch, replacing LLA’s per-query conjugate gradient solver.
It doubles arithmetic intensity while reusing the same KV stream, with a decode kernel matching or beating FlashAttention 2/3.
Consistent perplexity and downstream gains at 0.6B and 1.7B, holding under parameter-matched and compute-matched controls.
The gains depend heavily on Muon; under AdamW the advantage shrinks markedly or disappears.
Setting WR = 0 recovers softmax attention exactly, so pretrained Transformers can convert by adding WR and fine-tuning.
Check out the Paper and Repo. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post Parallax: A Parameterized Local Linear Attention That Keeps Softmax and Adds a Learned Covariance Correction Branch appeared first on MarkTechPost.