Goal: compute
Problem: numerical stability for large positive or negative
Solution:
where
This way:
- all exponents are
- one term will have exponent , where
- → small numbers, but no underflow/overflow
normalized_logits = logits - logits.logsumexp(dim=1, keepdim=True)
softmax = normalized_logits.exp() # exp(logits)/sum(exp(logits))