GRU

The Gated Recurrent Unit is an RNN that combines forget and input gates of the LSTM into a single update gate, and merges cell state and hidden state.

: input vector at time (t)
: previous and current hidden states
… update and reset gates (values between 0 and 1)
: sigmoid, tanh
… elementwise multiplication
… input weight matrices
… hidden state weight (update) matrices
… bias vectors

The notation here summarizes and into a single weight matrix each, and concatenates the input and hidden state vectors. Biases are like so often omitted for simplicity.

This matches how it’s actually implemented, doing 2 instead of 6 matmuls:

class GRUCellDiagram(nn.Module):
	def __init__(self, input_size, hidden_size):
		super().__init__()
		self.gates = nn.Linear(input_size + hidden_size, 2 * hidden_size) # W_z and W_r
		self.cand  = nn.Linear(input_size + hidden_size, hidden_size) # W 
 
	def forward(self, x_t, h_prev): # (batch, input_size), (batch, hidden_size)
		hx = torch.cat([h_prev, x_t], dim=-1)
		r_t, z_t = self.gates(hx).chunk(2, dim=-1)
		r_t = torch.sigmoid(r_t)
		z_t = torch.sigmoid(z_t)
 
		rhx   = torch.cat([r_t * h_prev, x_t], dim=-1)
		h_hat = torch.tanh(self.cand(rhx))
 
		h_t = (1.0 - z_t) * h_prev + z_t * h_hat
		return h_t

https://colah.github.io/posts/2015-08-Understanding-LSTMs/