Dark Experience Replay (NeurIPS 2020)

https://arxiv.org/pdf/2004.07211

Buzzega, Pietro, et al. "Dark experience for general continual learning: a strong, simple baseline." Advances in neural information processing systems 33 (2020): 15920-15930.


Key Idea

  • 기본 컨셉: “우리가 기억하고 싶은 것은 단지 입력과 정답(label)이 아니라, 그때 모델이 그것을 어떻게 해석했는가(logit)이다.”

  • 한 줄 요약: “모델이 과거 데이터를 어떻게 예측했는지를 기억하자!!”

    • i.e., 과거 데이터의 soft target (logit)을 함께 저장 + 이를 예측 잘 예측하도록!


1. Buffer

각 task에서 일부 샘플을 buffer에 저장.

저장 대상:

  • 입력 \(x\)
  • 레이블 \(y\)
  • 모델의 이전 예측 \(f_{\theta_{old}}(x)\) = Dark Knowledge


2. Loss

(1) 현재 Loss

현재 task에 대한 supervised loss

  • \(\mathcal{L}_{\text{task}} = \text{CrossEntropy}(f\theta(x_t), y_t)\).

(2) 과거 Loss

과거 task (Replay 샘플)에 대한 logit matching loss

  • \[\mathcal{L}_{\text{replay}} = \sum_{(x_r, z_r)} \mid \mid f_\theta(x_r) - z_r \mid \mid^2\]
  • \(z_r\): 저장된 logit
  • \(x_r\): replay 샘플)


(3) 최종 Loss = (1) + (2)

  • \(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \lambda \cdot \mathcal{L}_{\text{replay}}\).


3. Buffer Update

  • 각 step마다 buffer에 새 sample을 추가!
  • 샘플 교체: reservoir sampling 또는 FIFO 방식으로 샘플 교체


4. DER++

Replay loss

  • DER: \(\mathcal{L}_{\text{replay}} = \mid \mid f_\theta(x_r) - z_r \mid \mid^2\)
  • DER++: \(\mathcal{L}_{replay} = \text{CrossEntropy}(f(x_r), y_r) + \alpha \cdot \mid \mid f(x_r) - z_r \mid \mid^2\).


5. Code

# Assume buffer = list of (x, y, logit)
loss_task = criterion(model(x_curr), y_curr)

# DER Loss
x_replay, y_replay, logit_replay = buffer.sample()
logits_now = model(x_replay)
loss_der = F.mse_loss(logits_now, logit_replay)

loss = loss_task + lambda_ * loss_der
loss.backward()
optimizer.step()


figure2

figure2

Categories: ,

Updated: