year: 2021
paper: https://openreview.net/pdf?id=uCQfPZwRaUu
website:
code:
connections: prediction, RL, representation learning, self-supervised learning
Train agents to predict their own future latent representations. Augments Rainbow DQN with self-supervised loss that predicts future states in latent space using transition model and target encoder (EMA of online encoder).
Unlike pixel reconstruction methods, operates entirely in latent space. Unlike contrastive methods, no negative samples needed.
Link to original
Problem with vanilla value/policy loss How an auxiliary SPR loss helps Sparse / delayed rewards → almost no gradient early in training. Every frame supplies a dense self-supervised target (future latent), so the encoder learns useful structure before rewards appear. Representation drift: slight weight updates for one state perturb Q-values everywhere. Predicting k-step-ahead latents encourages smooth, locally linear dynamics in feature space, making values more stable. Overfitting to pixels: network can latch onto color patches that correlate with reward but don’t generalize. To minimize SPR loss it must keep track of objects that persist and move, implicitly disentangling position/velocity. Bootstrapping noise: value targets are themselves estimates. SPR provides an independent, low-variance learning signal that regularizes the network.
Core setup
Sample K+1 states from replay buffer. Online encoder produces representations, target encoder (EMA of ) produces prediction targets.
Transition model predicts future representations iteratively:
Loss:
Uses projection heads before computing similarity (like BYOL). K=5 optimal.
Transition model: 2x conv layers (64 channels, 3x3), operates on 64×7×7 encoder output. Action one-hot appended at each spatial location.
With augmentation: random shifts + color jitter (same as DrQ)
Without augmentation: dropout=0.5 in encoders works better
Key findings
- Target encoder crucial - online-only causes collapse
- Cosine similarity prevents collapse (L2 fails)
- τ=0 (no momentum) works with augmentation, τ=0.99 without
- Each component helps independently (future prediction alone beats prior methods)
Atari 100k: 0.415 median human-normalized score (55% above prior SOTA). Beats humans on 7/26 games with 2 hours gameplay.