Gradient Episodic Memory for Continual Learning
Contents
- Abstract
 - Introduction
 - A Framework for Continual Learning
 - Gradient of Epsiodic Memory (GEM)
 - Algorithm
 
0. Abstract
2가지를 제안함
- (1) metrics to evaluate models learning over a continuum of data
 - (2) model for continual learning, called GEM (Gradient Episodic Memory)
 
1. Introduction
(대부분의) Supervise learning의 특징
- input data에 대한 i.i.d 가정
 - 목표 : 보지 못한 데이터 (unseen data)에 대한 loss를 minimize하도록 학습됨
 - ERM (Empirical Risk Minimization) principle 차용
 
하지만, 인간들은 다르다.
- Humans observe data as an ORDERED SEQUENCE! ( i.i.d가 아니다 )
 - 적은 수의 데이터 밖에 기억하지 못한다
 
\(\therefore\), 현실적으로, ERM을 적용하면 catastrophic forgetting 발생한다!
\(\rightarrow\) 이 paper의 목표 : ERM과 human-like learning 사이의 gap 줄이기!
Notation
continuum of data : \(\left(x_{1}, t_{1}, y_{1}\right), \ldots,\left(x_{i}, t_{i}, y_{i}\right), \ldots,\left(x_{n}, t_{n}, y_{n}\right)\)
- 이들은 서로 i.i.d가 아니다
 - task descriptor : \(t_{i} \in \mathcal{T}\)
 - data pair : \(\left(x_{i}, y_{i}\right) \sim P_{t_{i}}\)
 
challenges unknown to ERM
- 
    
1) Non-iid input data
 - 
    
2) Catastrophic forgetting
 - 
    
3) Transfer learning
( 만약 continuum내의 task들이 서로 related 되어있다면, transfer learninng을 활용할 여지 O )
 
2. A Framework for Continual Learning
( 가정 : continuum들은 locally i.i.d이다. 즉, \(\left(x_{i}, y_{i}\right) \stackrel{i i d}{\sim} P_{t_{i}}(X, Y)\) )
Training Protocol & Evaluation Metrics
일반적으로, sequence of tasks에 대해 학습하는 것은, 아래와 같은 setting을 가진다.
- 1) task의 수는 적다
 - 2) task 별 데이터 수는 충분하다
 - 3) 각 task내의 example에 대해 여러번의 pass를 거침
 - 4) average performance across all tasks를 metric으로 삼음
 
하지만, 이 논문은 보다 “human-like” setting을 가정한다 ( 보다 현실적 )
그러기 위해….
- 
    
training time에 learner에게 ONLY ONE example at a time만을 제공!
 - 
    
똑같은 data가 2번 제공되지 않음
 - 
    
tasks는 sequence로 들어옴
 - 
    
아래의 (1) 뿐만 아니라, (2) 또한 중시함
- 
        
(1) performance across tasks
 - 
        
(2) ability of learner to TRANSFER KNOWLEDGE
( 아래의 2가지 measure 참고 )
 
 - 
        
 
측정하고자 하는 measure
- Backward Transfer (BWT)
    
- task \(t\)에 대해서 학습하는 것이 PREVIOUS task \(t-1,...1\)에 미치는 영향
 - positive BWT & negative BWT ( = catastrophic forgetting )
 
 - Forward Transfer (FWT)
    
- task \(t\)에 대해서 학습하는 것이 FUTURE task \(t+1,...N\)에 미치는 영향
 
 - Test classsification accuracy : \(R_{i,j}\)
    
- task \(t_i\)를 관측한 뒤, task \(t_j\)에 대한 accuarcy
 
 
제안한 3가지 metric
\(\begin{aligned} \text { Average Accuracy: } \mathrm{ACC} &=\frac{1}{T} \sum_{i=1}^{T} R_{T, i} \\ \text { Backward Transfer: } \mathrm{BWT} &=\frac{1}{T-1} \sum_{i=1}^{T-1} R_{T, i}-R_{i, i} \\ \text { Forward Transfer: FWT } &=\frac{1}{T-1} \sum_{i=2}^{T} R_{i-1, i}-\bar{b}_{i} \end{aligned}\).
- 높을수록 좋은 metric이다
 
3. Gradient of Epsiodic Memory (GEM)
GEM의 핵심 특징 : “EPISODIC memory” ( \(\mathcal{M}_t\) )
- 
    
stores a “subset of observed examples” of task \(t\)
 - 
    
현실적으로 메모리 제약! total budget = \(M\)
( \(m=M/T\) memories for each task )
( 만약 task의 개수를 모를 경우, gradually reduce \(m\) )
 - 
    
목표 : minimize BACKWARD transfer ( catastrophic forgetting ), by using episodic memory
 
Loss at memories from the \(k\)-th task :
- \(\ell\left(f_{\theta}, \mathcal{M}_{k}\right)=\frac{1}{ \mid \mathcal{M}_{k} \mid } \sum_{\left(x_{i}, k, y_{i}\right) \in \mathcal{M}_{k}} \ell\left(f_{\theta}\left(x_{i}, k\right), y_{i}\right)\).
 
하지만, “현재 데이터의 loss”와 함께 “위의 memory의 loss function”을 둘 다 minimize하는 것은, \(\mathcal{M_k}\) 메모리에 저장된 데이터셋에 overfitting 위험!
따라서, 위의 loss (Loss at memories from the \(k\)-th task)는 inequality constraints로만 사용한다!
최종 목표 :
\(\begin{aligned} \operatorname{minimize}_{\theta} & \ell\left(f_{\theta}(x, t), y\right) \\ \text { subject to } & \ell\left(f_{\theta}, \mathcal{M}_{k}\right) \leq \ell\left(f_{\theta}^{t-1}, \mathcal{M}_{k}\right) \text { for all } k<t \end{aligned}\).
위 식을 효율적으로 풀고자 함.
- 
    
1) old predictor \(f_{\theta}^{t-1}\) 를 저장할 필요가 없음
( 단지, 이전 task의 loss가 parameter update \(g\)이후 loss가 늘어나지 않음만 “확인”하면 되니까 )
 - 
    
2) 위의 “확인”은, loss gradient vector & proposed update 사이의 angle을 계산함으로써 가능
- \(\left\langle g, g_{k}\right\rangle:=\left\langle\frac{\partial \ell\left(f_{\theta}(x, t), y\right)}{\partial \theta}, \frac{\partial \ell\left(f_{\theta}, \mathcal{M}_{k}\right)}{\partial \theta}\right\rangle \geq 0, \text { for all } k<t\).
 
 
위의 식을, Quadratic Program / dual problem / primal 등을 사용하여 정리하면, 최종적으로 아래와 같다.
\(\begin{aligned} \operatorname{minimize}_{z} & \frac{1}{2} z^{\top} z-g^{\top} z+\frac{1}{2} g^{\top} g \\ \text { subject to } & G z \geq 0, \end{aligned}\).
where \(G=-\left(g_{1}, \ldots, g_{t-1}\right)\)
4. Algorithm

5. Experiments
사용한 데이터셋
( 모든 데이터셋에 대해 \(T = 20\) tasks )
- 1) MNIST Permutations
 - 2) MNIST Rotations,
 - 3) Incremental CIFAR100
    
- each task introduces a new set of classes
 - For a total number of T tasks, each new task concerns examples from a disjoint subset of 100/T classes
 
 

