TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning
Marco Bagatella, Matteo Pirotta, Ahmed Touati, Alessandro Lazaric, Andrea Tirinzoni
We propose a temporal-difference latent-predictive method for zero-shot unsupervised RL.
Abstract
Latent prediction–where agents learn by predicting their own latents–has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.
Learns zero-shot RL representations via temporal difference latent prediction recovering successor factorization.
- Introduces TD-JEPA leveraging TD-based latent-predictive representations for unsupervised RL
- Shows TD learning enables representations predictive of long-term latent dynamics across multiple policies
- Theoretically proves method avoids collapse and learns low-rank factorization of policy dynamics
- Matches or outperforms state-of-the-art on 13 datasets in ExoRL and OGBench
- Temporal difference learning
- Latent prediction
- Unsupervised reinforcement learning
- ExoRL
- OGBench
Authors did not state explicit limitations.
Study learning objectives compatible with asymmetric successor measures
from the paperBenchmark on large-scale real robotic datasets
from the paper
Author keywords
- zero-shot reinforcement learning
- unsupervised reinforcement learning
- self-predictive representations
- joint embedding predictive architecture
Related orals
Mastering Sparse CUDA Generation through Pretrained Models and Deep Reinforcement Learning
SparseRL leverages deep RL and pretrained models to generate high-performance CUDA code for sparse matrix operations.
Overthinking Reduction with Decoupled Rewards and Curriculum Data Scheduling
DECS framework reduces reasoning model overthinking by decoupling necessary from redundant tokens via curriculum scheduling.
MemAgent: Reshaping Long-Context LLM with Multi-Conv RL-based Memory Agent
MemAgent uses RL-trained memory modules to enable LLMs to extrapolate from 8K to 3.5M token contexts with minimal performance degradation.
DiffusionNFT: Online Diffusion Reinforcement with Forward Process
DiffusionNFT enables efficient online reinforcement learning for diffusion models via forward process optimization with up to 25x efficiency gains.
Hyperparameter Trajectory Inference with Conditional Lagrangian Optimal Transport
Hyperparameter Trajectory Inference uses conditional Lagrangian optimal transport to reconstruct neural network outputs across hyperparameter spectra without expensive retraining.