An outer product
results in a matrix where the row is the row of the first vector multiplied by the element of the second vector:
torch.einsum("i, j -> ij", [torch.arange(4), torch.arange(4)])
tensor([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])→ an outer product produces a rank-1 matrix
Link to originalOuter-product memory
The rank-1 matrix produced by an outer product can be interpreted as storing an association from pattern (key) to pattern (value).
is a linear map that takes a query which, if similar to , returns something similar to :If the keys are unit norm, is exactly scaled by the cosine similarity of and .
→ outer product stores association from key to value
→ inner product measures similarity of query to key
→ Together, they implement content-based addressing, aka associative memory.For many pairs, we stack keys and values and build a weight matrix that stores all associations . Then
So computes the similarity of the query to all keys, and returns a similarity-weighted sum of all values.
If the keys are orthonormal, this exactly returns the value corresponding to the most similar key.
Too many non-orthogonal keys clutter recall.
Mitigations incude:
- normalization, scaling (e.g. NTM, SDP, layer normalization, RMS norm, qk-norm)
- applying a hard/soft threshold after reading, letting the correct item win even if noisy (Hopfield Network, SDP)
- increasing the the memory capacity, i.e. the dimensionality of keys/queries (for each key, each nonmatching adds ~ noise to the readout → std of noise is for items)
- using sparse keys/queries (e.g. locality-sensitive hashing, sparse attention, …)