Goal: compute
Problem: numerical stability for large positive or negative
Solution:

where

This way:

  • all exponents are
  • one term will have exponent , where
  • → small numbers, but no underflow/overflow

normalized_logits = logits - logits.logsumexp(dim=1, keepdim=True)
softmax = normalized_logits.exp() # exp(logits)/sum(exp(logits))

References

logarithm, exponential, softmax

https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/