Dark Experience Replay (NeurIPS 2020)
https://arxiv.org/pdf/2004.07211
Buzzega, Pietro, et al. "Dark experience for general continual learning: a strong, simple baseline." Advances in neural information processing systems 33 (2020): 15920-15930.
Key Idea
-
기본 컨셉: “우리가 기억하고 싶은 것은 단지 입력과 정답(label)이 아니라, 그때 모델이 그것을 어떻게 해석했는가(logit)이다.”
-
한 줄 요약: “모델이 과거 데이터를 어떻게 예측했는지를 기억하자!!”
- i.e., 과거 데이터의 soft target (logit)을 함께 저장 + 이를 예측 잘 예측하도록!
1. Buffer
각 task에서 일부 샘플을 buffer에 저장.
저장 대상:
- 입력 \(x\)
- 레이블 \(y\)
- 모델의 이전 예측 \(f_{\theta_{old}}(x)\) = Dark Knowledge
2. Loss
(1) 현재 Loss
현재 task에 대한 supervised loss
- \(\mathcal{L}_{\text{task}} = \text{CrossEntropy}(f\theta(x_t), y_t)\).
(2) 과거 Loss
과거 task (Replay 샘플)에 대한 logit matching loss
- \[\mathcal{L}_{\text{replay}} = \sum_{(x_r, z_r)} \mid \mid f_\theta(x_r) - z_r \mid \mid^2\]
- \(z_r\): 저장된 logit
- \(x_r\): replay 샘플)
(3) 최종 Loss = (1) + (2)
- \(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \lambda \cdot \mathcal{L}_{\text{replay}}\).
3. Buffer Update
- 각 step마다 buffer에 새 sample을 추가!
- 샘플 교체: reservoir sampling 또는 FIFO 방식으로 샘플 교체
4. DER++
Replay loss
- DER: \(\mathcal{L}_{\text{replay}} = \mid \mid f_\theta(x_r) - z_r \mid \mid^2\)
- DER++: \(\mathcal{L}_{replay} = \text{CrossEntropy}(f(x_r), y_r) + \alpha \cdot \mid \mid f(x_r) - z_r \mid \mid^2\).
5. Code
# Assume buffer = list of (x, y, logit)
loss_task = criterion(model(x_curr), y_curr)
# DER Loss
x_replay, y_replay, logit_replay = buffer.sample()
logits_now = model(x_replay)
loss_der = F.mse_loss(logits_now, logit_replay)
loss = loss_task + lambda_ * loss_der
loss.backward()
optimizer.step()