Deniz Qian

JEPA Navigator

Sep 2024 - Dec 2024 | GitHub | Download PDF

Self-Supervised JEPA World Model for Two-Room Navigation

This project explored building a self-supervised Joint Embedding Predictive Architecture (JEPA) to learn representations of an agent navigating a simple two-room environment. The goal was to train a model that could predict future latent representations based solely on current observations and actions without explicit reconstruction of pixel frames. We trained our JEPA models on 2.5 million frames of agent trajectories, ultimately learning structured latent spaces that capture spatial temporal dynamics and can be probed to recover true (x, y) agent coordinates.

Machine Learning Techniques

  • Implemented both recurrent and teacher-forcing JEPA variants with a ResNet-based CNN encoder to embed each frame into 128-D latent vectors.
  • Used a lightweight MLP to fuse these embeddings with action vectors, predicting the next latent state to model environment dynamics.
  • Optimized an energy-based objective by minimizing MSE between predicted and target encoder outputs, supported by a contrastive margin loss to encourage separation in representation space.
  • Applied VicReg-inspired variance and covariance regularization to prevent representational collapse and maintain feature diversity.
  • Stabilized training by introducing a momentum updated target encoder (similar to BYOL) for consistent gradient signals.
  • Employed a quadratic warm-up with cosine decay for learning-rate scheduling, and ran extensive sweeps on hyperparameters using Weights & Biases for reproducibility.
  • Accelerated training and handled large-scale data with PyTorch’s Distributed Data Parallel (DDP) across multiple GPUs, allowing us to process over 2.5 million frames efficiently.

Results

The final model demonstrated promising ability to predict latent future states, as evidenced by probing experiments with a 2-layer MLP to recover ground truth (x, y) coordinates. We evaluated performance on standard trajectories, wall-collision sequences, long horizon prediction, and out-of-domain test sets—reporting mean squared errors that confirmed the quality of our learned representations.

Although we faced challenges with instability and early overfitting (mitigated by gradient clipping, adjusted learning rates, and robust regularization) our best models showed clear improvements over initial baselines. Future work would include scaling to richer environments and experimenting with more advanced self-supervised losses.