Transformer Architectures for World Models
25 min
Transformer Architectures for World Models
Transformers have become the dominant architecture for world models, enabling them to process long sequences of visual data and capture complex temporal dependencies.
Why Transformers for World Models?
Transformers excel at world modeling because they can:
- Handle Long Sequences: Attention mechanisms capture long-range dependencies
- Process Multimodal Data: Unified architecture for text, images, video
- Scale Efficiently: Parallelizable training on large datasets
- Enable Conditioning: Easy to condition on actions, text, or other inputs
Video Transformer Architecture
Input Video: [Frame1, Frame2, ..., FrameN]
↓
┌───────────────┐
│ Patchify & │
│ Tokenize │
└───────────────┘
↓
[Token1, Token2, ..., TokenM]
↓
┌───────────────┐
│ Positional │
│ Encoding │
└───────────────┘
↓
┌───────────────┐
│ Transformer │
│ Blocks │
│ (L layers) │
└───────────────┘
↓
┌───────────────┐
│ Prediction │
│ Head │
└───────────────┘
↓
Next Frame Tokens
Input Video: [Frame1, Frame2, ..., FrameN]
↓
┌───────────────┐
│ Patchify & │
│ Tokenize │
└───────────────┘
↓
[Token1, Token2, ..., TokenM]
↓
┌───────────────┐
│ Positional │
│ Encoding │
└───────────────┘
↓
┌───────────────┐
│ Transformer │
│ Blocks │
│ (L layers) │
└───────────────┘
↓
┌───────────────┐
│ Prediction │
│ Head │
└───────────────┘
↓
Next Frame Tokens
Key Components
1. Spatial-Temporal Attention
python
import torch
import torch.nn as nn
class SpatioTemporalAttention(nn.Module):
"""Factorized attention for video transformers"""
def __init__(self, dim, num_heads=8):
super().__init__()
self.spatial_attn = nn.MultiheadAttention(dim, num_heads)
self.temporal_attn = nn.MultiheadAttention(dim, num_heads)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, T, H, W):
B, N, C = x.shape # N = T * H * W
# Reshape for spatial attention: (B*T, H*W, C)
x_spatial = x.view(B * T, H * W, C)
x_spatial = self.norm1(x_spatial + self.spatial_attn(x_spatial, x_spatial, x_spatial)[0])
# Reshape for temporal attention: (B*H*W, T, C)
x_temporal = x_spatial.view(B, T, H * W, C).permute(0, 2, 1, 3).reshape(B * H * W, T, C)
x_temporal = self.norm2(x_temporal + self.temporal_attn(x_temporal, x_temporal, x_temporal)[0])
# Reshape back: (B, T*H*W, C)
out = x_temporal.view(B, H * W, T, C).permute(0, 2, 1, 3).reshape(B, N, C)
return out
import torch
import torch.nn as nn
class SpatioTemporalAttention(nn.Module):
"""Factorized attention for video transformers"""
def __init__(self, dim, num_heads=8):
super().__init__()
self.spatial_attn = nn.MultiheadAttention(dim, num_heads)
self.temporal_attn = nn.MultiheadAttention(dim, num_heads)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, T, H, W):
B, N, C = x.shape # N = T * H * W
# Reshape for spatial attention: (B*T, H*W, C)
x_spatial = x.view(B * T, H * W, C)
x_spatial = self.norm1(x_spatial + self.spatial_attn(x_spatial, x_spatial, x_spatial)[0])
# Reshape for temporal attention: (B*H*W, T, C)
x_temporal = x_spatial.view(B, T, H * W, C).permute(0, 2, 1, 3).reshape(B * H * W, T, C)
x_temporal = self.norm2(x_temporal + self.temporal_attn(x_temporal, x_temporal, x_temporal)[0])
# Reshape back: (B, T*H*W, C)
out = x_temporal.view(B, H * W, T, C).permute(0, 2, 1, 3).reshape(B, N, C)
return out
2. 3D Positional Encoding
python
class PositionalEncoding3D(nn.Module):
"""Learnable 3D positional encoding for video"""
def __init__(self, dim, max_t=32, max_h=16, max_w=16):
super().__init__()
self.temporal_embed = nn.Embedding(max_t, dim)
self.height_embed = nn.Embedding(max_h, dim)
self.width_embed = nn.Embedding(max_w, dim)
def forward(self, x, T, H, W):
B, N, C = x.shape
t_pos = torch.arange(T, device=x.device)
h_pos = torch.arange(H, device=x.device)
w_pos = torch.arange(W, device=x.device)
# Create position grid
t_embed = self.temporal_embed(t_pos)[:, None, None, :] # (T, 1, 1, C)
h_embed = self.height_embed(h_pos)[None, :, None, :] # (1, H, 1, C)
w_embed = self.width_embed(w_pos)[None, None, :, :] # (1, 1, W, C)
pos_embed = (t_embed + h_embed + w_embed).view(1, T * H * W, C)
return x + pos_embed
class PositionalEncoding3D(nn.Module):
"""Learnable 3D positional encoding for video"""
def __init__(self, dim, max_t=32, max_h=16, max_w=16):
super().__init__()
self.temporal_embed = nn.Embedding(max_t, dim)
self.height_embed = nn.Embedding(max_h, dim)
self.width_embed = nn.Embedding(max_w, dim)
def forward(self, x, T, H, W):
B, N, C = x.shape
t_pos = torch.arange(T, device=x.device)
h_pos = torch.arange(H, device=x.device)
w_pos = torch.arange(W, device=x.device)
# Create position grid
t_embed = self.temporal_embed(t_pos)[:, None, None, :] # (T, 1, 1, C)
h_embed = self.height_embed(h_pos)[None, :, None, :] # (1, H, 1, C)
w_embed = self.width_embed(w_pos)[None, None, :, :] # (1, 1, W, C)
pos_embed = (t_embed + h_embed + w_embed).view(1, T * H * W, C)
return x + pos_embed
Notable Architectures
Video Vision Transformer (ViViT)
| Variant | Description | Complexity |
|---|---|---|
| Factorized Encoder | Separate spatial and temporal transformers | O(T·H·W) |
| Factorized Self-Attention | Joint encoder with factorized attention | O(T·H·W) |
| Factorized Dot-Product | Single attention with factorized computation | O(T + H·W) |
Sora Architecture (Diffusion Transformer)
Sora uses a Diffusion Transformer (DiT) architecture:
python
class DiTBlock(nn.Module):
"""Diffusion Transformer block with adaptive layer norm"""
def __init__(self, dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# Adaptive layer norm parameters
self.adaLN_modulation = nn.Linear(dim, 6 * dim)
def forward(self, x, c):
# c is conditioning (timestep, text embedding, etc.)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Modulated attention
x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa
x = x + gate_msa * self.attn(x_norm, x_norm, x_norm)[0]
# Modulated MLP
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + gate_mlp * self.mlp(x_norm)
return x
class DiTBlock(nn.Module):
"""Diffusion Transformer block with adaptive layer norm"""
def __init__(self, dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# Adaptive layer norm parameters
self.adaLN_modulation = nn.Linear(dim, 6 * dim)
def forward(self, x, c):
# c is conditioning (timestep, text embedding, etc.)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Modulated attention
x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa
x = x + gate_msa * self.attn(x_norm, x_norm, x_norm)[0]
# Modulated MLP
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + gate_mlp * self.mlp(x_norm)
return x
Scaling Considerations
| Model Size | Parameters | Training Compute | Typical Use |
|---|---|---|---|
| Small | 100M-500M | 100 GPU-days | Research, prototyping |
| Medium | 1B-5B | 1000 GPU-days | Production applications |
| Large | 10B+ | 10000+ GPU-days | Foundation models |
Summary
Transformer architectures enable world models to process video data effectively through spatial-temporal attention mechanisms. The choice of architecture depends on the trade-off between computational cost and modeling capability.
Knowledge Check
Test your understanding with 1 questions