Back to Training Methods

Pre-training World Models

28 min

Pre-training World Models

Pre-training is the process of training world models on large-scale data to learn general representations of world dynamics. This lesson covers strategies, objectives, and best practices.

Pre-training Objectives

1. Next Frame Prediction

python
def next_frame_loss(model, video_batch):
    """Predict the next frame given previous frames"""
    context = video_batch[:, :-1]  # All frames except last
    target = video_batch[:, -1]    # Last frame
    
    prediction = model(context)
    loss = F.mse_loss(prediction, target)
    return loss

2. Video Reconstruction (Autoencoding)

python
def reconstruction_loss(model, video_batch):
    """Encode and reconstruct video"""
    latent = model.encode(video_batch)
    reconstruction = model.decode(latent)
    
    loss = F.mse_loss(reconstruction, video_batch)
    return loss

3. Masked Video Modeling

python
def masked_video_loss(model, video_batch, mask_ratio=0.75):
    """Predict masked patches in video"""
    # Create random mask
    B, T, C, H, W = video_batch.shape
    num_patches = (T // patch_t) * (H // patch_h) * (W // patch_w)
    num_masked = int(num_patches * mask_ratio)
    
    mask = create_random_mask(B, num_patches, num_masked)
    
    # Encode visible patches
    visible_patches = apply_mask(video_batch, ~mask)
    encoded = model.encoder(visible_patches)
    
    # Predict masked patches
    predictions = model.decoder(encoded, mask)
    
    # Loss only on masked patches
    targets = extract_patches(video_batch)[mask]
    loss = F.mse_loss(predictions[mask], targets)
    return loss

4. Contrastive Learning

python
def contrastive_loss(model, video_batch, temperature=0.07):
    """Learn representations via contrastive learning"""
    # Create two augmented views
    view1 = augment(video_batch)
    view2 = augment(video_batch)
    
    # Encode both views
    z1 = model.encode(view1)
    z2 = model.encode(view2)
    
    # Normalize
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)
    
    # Compute similarity matrix
    similarity = z1 @ z2.T / temperature
    
    # Labels: diagonal elements are positives
    labels = torch.arange(z1.shape[0], device=z1.device)
    
    loss = F.cross_entropy(similarity, labels)
    return loss

Training Infrastructure

Distributed Training Setup

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    return local_rank

def train_distributed(model, dataloader, epochs):
    local_rank = setup_distributed()
    model = model.to(local_rank)
    model = DDP(model, device_ids=[local_rank])
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()  # Mixed precision
    
    for epoch in range(epochs):
        for batch in dataloader:
            batch = batch.to(local_rank)
            
            with torch.cuda.amp.autocast():
                loss = compute_loss(model, batch)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

Gradient Checkpointing

For large models, use gradient checkpointing to reduce memory:

python
from torch.utils.checkpoint import checkpoint

class MemoryEfficientBlock(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
    
    def forward(self, x):
        return checkpoint(self.block, x, use_reentrant=False)

Training Curriculum

Stage 1: Low Resolution

  • Train on 256×256 video
  • Faster iteration, learn basic dynamics

Stage 2: High Resolution

  • Fine-tune on 512×512 or higher
  • Learn fine details

Stage 3: Long Context

  • Extend temporal context
  • Learn long-range dependencies

Hyperparameters

ParameterTypical ValueNotes
Learning Rate1e-4 to 3e-4With warmup
Batch Size256-1024Per GPU × num GPUs
Weight Decay0.01-0.1AdamW
Warmup Steps1000-10000Linear warmup
Total Steps100K-1MDepends on data size

Monitoring Training

python
import wandb

def log_training_metrics(step, loss, model, sample_batch):
    # Log scalar metrics
    wandb.log({
        "loss": loss,
        "learning_rate": get_lr(optimizer),
        "gradient_norm": compute_grad_norm(model),
    }, step=step)
    
    # Log sample generations periodically
    if step % 1000 == 0:
        with torch.no_grad():
            samples = model.generate(sample_batch[:4])
            wandb.log({"samples": wandb.Video(samples)}, step=step)

Summary

Pre-training world models requires careful choice of objectives, efficient infrastructure, and proper monitoring. The goal is to learn general representations that can be fine-tuned for specific applications.