RMSNorm
- https://arxiv.org/pdf/1910.07467 (NeurIPS 2019)
Contents
- 개념
- 수식
- 장점
- 위치
- LN vs. RMSNorm
- Code
1. 개념
한 줄 요약: RMSNorm = 평균 빼지 않고 RMS로만 스케일 정규화 + 학습 가능한 gain.
RMSNorm (Root Mean Square Layer Normalization)
- Transformer에서 자주 쓰이는 정규화
- 평균을 빼지 않고, 마지막 차원(hidden dim)에서 RMS로만 스케일을 맞춤.
LayerNorm vs. RMSNorm
- LayerNorm: \((x-\mu)/\sqrt{\sigma^2+\varepsilon}\).
- RMSNorm: \(x/\sqrt{\text{mean}(x^2)+\varepsilon}\).
Summary
- 분산(=중심화된 2차 모멘트)을 사용 X
- 0을 기준으로 한 2차 모멘트만 사용. O
\(\rightarrow\) 계산이 더 단순 & LLM에서는 속도·안정성·메모리 측면에서 더 선호됨 (e.g., LLaMA 계열).
2. 수식
(배치의 한 토큰 벡터 \(x\in\mathbb{R}^{d}\)에 대해)
\(\text{RMSNorm}(x)=g\odot\frac{x}{\text{rms}(x)}\),
-
\(\text{rms}(x)=\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2+\varepsilon}\).
-
\(g\in\mathbb{R}^{d}\): Learnable한 스케일(“gain”) 파라미터 (보통 bias 없음)
3. 장점
-
효율성:
- 평균 계산·중심화가 없음
- 연산과 메모리 접근이 줄고, 저정밀(FP16/BF16)에서도 수치적으로 덜 까다로움
-
대형 모델에서 안정적:
-
Residual 경로가 많은 LLM(Pre-Norm)에서 부드러운 스케일링만으로도 충분히 학습
-
LLaMA 등에서 RMSNorm(+SwiGLU)이 기본 조합으로 자리잡음
-
참고)
- 평균을 빼지 않으므로 평균 시프트(mean shift) 는 그대로 남아있음
- 하지만 Residual/게이팅 구조와 결합될 때 보통 문제가 되지 X
4. 위치
Pre-Norm 구조 예:
# x: (B, T, D)
x = x + Attn( RMSNorm(D)(x) )
x = x + MLP ( RMSNorm(D)(x) )
- LayerNorm을 RMSNorm으로 치환할 때 보통 그대로 대체
- MLP는 SwiGLU와 함께 쓰는 패턴이 흔함 (LLaMA: RMSNorm + SwiGLU).
5. LayerNorm vs. RMSNorm
항목 | LayerNorm | RMSNorm |
---|---|---|
중심화(mean subtraction) | 있음 | 없음 |
정규화 기준 | \(\sqrt{\sigma^2+\varepsilon}\) | \(\sqrt{E[x^2]+\varepsilon}\) |
파라미터 | \(\gamma,\beta\) (보통 둘 다) | \(g\) (보통 bias 없음) |
수치/성능 | 표준 선택지 | 대형 LLM에서 자주 더 빠르고 안정적 |
구현 난도 | 보통 | 더 단순 |
6. Code
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # gain
# bias는 보통 두지 않습니다.
def forward(self, x):
# x: (B, T, D) 가정, 마지막 차원 D에 대해 정규화
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
x_hat = x / rms
return x_hat * self.weight