Message passing is a generalization of the convolution operation to graphs: We aggregate information from neighbouring nodes, but neighbours can vary in number and arrangement.
The difference to GCNs is !TODO!
transformers are GNNs, communicating via message passing on complete graphs of tokens.
Message passing layer
… some graph.
… neighbourhood of some node .
… features of .
features of edge (other features can also be used).
… message function (differentiable / neural network).
… readout function (differentiable / neural network).
… permutation invariant aggregation operator (like or ).
… output representation for node .The equation describes pairwise communication between nodes.
Node receives messages sent by all of its immediate neighbours to . Messages are computing via the message function , which accounts for the features of both senders and receiver.
A simple type of “message passing” in GNNs is simply aggregating information from adjacent nodes to compute the next embedding for a node (in conjunction with the nodes own embedding).
For a GNN with layers, each node can aggregate information from nodes steps away.
This can also be viewed as collecting all possible subgraphs of size and learning vector representations from the vantage point of one node or edge:
For large graphs, it may become infeasible (or even detrimental - see below) to add enough layers for every node to communicate with eachother, however, you can add a global context vector / master node / global representation which acts like a bridge for communication between nodes.
While GNNs generally scale with the number of layers, having many layers can even be detrimental, as representations get “diluted” from broadcasting many successive iterations.
GNNs are surprisingly parameter efficient.
Aggregation functions
There is no operation that is uniformly the best choice. The mean operation can be useful when nodes have a highly-variable number of neighbors or you need a normalized view of the features of a local neighborhood. The max operation can be useful when you want to highlight single salient features in local neighborhoods. Sum provides a balance between these two, by providing a snapshot of the local distribution of features, but because it is not normalized, can also highlight outliers. In practice, sum is commonly used.
If you measure the similarity of every node and its neighbours , via dot-products and then take the softmax and a weighted sum of the node features, you’ve almost got self-attention. Just that in self-attention, we make separate projections for Q,K, and V, in order to increase the expressivity of the scoring mechanism.
Parallels to convolutions
Aggregation = downsampling / pooling.
Number of graph layers = kernel size.
This is all very similar to self-attention.
Or like very similar - Just a different aggregaion mechanism too, instead of SDP.
And this issue or like tradeoff that you have with GNNs: For large graphs, not every node receives info from every node, even though it might need to in order to classify correctly, is exactly what more general attention mechanism solves.
MPGNN with registers?
References
MESSAGE PASSING ALL THE WAY UP
https://en.wikipedia.org/wiki/Graph_neural_network
distill gnn intro