More Data: JEPA as a Learning Attachment for LLMs
An experimental approach using JEPA Attach to test whether Joint Embedding Predictive Architectures can help Large Language Models learn from abundant multimodal data when high quality text becomes scarce.
The looming data wall problem
Large Language Models have reached a crossroads. Compute optimal scaling laws like Chinchilla show that tokens must scale with parameters but high quality human text is finite. Conservative estimates suggest we'll exhaust publicly available human generated text for frontier model training by 2026 to 2032.
Meanwhile, alternatives like synthetic data carry documented risks of recursive degradation (model collapse), and RAG helps with knowledge access but doesn't provide new pretraining signals. This motivates our experimental approach: testing whether LLMs can learn from the abundant streams of unlabeled multimodal data that dwarf text corpora.
The JEPA insight
Joint Embedding Predictive Architectures (JEPA) learn by predicting latent representations instead of raw pixels or tokens. The core mathematical insight is elegantly simple.
The basic idea: Instead of trying to predict every pixel in an image, JEPA predicts high level concepts and relationships. This approach, similar to MAE but in latent space, is much more efficient and captures what actually matters.
Given context encoder and target encoder (an EMA copy), along with predictor :
Breaking down the math:
- = input image patches
- = unmasked patches (context)
- = masked patches (what to predict)
- = mask indicating which patches to predict
- = context encoder, = target encoder (EMA copy)
- = predictor network
- = stop gradient on target to prevent collapse
The equation measures how well we can predict masked patch representations from visible context. This forces learning of semantic representations rather than pixel level reconstruction, making it scalable and transferable.
Building JEPA Attach
The architectural fusion
Our JEPA Attach prototype aims to inject multimodal predictive learning signals into an LLM's training loop. The key innovation was creating alignment adapters that project LLM hidden states into JEPA's latent space.
What this means in simple terms: We connect the text model and vision model by training them to understand similar concepts in a shared embedding space, similar to approaches in CLIP. When the text model processes language about visual concepts, it can leverage knowledge learned from images.
Breaking down the math:
- = text embeddings from LLM final layer
- = image embeddings from JEPA encoder
- = shared projection head mapping both to same space
- The equation measures alignment between text and vision in shared embedding space
This creates a bridge between linguistic representations and multimodal world knowledge.
The combined objective
Our total loss elegantly balances language modeling with predictive learning:
What this means: We train the model on three things at once:
- = normal language modeling (predicting next words)
- = visual understanding (predicting image representations)
- = connecting text and images (alignment loss)
- terms = how much weight to give each part
The beauty of this formulation is that scales with scarce text, while scales with abundant visual streams. The LLM receives gradients from both finite text and infinite multimodal data.
Implementation details that matter
Apple Silicon optimization
Our prototype targets Apple Silicon (MPS backend) to make this research approach accessible. The model combines GPT-2 small (124M) with ViT base (86M) plus JEPA components (~2M), totaling 212M parameters perfect for M1/M2 Mac testing.
def create_image_patches(image: torch.Tensor, patch_size: int = 16) -> torch.Tensor:
"""Convert image to patches for JEPA masking"""
B, C, H, W = image.shape
patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches.contiguous().view(B, C, -1, patch_size, patch_size)
def sample_block_mask(num_patches: int, mask_ratio: float = 0.6) -> torch.Tensor:
"""Sample contiguous blocks for masking (not random patches)"""
num_masked = int(mask_ratio * num_patches)
# Sample contiguous blocks - key insight from I-JEPA
start_idx = torch.randint(0, num_patches - num_masked, (1,))
return torch.arange(start_idx, start_idx + num_masked)
def encode_context_target(image_patches: torch.Tensor, mask_indices: torch.Tensor, encoder):
"""Encode visible patches (context) and masked patches (target)"""
visible_patches = image_patches.clone()
visible_patches[:, mask_indices] = 0.0 # Mask out target patches
context_emb = encoder(visible_patches) # Encode visible context
target_emb = encoder(image_patches[:, mask_indices]).detach() # Target with stop grad
return context_emb, target_emb
def compute_jepa_loss(context_emb: torch.Tensor, target_emb: torch.Tensor, predictor) -> torch.Tensor:
"""Predict target representations from context"""
predicted_target = predictor(context_emb) # Predict missing patches
return F.mse_loss(predicted_target, target_emb) # How well did we predict?
def update_target_encoder_ema(model, tau: float = 0.996):
"""Exponential moving average update: θ⁻ ← τθ⁻ + (1-τ)θ"""
with torch.no_grad(): # No gradients for EMA update
for target_param, online_param in zip(model.target_encoder.parameters(), model.online_encoder.parameters()):
target_param.data.mul_(tau).add_(online_param.data, alpha=1.0 - tau)
The fair comparison framework
The critical research question: does JEPA Attach improve data efficiency? Our proposed experimental design would use controlled A/B tests comparing LLM only baselines against JEPA Attach under identical compute budgets. Both models would see the same limited text, but JEPA Attach would also learn from abundant image data.
The hypothesis we're testing
Data efficiency prediction
If the hypothesis holds, we expect JEPA Attach to achieve lower perplexity when both models train on identical text corpora (1000 samples). The key test: does the language model learn better representations when it can also learn from abundant unlabeled images, even though we only evaluate on text? We haven't tested this yet.
Expected representation changes
We hypothesize that probing LLM hidden states after JEPA alignment would reveal an interesting phenomenon: the internal representations should become more grounded in visual semantic concepts. If successful, the model would develop richer abstractions that generalize beyond pure linguistic patterns but this remains to be tested.
The research methodology
Controlled experimental design
Our planned comparison aims to isolate JEPA's contribution from confounding factors:
- Same compute budget: Identical training steps for both models
- Same text data: Both models see identical limited text corpora
- Same evaluation: Language modeling perplexity on held out text
- Only difference: JEPA Attach sees additional visual data
This controls for architecture size effects and focuses on the learning signal.
def setup_baseline_model(device):
"""Standard GPT-2 trained only on text"""
model = GPT2LMHeadModel.from_pretrained("gpt2")
return model.to(device)
def setup_jepa_model(device):
"""GPT-2 + JEPA vision encoder + alignment layers"""
text_model = GPT2Model.from_pretrained("gpt2")
vision_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224")
predictor = PredictorNetwork(hidden_dim=768)
projection_head = ProjectionHead(text_dim=768, vision_dim=768, joint_dim=256)
return JEPAAttachModel(text_model, vision_encoder, predictor, projection_head).to(device)
def train_baseline_step(model, text_batch, optimizer):
"""Train standard GPT-2 on text only"""
outputs = model(text_batch, labels=text_batch)
loss = outputs.loss # Standard next token prediction
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss.item()
def train_jepa_step(model, text_batch, image_batch, optimizer):
"""Train JEPA Attach: text modeling + visual prediction + alignment"""
# 1. Language modeling loss (same as baseline)
text_outputs = model.text_model(text_batch, labels=text_batch)
lm_loss = text_outputs.loss
# 2. JEPA visual prediction loss
image_patches = create_image_patches(image_batch)
mask_indices = sample_block_mask(image_patches.size(2))
context_emb, target_emb = encode_context_target(image_patches, mask_indices, model.vision_encoder)
jepa_loss = compute_jepa_loss(context_emb, target_emb, model.predictor)
# 3. Text vision alignment loss
text_emb = model.text_model(text_batch).last_hidden_state.mean(dim=1)
vision_emb = model.vision_encoder(image_batch).last_hidden_state.mean(dim=1)
alignment_loss = F.mse_loss(model.projection_head(text_emb), model.projection_head(vision_emb))
# Combined objective
total_loss = lm_loss + 0.1 * jepa_loss + 0.05 * alignment_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
return total_loss.item()
def run_controlled_experiment(text_budget: int = 1000, image_budget: int = 10000, steps: int = 1000):
"""Full controlled experiment comparing baseline vs JEPA Attach"""
# Initialize models
baseline_model = setup_baseline_model(device)
jepa_model = setup_jepa_model(device)
# Same text dataset for both models
text_dataset = load_limited_text_data(budget=text_budget)
image_dataset = load_image_data(budget=image_budget) # Only for JEPA model
# Training loop with identical compute budget
for step in range(steps):
# Get same text batch for both models
text_batch = next(text_dataset)
# Train baseline (text only)
baseline_loss = train_baseline_step(baseline_model, text_batch, baseline_optimizer)
# Train JEPA (same text + extra images)
image_batch = next(image_dataset)
jepa_loss = train_jepa_step(jepa_model, text_batch, image_batch, jepa_optimizer)
# Update JEPA target encoder
update_target_encoder_ema(jepa_model)
# Evaluate both on same text only test set
test_text = load_test_text_data()
baseline_ppl = evaluate_perplexity(baseline_model, test_text)
jepa_ppl = evaluate_perplexity(jepa_model.text_model, test_text)
return baseline_ppl, jepa_ppl, (baseline_ppl - jepa_ppl)
Scaling considerations
Starting with modest hardware requirements (Apple Silicon, small datasets), the architecture scales naturally to cloud GPUs and larger corpora. The mathematical formulation remains identical; only the data throughput changes.
What we expect to learn
The multimodal alignment hypothesis
If this approach works, the most interesting finding would be whether LLM representations can quickly align with visual semantics. The hypothesis is that cross modal learning creates meaningful bridges between language and vision that improve text only performance, but this needs empirical validation.
Implementation simplicity
Despite the complex motivation, the actual implementation is surprisingly clean. PyTorch with Hugging Face transformers, standard optimizers, and straightforward loss combinations. No exotic architectures or training tricks required.
Apple Silicon capabilities
Modern Mac hardware with unified memory handles 200M+ parameter multimodal models gracefully. This democratizes JEPA research; you don't need datacenter resources to test the core hypotheses.
The bigger picture
This approach could represent a paradigm shift if the experiments validate the hypothesis. Instead of fighting over scarce text data or risking model collapse with synthetic generation, we might be able to tap into the vast streams of unlabeled multimodal content that grow daily.
The implications would extend beyond academic curiosity. If JEPA Attach proves effective at scale, it could suggest a path to continued AI capability growth that doesn't depend on finite human written text. But first, we need to run the controlled experiments to see if the core hypothesis holds.
What's next
First, we need to run the basic experiments to validate the core hypothesis. Then, if promising, the research agenda would scale to larger models, diverse datasets, and comprehensive baselines including RAG and synthetic data augmentation. The broader vision involves extending JEPA Attach to video, audio, and other modalities where human annotation is sparse but raw data is abundant.
The key insight driving this work: when you hit a wall, don't just push harder find a different path. JEPA Attach might just be that path around the data wall, but we won't know until we test it.
References
Core Papers
- Training Compute-Optimal Large Language Models - Chinchilla scaling laws
- Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture - I-JEPA foundation
- The Curse of Recursion: Training on Generated Data Makes Models Forget - Model collapse analysis
- Will we run out of data? An analysis of the limits of scaling datasets in Machine Learning - Data wall timeline
Technical Resources
- Meta AI I-JEPA Implementation - Official PyTorch code
- Hugging Face Transformers - GPT-2 and ViT models
- PyTorch MPS Documentation - Apple Silicon optimization
Related Work
- CLIP: Learning Transferable Visual Representations - Contrastive vision-language pretraining
- SimCLR: A Simple Framework for Contrastive Learning - Self-supervised representation learning
- MAE: Masked Autoencoders Are Scalable Vision Learners - Masked prediction for vision