A linear map applied to a concatenation of two vectors decomposes into separate maps plus addition:
where are the left and right halves of (the columns corresponding to and respectively).
This follows directly from how matrix multiplication works: each output element is a dot product of a row of with the full input. The input is , so the dot product splits along the boundary between and .
Efficiency for pairwise computation
Often you need not for one pair but for every pairing of with (e.g. scoring each node against each of its neighbors in a GNN). With nodes that’s up to pairs.
Without the decomposition, you’d concatenate and matmul each pair separately.
With it, you matmul each side once ( and , two matmuls total regardless of how many pairs), then combine any pair with a single addition .