Linear Attention

This blog records my understanding of linear attention. The key takeaways from my learning are: (1) always keep things hardware-efficient; (2) make data-independent things data-dependent; (3) make coarse data-dependency more fine-grained.

Thanks to Songlin Yang for this talk, which I learned a lot from.

Vanilla Linear Attention

Softmax causal attention can be formulated as:

where is the causal attention mask.

Linear attention is just dot-product attention without softmax:

The first equality holds because is a scalar, so it can be moved around freely. The second equality follows from associativity.

Define . Then . This gives a nice interpretation of dot-product attention without softmax: the output at timestep is the query pulling information out of a hidden state , where:

Data-Dependent Decay: From Coarse to Fine-Grained

One problem with the update rule is that it does not support forgetting. New information is added into the hidden state in a data-independent way. From a memory or information-aggregation perspective, this is too coarse. On top of that, during training the norm of can explode and cause instability.

We want to add a gating mechanism — to allow the model to forget things stored in the hidden state.

The Most Coarse Method

Add a data-independent weight decay. The update rule becomes , with .

This is proposed in RetNet.

More Fine-Grained

Add a global data-dependent weight decay.

For instance, in mLSTM:

Another case is Mamba-2:

In both cases, the data-dependent decay term is a scalar: for mLSTM and for Mamba-2. This means two things, viewed from different angles:

More and More Fine-Grained

Instead of using the same scalar for every element of , what if we used a different scalar for each matrix element?

Here we use a gate matrix instead of a gate scalar. However, this gate matrix should take the outer-product form in order to enable training parallelism. I will discuss this later in the section on the chunkwise parallel form.

Papers like GLA and GSA use this mechanism.

Perspective Shift: From Hidden State with Weight Decay to Test-Time Training

Slow weights and fast weights: fast weights provide a neurally plausible way of implementing the type of temporary storage required by working memory, while slow weights capture the more permanent associations learned over many experiences.

In a linear-attention architecture, the projection matrices are slow weights: they are fixed during inference and encode knowledge learned during training on large amounts of data. The hidden state is a fast weight: it is updated at each sequence step (when a new token arrives) during inference, via .

From this perspective, the update rule of the hidden state in linear-attention models can be interpreted as one step of test-time training on the fast weights. Test-time training (TTT) means doing one step of SGD on the fast weights during inference.

TTT objective for

Here is the parameter. Note that and depend only on the input , so .

Consider the per-step linear loss

whose gradient is . Performing one step of SGD with weight decay and learning rate ,

recovers the recurrence exactly. Equivalently, the update solves the proximal/regularized objective

Interpretation:

Another TTT objective

The main term in the previous TTT objective, , simultaneously aligns directions and inflates norms. But for associative-memory readout we really only need direction alignment. So a more natural objective is

This way we align both direction and norm; if is well normalized, the TTT objective itself contributes no source of gradient explosion.

The widely adopted DeltaNet and Gated DeltaNet use exactly this update rule. Doing one step of SGD with learning rate ,

The Gated DeltaNet paper offers a nice intuition for this update rule:

The delta update rule dynamically erases the value () associated with the current input key () and writes a new value (), which is a linear combination of the current input value and the old value, weighted by the “writing strength” :

DeltaNet has shown impressive performance on synthetic in-context retrieval benchmarks. However, since this process only modifies a single key–value pair at a time, the model lacks the ability to rapidly clear outdated or irrelevant information, especially during context switches where previous data needs to be erased. Consequently, DeltaNet has been found to perform only moderately on real-world tasks, likely due to the absence of a robust memory-clearing mechanism.

We propose the gated delta rule. This unified rule enables flexible memory control: it can promptly clear memory by setting , while selectively updating specific content without affecting other information by setting (effectively switching to the pure delta rule).

A few personal comments on combining data-dependent decay with the delta rule:

Non-Linear TTT objective

[WIP]

Hardware-Efficient Implementation: Chunk-Wise Parallel Form (CPF)

Case for understanding CPF: vanilla linear attention

[WIP]

A more realistic case: DeltaNet

[WIP]