Self-attention
The transformer accepts a set of observations , ( could be a word embedding, image patch, …). is the number of tokens and is the embedding dimension of each token.
To attend to another, each token (a row vector) emits a query () and compares it to the keys of other tokens (), in order to decide what weight to assign to their values ():Here, is the kernel describing the similarity of to .
The resulting output values are a weighted sum of the values that each token encodes. Tokens with large contribute more to the output .
Typically, the similarity measure is the cosine similarity with a scaling factor and softmax, i.e. scaled dot product attention ().
There are also other variants which remove softmax, e.g. forms of linear attention.
These equations can be succinctly expressed in matrix form, with all elements updated simultaneously:Here, , , and are matrices with rows filled by , , and , respectively.
The softmax is taken independently for each row.
The result of the operation is also called the attention matrix.
QKV?
We project the input nodes to Q,K,V, in order to give the scoring mechanism more expressivity.
In principle we could do without them, see message passing.Q: Here is what I’m interested in (like text in a search bar).
K: Here is what I have (like titles, description, etc.).
V: Here is what I will communicate to you (actual content, filtered / refined by value head).
For every query, all values are added in a weighted sum, weighted by similarity between that query and the keys.
Attention matrix: queries = rows, keys = columns.
Self-attention?
What makes it Self- attention is, that the K, Q and V are all coming from the same source , the input.
Keys and Values can also come from an entirely seperate source / nodes: cross-attention.
The qk - and v if it is the same size - dim is sometimes referred to as “head size”, esp. in the context of MHA.
Attention is a communication mechanism / message passing scheme between nodes.
Attention in principle be applied to any arbitrary directed graph.
Attention can be thought of as a message passing scheme.
Every node (token) has some vector of information () at each point in time (data-dependent) and it gets to aggregate information via a weighted sum from all the nodes that point to it.
Imagine looping through each node:
- The query - representing what this node is looking for - gathers the keys of all edges that are input to this node and then calculates the the unnormalized “interestingness” of information / compatibilities with other nodes by taking the dot-product between the keys and the queries.
- We then normalize these scores and multiply them with the value of the input nodes.
- This happens in every head in parallel, and every layer in sequence, with different weights (in both cases).
The attention graph of encoder-transformers is fully connected.
The attention graph of the decoder is fully connected to the encoder values, and tokens are fully connected to every token that came before them (triangular attention matrix structure).
Graph of autoregressive attention (self-loops are missing from the illustration):
Attention acts over a set of vectors in a graph: There is no notion of space.
Nodes have no notion of space, e.g. where they are relative to another.
This is why we need to encode them positionally:
Transclude of inductive-bias#^a88119
Link to originalTransformers and The Bitter Lesson (from The Bittersweet Lesson)
→ If it’s possible for a human domain expert to discover from the data the basis of a useful inductive bias, it should be obvious for your model to learn about it too, so no need for an inductive bias.
→ Instead focus on building biases that improve either scale, learning or search .
→ In the case of sequence models, any bias towards short-term dependency is needless, and may inhibit learning (about long-term dependency).
→ Skip connections are good because they promote learning.
→ MHSA is good because it enables the Transformer to (learn to) perform an online, feed-forward, parallel ‘search’ over possible interpretations of a sequence.
Link to originalThe hamiltonian of the Hopfield Network is identical to the ising model, except that the interaction strength is not constant, and very similar to attention!
For a hopfield network:
is the weight / interaction matrix, which encodes how patterns are stored, rather than being a simple interaction constant.
In a simplified form, the update rule of the HFN is:And for attention:
Tokens in self-attention are just like spins in the ising model, but instead of spins, they are vectors in higher dimensional space, with all to all communication, self-organizing to compute the final representation.
Attention matrix values “visualization”.
NOTE: The full transformer is not - set transformer changes this.
(from set transformer)
Transclude of Relating-transformers-to-models-and-neural-representations-of-the-hippocampal-formation#^97a897
Causal and padding masks
(todo, refactor and separate note(s)
or for padding mask,maybe just put it here, under implementation details, together with the batch dim note at the bottom
causal attention
Transformer is trained to predict next token in sequence at each step (every time with the correct previous tokens, calledd Teacher Forcing), so this can be done in one go, by applying a causal mask, which masks out the future tokens at each step (torch.ones(T,T).tril(diagonal=0)
for F.scaled_dot_product_attention
).
The causal mask is merged with the padding mask. The padding mask cuts the attention off, as soon as we reach the first padding token.
Link to originalAttention matrix: queries = rows, keys = columns.
Hence it is sufficient to cut off the columns, which is saying “I don’t want anything to pay attention to these keys (padded tokens)”.
Each example across batch dimension of course independently processed, tokens across batches never talk to each other.
References
Relating transformers to models and neural representations of the hippocampal formation
Stanford CS25: V2 I Introduction to Transformers w/ Andrej Karpath
Karpathy notebook and YT
Karpathy Attention timstamped)
AssemblyAI YT intro.
Deep-learning bible reference
Implementation: Pytorch Lightning Notebook with explanations.