Diffusion Models for World Generation
30 min
Diffusion Models for World Generation
Diffusion models have emerged as the leading approach for generating high-quality video in world models. They learn to reverse a noise process, enabling generation of realistic, temporally coherent sequences.
How Diffusion Works
The diffusion process has two phases:
Forward Process (Adding Noise)
Clean Data x₀ → x₁ → x₂ → ... → xₜ → Pure Noise
(gradually add noise)
Clean Data x₀ → x₁ → x₂ → ... → xₜ → Pure Noise
(gradually add noise)
Reverse Process (Denoising)
Pure Noise xₜ → xₜ₋₁ → ... → x₁ → x₀ Clean Data
(learned denoising)
Pure Noise xₜ → xₜ₋₁ → ... → x₁ → x₀ Clean Data
(learned denoising)
Mathematical Foundation
Forward Process
The forward process adds Gaussian noise according to a schedule:
python
import torch
def forward_diffusion(x_0, t, noise_schedule):
"""Add noise to data according to timestep t"""
alpha_bar = noise_schedule.alpha_bar[t]
noise = torch.randn_like(x_0)
# x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
return x_t, noise
import torch
def forward_diffusion(x_0, t, noise_schedule):
"""Add noise to data according to timestep t"""
alpha_bar = noise_schedule.alpha_bar[t]
noise = torch.randn_like(x_0)
# x_t = sqrt(alpha_bar) * x_0 + sqrt(1 - alpha_bar) * noise
x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
return x_t, noise
Reverse Process (Training Objective)
The model learns to predict the noise:
python
def diffusion_loss(model, x_0, noise_schedule):
"""Simple diffusion training loss"""
batch_size = x_0.shape[0]
# Sample random timesteps
t = torch.randint(0, noise_schedule.num_steps, (batch_size,))
# Add noise
x_t, noise = forward_diffusion(x_0, t, noise_schedule)
# Predict noise
noise_pred = model(x_t, t)
# MSE loss
loss = torch.nn.functional.mse_loss(noise_pred, noise)
return loss
def diffusion_loss(model, x_0, noise_schedule):
"""Simple diffusion training loss"""
batch_size = x_0.shape[0]
# Sample random timesteps
t = torch.randint(0, noise_schedule.num_steps, (batch_size,))
# Add noise
x_t, noise = forward_diffusion(x_0, t, noise_schedule)
# Predict noise
noise_pred = model(x_t, t)
# MSE loss
loss = torch.nn.functional.mse_loss(noise_pred, noise)
return loss
Video Diffusion Architecture
python
class VideoDiffusionModel(nn.Module):
"""3D U-Net for video diffusion"""
def __init__(self, in_channels=3, model_channels=256, num_res_blocks=2):
super().__init__()
# Time embedding
self.time_embed = nn.Sequential(
SinusoidalPosEmb(model_channels),
nn.Linear(model_channels, model_channels * 4),
nn.GELU(),
nn.Linear(model_channels * 4, model_channels * 4)
)
# Encoder
self.encoder = nn.ModuleList([
ResBlock3D(in_channels, model_channels),
ResBlock3D(model_channels, model_channels),
Downsample3D(model_channels),
ResBlock3D(model_channels, model_channels * 2),
SpatioTemporalAttention(model_channels * 2),
Downsample3D(model_channels * 2),
])
# Middle
self.middle = nn.Sequential(
ResBlock3D(model_channels * 2, model_channels * 2),
SpatioTemporalAttention(model_channels * 2),
ResBlock3D(model_channels * 2, model_channels * 2),
)
# Decoder (symmetric to encoder)
self.decoder = nn.ModuleList([...])
# Output
self.out = nn.Conv3d(model_channels, in_channels, 3, padding=1)
def forward(self, x, t, condition=None):
# x: (B, C, T, H, W)
# t: (B,) timesteps
t_emb = self.time_embed(t)
# Encoder with skip connections
skips = []
for layer in self.encoder:
x = layer(x, t_emb)
skips.append(x)
# Middle
x = self.middle(x, t_emb)
# Decoder
for layer in self.decoder:
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, t_emb)
return self.out(x)
class VideoDiffusionModel(nn.Module):
"""3D U-Net for video diffusion"""
def __init__(self, in_channels=3, model_channels=256, num_res_blocks=2):
super().__init__()
# Time embedding
self.time_embed = nn.Sequential(
SinusoidalPosEmb(model_channels),
nn.Linear(model_channels, model_channels * 4),
nn.GELU(),
nn.Linear(model_channels * 4, model_channels * 4)
)
# Encoder
self.encoder = nn.ModuleList([
ResBlock3D(in_channels, model_channels),
ResBlock3D(model_channels, model_channels),
Downsample3D(model_channels),
ResBlock3D(model_channels, model_channels * 2),
SpatioTemporalAttention(model_channels * 2),
Downsample3D(model_channels * 2),
])
# Middle
self.middle = nn.Sequential(
ResBlock3D(model_channels * 2, model_channels * 2),
SpatioTemporalAttention(model_channels * 2),
ResBlock3D(model_channels * 2, model_channels * 2),
)
# Decoder (symmetric to encoder)
self.decoder = nn.ModuleList([...])
# Output
self.out = nn.Conv3d(model_channels, in_channels, 3, padding=1)
def forward(self, x, t, condition=None):
# x: (B, C, T, H, W)
# t: (B,) timesteps
t_emb = self.time_embed(t)
# Encoder with skip connections
skips = []
for layer in self.encoder:
x = layer(x, t_emb)
skips.append(x)
# Middle
x = self.middle(x, t_emb)
# Decoder
for layer in self.decoder:
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, t_emb)
return self.out(x)
Conditioning Mechanisms
World models need to condition generation on various inputs:
1. Text Conditioning (Cross-Attention)
python
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim):
super().__init__()
self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(context_dim, dim)
self.to_v = nn.Linear(context_dim, dim)
self.to_out = nn.Linear(dim, dim)
def forward(self, x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
attn = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]), dim=-1)
out = attn @ v
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim):
super().__init__()
self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(context_dim, dim)
self.to_v = nn.Linear(context_dim, dim)
self.to_out = nn.Linear(dim, dim)
def forward(self, x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
attn = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]), dim=-1)
out = attn @ v
return self.to_out(out)
2. Action Conditioning
python
class ActionConditionedDiffusion(nn.Module):
def __init__(self, base_model, action_dim):
super().__init__()
self.base_model = base_model
self.action_encoder = nn.Sequential(
nn.Linear(action_dim, 256),
nn.ReLU(),
nn.Linear(256, base_model.model_channels * 4)
)
def forward(self, x, t, action):
t_emb = self.base_model.time_embed(t)
action_emb = self.action_encoder(action)
combined_emb = t_emb + action_emb
return self.base_model(x, combined_emb)
class ActionConditionedDiffusion(nn.Module):
def __init__(self, base_model, action_dim):
super().__init__()
self.base_model = base_model
self.action_encoder = nn.Sequential(
nn.Linear(action_dim, 256),
nn.ReLU(),
nn.Linear(256, base_model.model_channels * 4)
)
def forward(self, x, t, action):
t_emb = self.base_model.time_embed(t)
action_emb = self.action_encoder(action)
combined_emb = t_emb + action_emb
return self.base_model(x, combined_emb)
Sampling Strategies
DDPM Sampling (Slow but High Quality)
python
@torch.no_grad()
def ddpm_sample(model, shape, noise_schedule, num_steps=1000):
x = torch.randn(shape)
for t in reversed(range(num_steps)):
t_batch = torch.full((shape[0],), t)
noise_pred = model(x, t_batch)
x = denoise_step(x, noise_pred, t, noise_schedule)
return x
@torch.no_grad()
def ddpm_sample(model, shape, noise_schedule, num_steps=1000):
x = torch.randn(shape)
for t in reversed(range(num_steps)):
t_batch = torch.full((shape[0],), t)
noise_pred = model(x, t_batch)
x = denoise_step(x, noise_pred, t, noise_schedule)
return x
DDIM Sampling (Fast)
python
@torch.no_grad()
def ddim_sample(model, shape, noise_schedule, num_steps=50):
x = torch.randn(shape)
timesteps = torch.linspace(999, 0, num_steps).long()
for i, t in enumerate(timesteps[:-1]):
t_next = timesteps[i + 1]
noise_pred = model(x, t.expand(shape[0]))
x = ddim_step(x, noise_pred, t, t_next, noise_schedule)
return x
@torch.no_grad()
def ddim_sample(model, shape, noise_schedule, num_steps=50):
x = torch.randn(shape)
timesteps = torch.linspace(999, 0, num_steps).long()
for i, t in enumerate(timesteps[:-1]):
t_next = timesteps[i + 1]
noise_pred = model(x, t.expand(shape[0]))
x = ddim_step(x, noise_pred, t, t_next, noise_schedule)
return x
Summary
Diffusion models provide a powerful framework for video generation in world models. By learning to denoise data, they can generate high-quality, temporally coherent video conditioned on text, actions, or other inputs.
Knowledge Check
Test your understanding with 1 questions