Pre-training World Models
28 min
Pre-training World Models
Pre-training is the process of training world models on large-scale data to learn general representations of world dynamics. This lesson covers strategies, objectives, and best practices.
Pre-training Objectives
1. Next Frame Prediction
python
def next_frame_loss(model, video_batch):
"""Predict the next frame given previous frames"""
context = video_batch[:, :-1] # All frames except last
target = video_batch[:, -1] # Last frame
prediction = model(context)
loss = F.mse_loss(prediction, target)
return loss
def next_frame_loss(model, video_batch):
"""Predict the next frame given previous frames"""
context = video_batch[:, :-1] # All frames except last
target = video_batch[:, -1] # Last frame
prediction = model(context)
loss = F.mse_loss(prediction, target)
return loss
2. Video Reconstruction (Autoencoding)
python
def reconstruction_loss(model, video_batch):
"""Encode and reconstruct video"""
latent = model.encode(video_batch)
reconstruction = model.decode(latent)
loss = F.mse_loss(reconstruction, video_batch)
return loss
def reconstruction_loss(model, video_batch):
"""Encode and reconstruct video"""
latent = model.encode(video_batch)
reconstruction = model.decode(latent)
loss = F.mse_loss(reconstruction, video_batch)
return loss
3. Masked Video Modeling
python
def masked_video_loss(model, video_batch, mask_ratio=0.75):
"""Predict masked patches in video"""
# Create random mask
B, T, C, H, W = video_batch.shape
num_patches = (T // patch_t) * (H // patch_h) * (W // patch_w)
num_masked = int(num_patches * mask_ratio)
mask = create_random_mask(B, num_patches, num_masked)
# Encode visible patches
visible_patches = apply_mask(video_batch, ~mask)
encoded = model.encoder(visible_patches)
# Predict masked patches
predictions = model.decoder(encoded, mask)
# Loss only on masked patches
targets = extract_patches(video_batch)[mask]
loss = F.mse_loss(predictions[mask], targets)
return loss
def masked_video_loss(model, video_batch, mask_ratio=0.75):
"""Predict masked patches in video"""
# Create random mask
B, T, C, H, W = video_batch.shape
num_patches = (T // patch_t) * (H // patch_h) * (W // patch_w)
num_masked = int(num_patches * mask_ratio)
mask = create_random_mask(B, num_patches, num_masked)
# Encode visible patches
visible_patches = apply_mask(video_batch, ~mask)
encoded = model.encoder(visible_patches)
# Predict masked patches
predictions = model.decoder(encoded, mask)
# Loss only on masked patches
targets = extract_patches(video_batch)[mask]
loss = F.mse_loss(predictions[mask], targets)
return loss
4. Contrastive Learning
python
def contrastive_loss(model, video_batch, temperature=0.07):
"""Learn representations via contrastive learning"""
# Create two augmented views
view1 = augment(video_batch)
view2 = augment(video_batch)
# Encode both views
z1 = model.encode(view1)
z2 = model.encode(view2)
# Normalize
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
# Compute similarity matrix
similarity = z1 @ z2.T / temperature
# Labels: diagonal elements are positives
labels = torch.arange(z1.shape[0], device=z1.device)
loss = F.cross_entropy(similarity, labels)
return loss
def contrastive_loss(model, video_batch, temperature=0.07):
"""Learn representations via contrastive learning"""
# Create two augmented views
view1 = augment(video_batch)
view2 = augment(video_batch)
# Encode both views
z1 = model.encode(view1)
z2 = model.encode(view2)
# Normalize
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
# Compute similarity matrix
similarity = z1 @ z2.T / temperature
# Labels: diagonal elements are positives
labels = torch.arange(z1.shape[0], device=z1.device)
loss = F.cross_entropy(similarity, labels)
return loss
Training Infrastructure
Distributed Training Setup
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank
def train_distributed(model, dataloader, epochs):
local_rank = setup_distributed()
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler() # Mixed precision
for epoch in range(epochs):
for batch in dataloader:
batch = batch.to(local_rank)
with torch.cuda.amp.autocast():
loss = compute_loss(model, batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank
def train_distributed(model, dataloader, epochs):
local_rank = setup_distributed()
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler() # Mixed precision
for epoch in range(epochs):
for batch in dataloader:
batch = batch.to(local_rank)
with torch.cuda.amp.autocast():
loss = compute_loss(model, batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Gradient Checkpointing
For large models, use gradient checkpointing to reduce memory:
python
from torch.utils.checkpoint import checkpoint
class MemoryEfficientBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
def forward(self, x):
return checkpoint(self.block, x, use_reentrant=False)
from torch.utils.checkpoint import checkpoint
class MemoryEfficientBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
def forward(self, x):
return checkpoint(self.block, x, use_reentrant=False)
Training Curriculum
Stage 1: Low Resolution
- Train on 256×256 video
- Faster iteration, learn basic dynamics
Stage 2: High Resolution
- Fine-tune on 512×512 or higher
- Learn fine details
Stage 3: Long Context
- Extend temporal context
- Learn long-range dependencies
Hyperparameters
| Parameter | Typical Value | Notes |
|---|---|---|
| Learning Rate | 1e-4 to 3e-4 | With warmup |
| Batch Size | 256-1024 | Per GPU × num GPUs |
| Weight Decay | 0.01-0.1 | AdamW |
| Warmup Steps | 1000-10000 | Linear warmup |
| Total Steps | 100K-1M | Depends on data size |
Monitoring Training
python
import wandb
def log_training_metrics(step, loss, model, sample_batch):
# Log scalar metrics
wandb.log({
"loss": loss,
"learning_rate": get_lr(optimizer),
"gradient_norm": compute_grad_norm(model),
}, step=step)
# Log sample generations periodically
if step % 1000 == 0:
with torch.no_grad():
samples = model.generate(sample_batch[:4])
wandb.log({"samples": wandb.Video(samples)}, step=step)
import wandb
def log_training_metrics(step, loss, model, sample_batch):
# Log scalar metrics
wandb.log({
"loss": loss,
"learning_rate": get_lr(optimizer),
"gradient_norm": compute_grad_norm(model),
}, step=step)
# Log sample generations periodically
if step % 1000 == 0:
with torch.no_grad():
samples = model.generate(sample_batch[:4])
wandb.log({"samples": wandb.Video(samples)}, step=step)
Summary
Pre-training world models requires careful choice of objectives, efficient infrastructure, and proper monitoring. The goal is to learn general representations that can be fine-tuned for specific applications.