causal self-attention is when the future tokens/elements are masked out, e.g. in LLMs, to predict the next word (aka decoding) the LLM shouldn’t already know it.
Non-causal self-attention is when all elements can attend to all other elements in the sequence. This is is used in the encoder-part when doing machine translation or text-classification.
Causal attention is used for efficiency, as you can use a single training sequence and train it simultaneously, as
len(train_seq)
number examples.
However, this restricts the early tokens during inference to also only ever look backwards. So all the computed attention of the previous tokens stays constant!
→ You can cache all of the previously computed values.This breaks down however as soon as you run out of context length, as the past is now different and the cached computations / relations become invalid. So you have to re-compute the entire attention matrix every single new step (like you have to with unmasked attention).
If we were to keep the cache, the 2nd token still sees the first (has some info abt it at least) if we make a step outside the context length, the newest one only sees the 2nd though. This (window) approach does not work in practice without re-computation or attention sinks:
Transclude of Efficient-Streaming-Language-Models-with-Attention-Sinks#^2b5e75
Why not split training data into multiple batches rather than training with a causal mask?
AFAIK training RNNs with batched data requires all the samples (within the batch) to have the same length, so with an RNN you’d have to run “123 → 4” and “1234 → 5” on different batches and recompute the common states (corresponding to inputs 1, 2 and 3) whereas a transformer can optimize on both “123 → 4” and “1234 → 5” objectives within the same batch execution without having to recompute anything. At least, that is my understanding.
Autoregressive prediction is all about predicting the next token, but using those sub-sequence examples is what makes training so powerful!→ More datapoints
→ Easy to parallize (batch) with a causal mask