Masked Autoencoders Are Effective Tokenizers for Diffusion Models (ICML 2025 Spotlight)
1. Introduction
Background: Diffusion models
- 초창기: Pixel space
 - 이후: Latent space (feat. LDM)
    
- LDM: Tokenizers (보통 VAE 기반)로 차원 축소된 latent space에서 학습/생성
 
 
Question: 어떤 latent space가 diffusion 학습에 좋은가?
- VAE: smooth distribution은 만들지만 pixel fidelity가 떨어짐.
 - AE: pixel fidelity는 높지만 latent space 구조가 복잡하고 entangled
 
Proposal: Masked Autoencoder(MAE) 방식을 tokenizer로 활용
- discriminative latent space 제공
 - variational constraint 불필요
 - 더 효율적이고 성능 좋은 diffusion 모델 학습 가능
 
2. On the Latent Space and Diffusion Models
[Empirical analysis]: Latent space의 구조를 GMM으로 fitting.
- 
    
Latent space의 mode 수가 많으면….
\(\rightarrow\) (Entangled) diffusion loss ↑ & gFID(생성 품질) ↓.
 - 
    
Fewer modes → More discriminative → Diffusion 학습 쉬워지고 생성 품질 높음!
- Q) mode가 많으면, 오히려 클래스별로 구분된다는 뜻 아닌지?
 - A)
        
- No! 같은 클래스가 이곳저곳에 흩어져있다는 뜻
 - Class별로 구분이 아니라 class 내부적으로도 더 많은 모드가 생기니, entangled 구조
 
 
 

[Theoretical analysis]:
- GMM 모드 수 \(K\)가 많을수록 diffusion training에 더 많은 sample 필요.
 - 따라서 finite data setting에서 K가 많으면 generation 품질이 나빠짐.
 
3. Method (MAETok)
Overview
- Idea: AE를 MAE 방식으로 훈련시켜 “더 구조화된 latent space”를 학습
 - 
    
Architecture
- Encoder: ViT 기반, 이미지 패치 + learnable latent tokens 입력 → latent representation \(h\) 출력.
 - Decoder: Masked tokens + latent representation 입력 → 픽셀 reconstruction
 - Auxiliary shallow decoders: HOG, DINOv2, CLIP 등의 feature를 예측하도록 추가 supervision .
 
 

- 
    
Training objectives:
- \(L = L_{recon} + \lambda_1 L_{percep} + \lambda_2 L_{adv}\).
 - a) pixel MSE
 - b) perceptual loss
 - c) adversarial loss
 
 - Mask Modeling:
    
- 
        
Encoder 입력에서 40–60% 패치 mask.
 - 
        
Encoder가 더 discriminative feature를 학습하게 유도.
 
 - 
        
 - Pixel Decoder Fine-tuning: Encoder freeze, Decoder만 fine-tune → high fidelity reconstruction 회복.
 
(1) Architecture
- 
    
Encoder:
- Vision Transformer (ViT) 기반.
 - Patchify 후 일부 masking
 - MaskingX patch: w/ learnable latent tokens
 
 - 
    
Decoder:
- 
        
이미지 reconstruction
 - 
        
단순히 픽셀만 복원하는 게 아니라 (이 경우는 masked 된 부분만)
후술할 auxiliary objectives를 통해 feature-level 복원 (이 경우네는 모든 부분에 대해) 도 수행
 
 - 
        
 - 
    
Auxiliary shallow decoders:
- 
        
Encoder latent representation으로부터 CLIP, DINOv2, HOG 등 semantic feature를 예측
 - 
        
의미:
- Pixel fidelity만 맞추는 것이 아니라
 - Semantic feature도 복원하도록 유도
 
→ Latent space가 더 Discriminative하게 됨.
 
 - 
        
 
(2) Training Objectives
\(L = L_{recon} + \lambda_1 L_{percep} + \lambda_2 L_{adv}\).
- Reconstruction Loss
    
- Mask된 영역을 복원할 때 pixel 단위 MSE
 
 - Perceptual Loss:
    
- DINOv2, CLIP 같은 pretrained model의 feature space에서 distance 최소화.
 - 즉, “복원 이미지가 인간 지각적 의미에서도 비슷한가?”를 측정
 
 - Adversarial Loss:
    
- PatchGAN 스타일 discriminator로 복원 이미지가 자연스러운지 판별.
 
 - Auxiliary Loss:
    
- Shallow decoder를 통한 CLIP/DINO/HOG feature 예측 loss.
 
 
(3) Masked Modeling
MM이 latent space를 discriminative하게 하는 핵심 역할
- MAE에서 핵심: 입력 이미지를 Patch 단위로 random masking
 - Encoder는 masking 상황에서 정보를 압축해야 하므로,
    
- “어떤 feature가 중요한지”를 더 잘 구분하도록 학습.
 
 - Mask ratio 실험: 40–60%가 가장 효율적.
    
- 너무 낮으면: latent가 덜 구조화
 - 너무 높으면: 복원 fidelity가 떨어짐
 
 
3.4 Pixel Decoder Fine-tuning
기존 MM의 문제점:
- (장) Latent space를 discriminative하게 만들지만
 - (단) Pixel fidelity(세밀한 복원력)는 조금 떨어짐
 
해결: Pretrained encoder는 freeze, pixel decoder만 따로 fine-tuning.
- 이렇게 하면 latent space의 구조는 그대로 두고, decoder의 픽셀 복원력만 끌어올릴 수 있음.
 
4. Summary (Procedure)
Step 1: Tokenizer 학습 (MAE 기반 AE)
Step 1-1: Encoder + Decoder 공동 학습
- 목표: Encoder가 discriminative latent space를 만들고, Decoder가 기본적인 복원 능력을 가지도록.
 - 방법: Masked Modeling(MAE-style)로 학습.
    
- Encoder는 unmasked patches와 latent tokens을 받아 feature를 뽑음.
 - Decoder는 masked 부분을 복원.
 
 - Loss:
    
- Reconstruction loss (masked 영역)
 - Perceptual loss
 - Adversarial loss
 - Auxiliary feature loss(CLIP/DINO/HOG) 등.
 
 - 결과: Encoder는 “semantic하게 구분 잘 되는 latent space”를 학습, Decoder는 기본 복원력 확보.
 
Step 1-2: Decoder Fine-tuning (Encoder freeze)
- 문제: Step 1-1만 하면 latent space는 좋지만 pixel fidelity(세밀한 복원력)이 부족.
 - 해결: Encoder를 고정(freeze)하고 Decoder만 따로 학습.
    
- 이 단계에서는 주로 pixel reconstruction loss를 통해 디테일 회복에 집중.
 
 - 
    
목적: Latent space의 구조(semantic separability)는 그대로 유지하면서, Decoder의 디테일 복원 능력만 개선.
 - 결과: Encoder+Decoder 쌍이 완성된 Tokenizer.
 
Step 2: Diffusion 학습
- Encoder: 이미지 → latent tokens
 - Diffusion: latent tokens 공간에서 학습/샘플링
 - Decoder: latent tokens → 최종 이미지
 
5. Experiments


