Contents
- 개념
- GeLU/ReLU vs. GLU
- GLU의 종류들
- SwiGLU 수식
- Swish/SiLU의 장점
- 2/3 rule (for 파라미터/연산량 맞추기)
- 코드
1. 개념
GLU (gated linear unit) 계열의 activation function
Gate 함수: Swish/SiLU
- \[\sigma_{\text{SiLU}}(z)=z\cdot\text{sigmoid}(z)\]
PaLM (구글), LLaMA (메타) 등 LLM에서 일반적으로 사용
2. GeLU/ReLU vs. GLU
[1] 표준 FFN(GeLU/ReLU)
- \(\text{FFN}(x)=W_2\,\phi(W_1 x+b_1)+b_2\),
- \(\phi=\text{GeLU/ReLU}\).
- hidden dim: \(d_{\text{ff}} \approx 4\,d_{\text{model}}\).
[2] GLU 계열
- “gate”를 도입함
- 직관: gate가 정보 선택을 도와 더 표현력이 좋아짐
-
첫 번째 선형을 두 갈래로 나눠!
- 하나는 내용
- 하나는 gate
\(\rightarrow\) element-wise multiplication으로 결합
- \(\text{GLU}(x)=\big(W_{v}x+b_v\big)\;\odot\;g\big(W_{g}x+b_g\big)\).
- \(g(\cdot)\): gate activation
3. GLU의 종류들
- ReGLU: \(g=\text{ReLU}\)
- Rectified Linear Unit
- \(\sigma_{\text{ReLU}}(z) = \max(0, z)\).
- GEGLU: \(g=\text{GeLU}\)
- Gaussian Error Linear Unit
- \(\sigma_{\text{GeLU}}(z) = z \cdot \Phi(z)\).
- \(\Phi(z)\)는 Gaussian distn의 CDF
- \(\sigma_{\text{GeLU}}(z) \approx 0.5 \, z \left(1 + \tanh\!\Big(\sqrt{\tfrac{2}{\pi}} \,\big(z + 0.044715 z^3\big)\Big)\right)\).
- SwiGLU: \(g=\text{Swish/SiLU}\).
- Sigmoid Linear Unit
- \[\sigma_{\text{SiLU}}(z)=z\cdot\text{sigmoid}(z)\]
4. SwiGLU 수식
Notation
- 입력 \(x\in\mathbb{R}^{d}\)
- Hidden dimension: \(m\)
\(\text{SwiGLU}(x)=\underbrace{W_v x + b_v}{\in\mathbb{R}^{m}} \;\odot\; \underbrace{\text{SiLU}(W_g x + b_g)}{\in\mathbb{R}^{m}}\),
- where \(\text{SiLU}(z)=z\cdot\sigma(z)\).
최종 FFN 출력:
- \(\text{FFN}_{\text{SwiGLU}}(x)=W_o\,\text{SwiGLU}(x)+b_o,\quad W_o\in\mathbb{R}^{d\times m}\).
5. Swish/SiLU의 장점
-
부드러운 비선형성(smooth, non-monotonic): ReLU 대비 “음수 영역”도 연속적으로 보정
→ 미분/최적화가 안정적!
-
LLM에서 경험적으로 성능 우수
- Shazeer(2020) “GLU Variants Improve Transformer”
6. 2/3 rule (for 파라미터/연산량 맞추기)
[When?] ReGLU, GEGLU, SwiGLU 같은 GLU 계열 활성화 함수 사용 시
[Why?] Parameter 수를 기존 GeLU-FFN과 맞추기 위해 hidden dimension을 줄이는 경험적 규칙
[From] Shazeer (2020) “GLU Variants Improve Transformer”
(1) 표준 Transformer FFN (예: GeLU)
- 입력 차원 \(d_{\text{model}}, hidden 차원 d_{\text{ff}} \approx 4d_{\text{model}}\).
- 파라미터 수 (bias 무시):
- \(\underbrace{d_{\text{model}} \times d_{\text{ff}}}{W_1} + \underbrace{d_{\text{ff}} \times d_{\text{model}}}{W_2} = 2 d_{\text{model}} \, d_{\text{ff}}\).
(2) GLU 계열 FFN (ReGLU, GEGLU, SwiGLU)
- 첫 선형층에서 두 갈래 (value branch + gate branch)를 만들어 element-wise multiplication
- 중간 차원 m일 때 파라미터 수:
- \(\underbrace{d_{\text{model}} \times m}{W_v} + \underbrace{d_{\text{model}} \times m}{W_g} + \underbrace{m \times d_{\text{model}}}{W_o} = 3 d_{\text{model}} \, m\).
2/3 규칙 도출
GLU FFN의 파라미터 수를 표준 FFN과 같게 맞추려면 …
-
\(3 d_{\text{model}} m = 2 d_{\text{model}} d_{\text{ff}}\).
→ \(m = \tfrac{2}{3} d_{\text{ff}}\)
즉, hidden dimension을 2/3로 줄이면 GLU FFN과 GeLU FFN의 파라미터 수가 동일해짐
적용 예시
- LLaMA, PaLM: FFN에서 SwiGLU + 2/3 규칙을 채택.
- LLaMA 계열: RMSNorm + SwiGLU를 조합하는 패턴이 흔함
- 이렇게 하면 파라미터 수, 연산량은 유지하면서도 성능은 GLU 계열 이득을 가져갑니다.
7. 코드
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU_FFN(nn.Module):
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.0, match_params_with_gelu: bool = True):
"""
d_model: 모델 차원
d_ff: 표준 FFN의 중간 차원(기본 4*d_model 가정 시 None)
match_params_with_gelu: True면 2/3 규칙을 적용해 파라미터 수를 표준 FFN과 맞춤
"""
super().__init__()
if d_ff is None:
d_ff = 4 * d_model # 표준 설정 가정
m = int((2/3) * d_ff) if match_params_with_gelu else d_ff # 2/3 규칙
self.W_v = nn.Linear(d_model, m, bias=True)
self.W_g = nn.Linear(d_model, m, bias=True)
self.W_o = nn.Linear(m, d_model, bias=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
v = self.W_v(x) # 값(branch)
g = self.W_g(x) # 게이트(branch)
# SiLU/Swish: z * sigmoid(z)
y = v * F.silu(g)
y = self.W_o(y)
return self.dropout(y)