Back to Model Architectures

Diffusion Models for World Generation

30 min

Diffusion Models for World Generation

Diffusion models have emerged as the leading approach for generating high-quality video in world models. They learn to reverse a noise process, enabling generation of realistic, temporally coherent sequences.

How Diffusion Works

The diffusion process has two phases:

Forward Process (Adding Noise)

Clean Data x₀ → x₁ → x₂ → ... → xₜ → Pure Noise
                (gradually add noise)

Reverse Process (Denoising)

Pure Noise xₜ → xₜ₋₁ → ... → x₁ → x₀ Clean Data
                (learned denoising)

Mathematical Foundation

Forward Process

The forward process adds Gaussian noise according to a schedule:

python
import torch

def forward_diffusion(x_0, t, noise_schedule):
    """Add noise to data according to timestep t"""
    alpha_bar = noise_schedule.alpha_bar[t]
    noise = torch.randn_like(x_0)
    
    # x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
    x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
    return x_t, noise

Reverse Process (Training Objective)

The model learns to predict the noise:

python
def diffusion_loss(model, x_0, noise_schedule):
    """Simple diffusion training loss"""
    batch_size = x_0.shape[0]
    
    # Sample random timesteps
    t = torch.randint(0, noise_schedule.num_steps, (batch_size,))
    
    # Add noise
    x_t, noise = forward_diffusion(x_0, t, noise_schedule)
    
    # Predict noise
    noise_pred = model(x_t, t)
    
    # MSE loss
    loss = torch.nn.functional.mse_loss(noise_pred, noise)
    return loss

Video Diffusion Architecture

python
class VideoDiffusionModel(nn.Module):
    """3D U-Net for video diffusion"""
    
    def __init__(self, in_channels=3, model_channels=256, num_res_blocks=2):
        super().__init__()
        
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalPosEmb(model_channels),
            nn.Linear(model_channels, model_channels * 4),
            nn.GELU(),
            nn.Linear(model_channels * 4, model_channels * 4)
        )
        
        # Encoder
        self.encoder = nn.ModuleList([
            ResBlock3D(in_channels, model_channels),
            ResBlock3D(model_channels, model_channels),
            Downsample3D(model_channels),
            ResBlock3D(model_channels, model_channels * 2),
            SpatioTemporalAttention(model_channels * 2),
            Downsample3D(model_channels * 2),
        ])
        
        # Middle
        self.middle = nn.Sequential(
            ResBlock3D(model_channels * 2, model_channels * 2),
            SpatioTemporalAttention(model_channels * 2),
            ResBlock3D(model_channels * 2, model_channels * 2),
        )
        
        # Decoder (symmetric to encoder)
        self.decoder = nn.ModuleList([...])
        
        # Output
        self.out = nn.Conv3d(model_channels, in_channels, 3, padding=1)
    
    def forward(self, x, t, condition=None):
        # x: (B, C, T, H, W)
        # t: (B,) timesteps
        
        t_emb = self.time_embed(t)
        
        # Encoder with skip connections
        skips = []
        for layer in self.encoder:
            x = layer(x, t_emb)
            skips.append(x)
        
        # Middle
        x = self.middle(x, t_emb)
        
        # Decoder
        for layer in self.decoder:
            x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, t_emb)
        
        return self.out(x)

Conditioning Mechanisms

World models need to condition generation on various inputs:

1. Text Conditioning (Cross-Attention)

python
class CrossAttention(nn.Module):
    def __init__(self, dim, context_dim):
        super().__init__()
        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(context_dim, dim)
        self.to_v = nn.Linear(context_dim, dim)
        self.to_out = nn.Linear(dim, dim)
    
    def forward(self, x, context):
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)
        
        attn = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]), dim=-1)
        out = attn @ v
        return self.to_out(out)

2. Action Conditioning

python
class ActionConditionedDiffusion(nn.Module):
    def __init__(self, base_model, action_dim):
        super().__init__()
        self.base_model = base_model
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, base_model.model_channels * 4)
        )
    
    def forward(self, x, t, action):
        t_emb = self.base_model.time_embed(t)
        action_emb = self.action_encoder(action)
        combined_emb = t_emb + action_emb
        return self.base_model(x, combined_emb)

Sampling Strategies

DDPM Sampling (Slow but High Quality)

python
@torch.no_grad()
def ddpm_sample(model, shape, noise_schedule, num_steps=1000):
    x = torch.randn(shape)
    
    for t in reversed(range(num_steps)):
        t_batch = torch.full((shape[0],), t)
        noise_pred = model(x, t_batch)
        x = denoise_step(x, noise_pred, t, noise_schedule)
    
    return x

DDIM Sampling (Fast)

python
@torch.no_grad()
def ddim_sample(model, shape, noise_schedule, num_steps=50):
    x = torch.randn(shape)
    timesteps = torch.linspace(999, 0, num_steps).long()
    
    for i, t in enumerate(timesteps[:-1]):
        t_next = timesteps[i + 1]
        noise_pred = model(x, t.expand(shape[0]))
        x = ddim_step(x, noise_pred, t, t_next, noise_schedule)
    
    return x

Summary

Diffusion models provide a powerful framework for video generation in world models. By learning to denoise data, they can generate high-quality, temporally coherent video conditioned on text, actions, or other inputs.

Knowledge Check

Test your understanding with 1 questions