year: 2019
paper: https://arxiv.org/pdf/1901.02860
website:
code:
connections: transformer, long-context, linear attention
XL … extra long (context)
TLDR: Processing long input sequences as segments, where the newest is processed in parallel, and the past segments’ keys and values are cached (and cross-attend to them from the new segment). 1
In the standard approach, as soon as you’re OOC, you a) need to re-process the entire context window again, if you shift by 1 token, and b) you lose all context beyond the window c) how to give chunks information about eachother? (context fragementation) etc..:

Whereas TrXL keeps KV from past segments (context window chunks), processing the new segment with full attention, while letting it attend to a number of cached segments:

I like this visualization from the RMT paper better:

Each segment can look at the intermediate state from layer of previous segments.
Since we progress in time, this creates a shifting effect, which allows for much longer effective context lengths, while still being efficient (no re-processing of past segments, no quadratic attention over all past tokens, no gradients, etc.). 2
… hidden state from current (new) segment at layer , with segment length
… cached hidden state from previous segments at layer , with memory length
… concatenation of both
Attention is then computed as usual, but with as input for K and V:
Note that the resulting attention matrix isn’t square.
Each token in the new segment can attend to all tokens in the memory + the new segment, but not vice versa.
With causal masking, it looks like this for , :
(rows are queries from new segment, columns are keys from memory + new segment)
is a hyperparameter trading off compute/memory and context length.
is a standard transformer.
is full attention over all past tokens (unbounded, lossless read-only memory). 3
Models trained with smaller can generalize to bigger .
In the paper, they train with and evaluate with with good generalization.
Bigger → more parallism and expressivity, but higher memory per step.
Smaller → more sequential chunking, less expressivity.
In the extreme case of we generate one token at a time.
TrXL’s attention is linear with respect to the total sequence length
Compute: vs. standard attention
Memory: vs. standard attention
is the effective receptive field, i.e. the number of past tokens each new token can attend to.
relative positional encodings are used because they allow for better generalization when reusing past segments (and absolute PEs wouldn’t make sense when reusing past segments anyway).
They introduce some new relative PE method but you can just use any.
Visualization of attn patterns by gemini https://ai.studio/apps/drive/1ssW9IcHrVQiVhtUeVvnuct6Dw-0T5Eh5?fullscreenApplet=true
Footnotes
-
“segment-level recurrence” … in other words “processing the current segment based on previous segments”. That’s also why is used to index segments, because like trajectories, it’s a sequence of tokens in each segment. Nothing to do with RNNs. ↩
-
In comparison to RMT, Yannik says TrXL only learns how to read memory, but that’s just not true? There’s no explicit, separate write mechanism, and due to stop-grad also no direct/long-range credit assignment to past segments, but over training the model has to shape its hidden states such that they are useful for future segments too, so it absolutely can learn to write useful information into the cached KV states. ↩
-
But as context length increases, the attention scores get more diffuse, making it harder to focus on relevant tokens. ↩