Rotary Positional Embedding (RoPE)

최근 LLM에서 많이 쓰이는 PE 기법!


Contents

  1. Positional Embedding (PE)
  2. Relative PE의 등장
    1. Absolute PE의 문제점
    2. Relative PE의 등장
    3. Relative PE의 한계점: “느림”
  3. Rotation Matrix
  4. Overview of RoPE
  5. Details of RoPE
    1. 수식
    2. Summary
    3. Q,V에 적용
    4. 긴 sequence를 핸들링하는 이유
  6. Pytorch Code


1. Positional Embedding (PE)

Transformer는 order-invariant

\(\rightarrow\) Positional Embedding (PE)이 필요


Previous works

  • Absolute PE: 각 위치에 고정된 vector를 더함 (e.g., BERT) → 길이 고정, extrapolation 불가
  • Relative PE: 위치 간 상대적 거리만 반영 (e.g., Transformer-XL, T5 등).


2. Relative PE의 등장

(1) Absolute PE의 문제점

  1. Sequence length 일반화 불가 (extrapolation 문제)

    • “절대” 위치마다 “고정된” vector를 더하기 때문에!
    • (학습 시) 본 적 없는 더 긴 sequence 길이에서는 불가능
  2. Relative distance 정보 부족

    • 절대 위치만 더해주므로, “두 token이 얼마나 떨어져 있는지”를 직접적으로 알기 어려움.
  3. 비효율적 표현

    • 같은 상대 거리 (relative distance), 하지만 absolute embedding은 서로 다른 vector를 가짐 → Redundancy

    • e.g.,

      • position 5와 6
      • position 100과 101

      \(\rightarrow\) 모두 상대적 거리가 1인데, absolute embedding은 완전히 다른 vector


(2) Relative PE의 등장

  • 토큰 쌍 \((i,j)\)에 대해 “상대적 거리 \(i-j\)“를 embedding
    • 즉, (token 개수가 \(N\)개 라면) \(N\times N\) 짜리 attention matrix에 더하게 됨.
  • 장점:
    • Seq len 일반화 가능 → 학습 길이보다 긴 문장에도 잘 적용.
    • Relative distance 정보가 직접 attention에 들어감 → 문법/구조 학습에 유리.
    • 효율성 → 같은 상대 거리에 같은 embedding 사용.


(3) Relative PE의 한계점: “느림”

Absolute vs. Relative PE

  • (a) Absolute PE

    • 길이 \(N\)인 sequence에 대해 PE \(N\)만 준비 → 각 토큰 임베딩에 단순히 더해줌!

      → attention 연산( \(QK^\top\) )에는 추가 연산 없음.

    • 추가 비용: 선형(\(O(N·d)\))

  • (b) Relative PE

    • Attention score를 계산할 때:
      • \[\text{Score}(i,j) = \frac{q_i^\top k_j}{\sqrt{d}} + b_{i-j}\]
    • 여기서 \(b_{i-j}\)는 “두 위치 간 상대 거리”를 lookup해서 가져오는 값

\(\rightarrow\) 즉, 모든 \((i,j)\) 쌍에 대해 한 번씩, \(N \times N\) 의 bias를 추가해야 함.


Summary

  • Absolute: Token 개수 \(N\)만큼만 추가 연산

  • Relative: attention 행렬 자체 (\(N \times N\))에 element-wise 연산이 들어감.
  • \(N\)이 크면 \((N^2)\) 항의 overhead가 커짐.


3. Rotation Matrix

(1) Rotation

2차원 평면에서 어떤 점 \((x, y)\)를 “원점”을 기준으로 돌린다고 해보면

  • e.g., (1, 0)을 90°(시계 반대) 돌리면 → (0, 1)


“점들을 일정한 각도만큼 돌린다” = 회전 변환


(2) Rotation matrix

이런 변환을 행렬 곱으로 표현 가능!

  • 점: 열벡터 \(\begin{bmatrix}x \\ y\end{bmatrix}\)
  • 변환: 행렬 \(R\)
  • 변환된 점: \(R \begin{bmatrix}x \\ y\end{bmatrix}\)


e.g., 원점을 기준으로 \(θ\)만큼 시계 반대 방향 회전시키는 행렬:

  • \(R(\theta) = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix}\).


\(\begin{bmatrix} x’ \\ y’ \end{bmatrix} = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix} \begin{bmatrix} x \\ y \end{bmatrix}\) \(= \begin{bmatrix} x\cos\theta - y\sin\theta \\ x\sin\theta + y\cos\theta \end{bmatrix}\).


(3) 직관

(1,0)을 (반시계) 90° 돌린다:

\(R(90^\circ)\begin{bmatrix}1\\0\end{bmatrix} = \begin{bmatrix}0\\1\end{bmatrix}\).

→ 오른쪽을 보던 화살표가 위쪽으로


(0,1)을 (반시계) 90° 돌린다:

\(R(90^\circ)\begin{bmatrix}0\\1\end{bmatrix} = \begin{bmatrix}-1\\0\end{bmatrix}\).

→ 위쪽을 보던 화살표가 왼쪽으로


(4) 성질

  1. 길이 보존:
    • 회전은 “크기”를 바꾸지 않음
    • 단순히 방향만 바꿈!!
    • \(\mid \mid (x, y) \mid \mid = \mid \mid (x’, y’)\mid \mid\).
  2. 각도 보존:
    • 두 벡터 사이 각도도 변하지 않음
  3. 직교행렬:
    • \(R(\theta)^\top R(\theta) = I\). (전치 곱하면 항등행렬).
  4. determinant = 1:
    • 순수 회전 (크기 변화나 반사 없음)


(6) Summary

  • 회전 행렬 = “점(벡터)을 일정한 각도만큼 도는 공식”을 행렬로!

  • (공식) \(R(\theta) = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix}\)

  • 크기/각도 보존 & 단순히 방향만 바꿈!!
  • (RoPE에서는) 이 회전 성질을 이용해 위치 차이(=상대적 거리)를 자연스럽게 encoding!


4. Overview of RoPE

RoPE의 두 가지 핵심

  • Relative 위치를 “내적 구조”에 직접 녹여냄
  • “단순”하면서도 길이 확장성이 뛰어남


Details

각 쿼리(Q), 키(K) vector에 “위치에 따라 회전(Rotation)을 적용”

\(\rightarrow\) 이렇게 하면 내적(Q·K) 계산 시, 두 위치 간 상대적 거리 정보가 자연스럽게 포함


즉, “어디에 있는지(absolute)”가 아니라 “얼마나 떨어져 있는지(relative)”가 attention score에 반영


5. Details of RoPE

(1) 수식

  • 입력: \(x \in \mathbb{R}^d\)

  • 위치: \(m \in \mathbb{Z}^+\).

  • RoPE 적용: \(\text{RoPE}(x, m) = R_m \, x\).

    • \(R_m\): 2D rotation matrix가 block-diagonal 형태로 \(d/2\)개 반복된 것.
      • 각 블록은 2D 회전행렬(\(\cos, \sin\))이고
      • \(d/2\)개 블록이 쌓여 \(d \times d\) 전체 행렬을 만듬
    • 참고) \(N \times N\) matrix에 곱하는 것이 아니라, \(N\)개의 vector에 곱해지는 것임!


\(R_m = \begin{bmatrix} \cos(m\theta_1) & -\sin(m\theta_1) & & & \\ \sin(m\theta_1) & \cos(m\theta_1) & & & \\ & & \cos(m\theta_2) & -\sin(m\theta_2) & \\ & & \sin(m\theta_2) & \cos(m\theta_2) & \\ & & & & \ddots \end{bmatrix}\).

  • \((x_{2i}, x_{2i+1}) \mapsto (x_{2i}\cos(m\theta_i) - x_{2i+1}\sin(m\theta_i),\; x_{2i}\sin(m\theta_i) + x_{2i+1}\cos(m\theta_i))\) 를 적용하는 것이고
  • 모든 (거리 조합을) 한 줄로 나타내기 위해 “행렬곱”으로 표현
  • 각각의 \(\theta_i\)는 고유 주파수:

    • \(\theta_i = 10000^{-2(i-1)/d}\) (기존 sinusoidal embedding과 동일한 스케일링)


(2) Summary

  • \(R_m\)은 벡터 (토큰) 하나에 대해 위치 \(m\)을 반영할 때 쓰이는 회전 행렬
    • 즉, 이 행렬은 그 위치에 있는 벡터 \(x \in \mathbb{R}^d\)에만 적용
    • \(\text{RoPE}(x_m, m) = R_m \, x_m\).
  • 시퀀스 길이 \(N\)일 때

    • Token sequence \((x_1, x_2, \dots, x_N)\)이 있으면, 각 토큰마다 자기 위치에 해당하는 \(R_m\)을 사용

    • 1번째 토큰 → \(R_1 x_1\)

    • 2번째 토큰 → \(R_2 x_2\)

    • m번째 토큰 → \(R_m x_m\)


(3) Q, K에 적용

Q, K에 동일한 RoPE 적용

  • \(Q’_m = R_m Q\).
  • \(K’_n = R_n K\).


Attention score:

  • \((Q’_m)^\top K’_n = Q^\top (R_m^\top R_n) K\).


여기서 \(R_m^\top R_n = R_{n-m}\) 이므로, 점수는 m,n의 차이(=상대 위치)에만 의존!


(4) 긴 sequence를 핸들링하는 이유

  • 절대 위치 vector를 더하는 방식: (학습한 길이 이상의) extrapolation 불가!

  • RoPE는 회전 각도 공식이 연속적/주기적

    \(\rightarrow\) 원하는 만큼 큰 \(m\)에 대해서도 \(R_m\)을 계산할 수 있음.

    (절대 위치가 아니라) 상대 위치로 바뀌므로 sequence 길이가 늘어나도 동작!

  • 즉:

    • 길이가 늘어나도 위치 index \(m\)만 커지면 됨.
    • 기존 학습 범위를 넘어선 길이도 generalize 가능.
  • 다만, 각도 주파수는 주기성을 갖기 때문에 아주 길어지면 aliasing/주기 오버랩 문제가 발생할 수 있어,

    이를 보완하기 위해 RoPE scaling (NTK-aware scaling, Linear scaling 등) 기법이 연구되어 적용됩니다 (LLaMA-2/3).

6. Pytorch Code

import torch

def apply_rope(x, position_ids, dim):
    """
    x: (batch, seq_len, dim)
    position_ids: (seq_len,) 각 토큰 위치
    dim: 전체 차원 (짝수)
    """
    half_dim = dim // 2
    freq_seq = torch.arange(half_dim, dtype=torch.float32)
    freq = 10000 ** (-2 * freq_seq / dim)   # (half_dim,)

    # 각 위치의 각도
    theta = torch.einsum("i,j->ij", position_ids.float(), freq)  # (seq_len, half_dim)

    cos, sin = torch.cos(theta), torch.sin(theta)
    x1, x2 = x[..., :half_dim], x[..., half_dim:]
    x_rot = torch.cat([x1 * cos - x2 * sin,
                       x1 * sin + x2 * cos], dim=-1)
    return x_rot

Categories: ,

Updated: