RMSNorm

  • https://arxiv.org/pdf/1910.07467 (NeurIPS 2019)


Contents

  1. 개념
  2. 수식
  3. 장점
  4. 위치
  5. LN vs. RMSNorm
  6. Code


1. 개념

한 줄 요약: RMSNorm = 평균 빼지 않고 RMS로만 스케일 정규화 + 학습 가능한 gain.


RMSNorm (Root Mean Square Layer Normalization)

  • Transformer에서 자주 쓰이는 정규화
  • 평균을 빼지 않고, 마지막 차원(hidden dim)에서 RMS로만 스케일을 맞춤.


LayerNorm vs. RMSNorm

  • LayerNorm: \((x-\mu)/\sqrt{\sigma^2+\varepsilon}\).
  • RMSNorm: \(x/\sqrt{\text{mean}(x^2)+\varepsilon}\).


Summary

  • 분산(=중심화된 2차 모멘트)을 사용 X
  • 0을 기준으로 한 2차 모멘트만 사용. O

\(\rightarrow\) 계산이 더 단순 & LLM에서는 속도·안정성·메모리 측면에서 더 선호됨 (e.g., LLaMA 계열).


2. 수식

(배치의 한 토큰 벡터 \(x\in\mathbb{R}^{d}\)에 대해)

\(\text{RMSNorm}(x)=g\odot\frac{x}{\text{rms}(x)}\),

  • \(\text{rms}(x)=\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2+\varepsilon}\).

  • \(g\in\mathbb{R}^{d}\): Learnable한 스케일(“gain”) 파라미터 (보통 bias 없음)


3. 장점

  • 효율성:

    • 평균 계산·중심화가 없음
    • 연산과 메모리 접근이 줄고, 저정밀(FP16/BF16)에서도 수치적으로 덜 까다로움
  • 대형 모델에서 안정적:

    • Residual 경로가 많은 LLM(Pre-Norm)에서 부드러운 스케일링만으로도 충분히 학습

    • LLaMA 등에서 RMSNorm(+SwiGLU)이 기본 조합으로 자리잡음

참고)

  • 평균을 빼지 않으므로 평균 시프트(mean shift) 는 그대로 남아있음
  • 하지만 Residual/게이팅 구조와 결합될 때 보통 문제가 되지 X


4. 위치

Pre-Norm 구조 예:

# x: (B, T, D)
x = x + Attn( RMSNorm(D)(x) )
x = x + MLP ( RMSNorm(D)(x) )
  • LayerNorm을 RMSNorm으로 치환할 때 보통 그대로 대체
  • MLP는 SwiGLU와 함께 쓰는 패턴이 흔함 (LLaMA: RMSNorm + SwiGLU).


5. LayerNorm vs. RMSNorm

항목 LayerNorm RMSNorm
중심화(mean subtraction) 있음 없음
정규화 기준 \(\sqrt{\sigma^2+\varepsilon}\) \(\sqrt{E[x^2]+\varepsilon}\)
파라미터 \(\gamma,\beta\) (보통 둘 다) \(g\) (보통 bias 없음)
수치/성능 표준 선택지 대형 LLM에서 자주 더 빠르고 안정적
구현 난도 보통 더 단순


6. Code

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # gain
        # bias는 보통 두지 않습니다.

    def forward(self, x):
        # x: (B, T, D) 가정, 마지막 차원 D에 대해 정규화
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x_hat = x / rms
        return x_hat * self.weight

Categories: ,

Updated: