A linear map applied to a concatenation of two vectors decomposes into separate maps plus addition:

where are the left and right halves of (the columns corresponding to and respectively).

This follows directly from how matrix multiplication works: each output element is a dot product of a row of with the full input. The input is , so the dot product splits along the boundary between and .

Efficiency for pairwise computation

Often you need not for one pair but for every pairing of with (e.g. scoring each node against each of its neighbors in a GNN). With nodes that’s up to pairs.

Without the decomposition, you’d concatenate and matmul each pair separately.
With it, you matmul each side once ( and , two matmuls total regardless of how many pairs), then combine any pair with a single addition .