최근 수업 중 custom loss를 구현하는 과제에서 알게 된 trick인데 softmax 연산을 수행할 때 내부적으로 overflow나 underflow를 방지하기 위한 방법으로 이 방식을 사용한다고 한다. 이름에 나오듯 Log와 Summation 그리고 Exponential을 이용한 트릭이다.
exp를 기존 방식대로 계산해보면 아래와 같이 overflow가 발생하게 된다.
import numpy as np
x = np.array([1000, 1000, 1000])
print(np.exp(x))
# [inf inf inf]
x와 같은 logit(입력값)을 구해 softmax 연산을 위해 적용하였을 때 overflow가 발생하게 되고 이는 결국 loss가 발산하거나 backpropagation이 실패하는 결과로 이어질 수 있다. 물론 x와 같이 극단적인 크기의 logit이 발생하는 경우는 정규화 등을 통해 방지할 수 있겠지만, 저런 식으로 발생할 수밖에 없는 경우도 있을 수 있기에 logsumexp를 이용해 사전에 방지할 수 있다.
Softmax와 LogSumExp
Softmax의 수식은 다음과 같다.
$$ p_i = \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{n=1}^{N} \exp(x_n)} $$
여기서 \( p_i \)는 다음과 같이 정의된다.
$$ \sum_{n=1}^{N} p_i = 1 $$
이제 LogSumExp의 수식을 살펴보자.
$$ \mathrm{LSE}(x_1, \dots, x_N) = \log\left( \sum_{n=1}^{N} \exp(x_n) \right) $$
이름 그대로 Log, Summation, Exponential 순서로 전개된다.
Softmax 수식 유도
이번엔 자주 사용하는 softmax를 LSE를 사용하는 형태로 유도해 보자. 사실 간단하게 softmax 수식에 양변에 log를 취하면 되지만 절차에 따라 조금 더 상세히 살펴보려한다.
$$ p_i = \frac{\exp(x_i)}{\sum_{n=1}^{N} \exp(x_n)} $$
위 softmax 정의에서 denominator를 양변에 곱해주면 다음과 같다.
$$ \exp(x_i) = p_i \sum_{n=1}^{N} \exp(x_n) $$
여기서 양변에 log를 취해주면 다음과 같이 정리된다.
\begin{align*}
\log(\exp(x_i)) = \log(p_i \sum_{n=1}^{N} \exp(x_n)) \\
x_i = log(p_i) + \log(\sum_{n=1}^{N} \exp(x_n)))
\end{align*}
이제 \( log(p_i) \)에 대한 수식으로 정리하고 양변에 exp를 취해주면 다음과 같다.
\begin{align*}
\log(p_i) &= x_i - \log\left( \sum_{n=1}^{N} \exp(x_n) \right) \\
p_i &= \exp\left( x_i - \log \sum_{n=1}^{N} \exp(x_n) \right) \\
&= \exp\left( x_i - \mathrm{LSE}(x_1, \dots, x_N) \right)
\end{align*}
지금까지의 수식 전개는 softmax와 LSE 간의 수학적 관계를 나타낸 것이며, 해당 수식 전개는 단순한 수학적 정리일 뿐이다.
이제 trick에 대한 내용을 알아보자.
LogSumExp trick
LogSumExp trick을 위해선 overflow 방지용 수치 안정화(numerical stability)가 핵심이다.
\begin{align*}
y &= \log\left( \sum_{n=1}^{N} \exp(x_n) \right) \\
e^y &= \sum_{n=1}^{N} \exp(x_n) \\
e^y &= \sum_{n=1}^{N} \exp(x_n - c + c), \quad \exp(x_n - c + c) = e^{x_n-c} \cdot e^c \\
e^y &= e^c \sum_{n=1}^{N} \exp(x_n - c) \\
y &= c + \log \sum_{n=1}^{N} \exp(x_n - c)
\end{align*}
이렇게 전개하면 결과적으로 마지막에 c가 나오게 된다. 여기서 c는 현재 들어온 입력 벡터의 최댓값이 된다.
$$ c = \max\{x_1, \dots, x_N\} $$
실제로 구현하여 PyTorch에서 제공해 주는 logsumexp 메소드의 결과와 비교해 보자.
import numpy as np
import torch
def logsumexp(x):
c = max(x)
return c + np.log(np.sum(np.exp(x - c)))
x = np.array([1000, 1000, 1000])
print(logsumexp(x))
# 1001.0986122886682
print(torch.logsumexp(torch.tensor(x), dim=0).item())
# 1001.0986328125
# softmax 연산
print(np.exp(x - logsumexp(x)))
# [0.33333333 0.33333333 0.33333333]
결과에선 소수점 이하에서 미세한 차이가 발생하는 것을 확인할 수 있다. 이는 NumPy는 기본적으로 float64(배정밀도)를 사용하고 PyTorch는 기본적으로 float32(단정밀도)를 사용하기 때문이다.
참고 자료
https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
https://blog.naver.com/pkk1113/221327027304
https://heygeronimo.tistory.com/88