Back to Model Architectures

Transformer Architectures for World Models

25 min

Transformer Architectures for World Models

Transformers have become the dominant architecture for world models, enabling them to process long sequences of visual data and capture complex temporal dependencies.

Why Transformers for World Models?

Transformers excel at world modeling because they can:

  1. Handle Long Sequences: Attention mechanisms capture long-range dependencies
  2. Process Multimodal Data: Unified architecture for text, images, video
  3. Scale Efficiently: Parallelizable training on large datasets
  4. Enable Conditioning: Easy to condition on actions, text, or other inputs

Video Transformer Architecture

Input Video: [Frame1, Frame2, ..., FrameN]

            ┌───────────────┐
            │  Patchify &   │
            │   Tokenize    │
            └───────────────┘

            [Token1, Token2, ..., TokenM]

            ┌───────────────┐
            │   Positional  │
            │   Encoding    │
            └───────────────┘

            ┌───────────────┐
            │  Transformer  │
            │    Blocks     │
            │   (L layers)  │
            └───────────────┘

            ┌───────────────┐
            │   Prediction  │
            │     Head      │
            └───────────────┘

            Next Frame Tokens

Key Components

1. Spatial-Temporal Attention

python
import torch
import torch.nn as nn

class SpatioTemporalAttention(nn.Module):
    """Factorized attention for video transformers"""
    
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.spatial_attn = nn.MultiheadAttention(dim, num_heads)
        self.temporal_attn = nn.MultiheadAttention(dim, num_heads)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    
    def forward(self, x, T, H, W):
        B, N, C = x.shape  # N = T * H * W
        
        # Reshape for spatial attention: (B*T, H*W, C)
        x_spatial = x.view(B * T, H * W, C)
        x_spatial = self.norm1(x_spatial + self.spatial_attn(x_spatial, x_spatial, x_spatial)[0])
        
        # Reshape for temporal attention: (B*H*W, T, C)
        x_temporal = x_spatial.view(B, T, H * W, C).permute(0, 2, 1, 3).reshape(B * H * W, T, C)
        x_temporal = self.norm2(x_temporal + self.temporal_attn(x_temporal, x_temporal, x_temporal)[0])
        
        # Reshape back: (B, T*H*W, C)
        out = x_temporal.view(B, H * W, T, C).permute(0, 2, 1, 3).reshape(B, N, C)
        return out

2. 3D Positional Encoding

python
class PositionalEncoding3D(nn.Module):
    """Learnable 3D positional encoding for video"""
    
    def __init__(self, dim, max_t=32, max_h=16, max_w=16):
        super().__init__()
        self.temporal_embed = nn.Embedding(max_t, dim)
        self.height_embed = nn.Embedding(max_h, dim)
        self.width_embed = nn.Embedding(max_w, dim)
    
    def forward(self, x, T, H, W):
        B, N, C = x.shape
        
        t_pos = torch.arange(T, device=x.device)
        h_pos = torch.arange(H, device=x.device)
        w_pos = torch.arange(W, device=x.device)
        
        # Create position grid
        t_embed = self.temporal_embed(t_pos)[:, None, None, :]  # (T, 1, 1, C)
        h_embed = self.height_embed(h_pos)[None, :, None, :]    # (1, H, 1, C)
        w_embed = self.width_embed(w_pos)[None, None, :, :]     # (1, 1, W, C)
        
        pos_embed = (t_embed + h_embed + w_embed).view(1, T * H * W, C)
        return x + pos_embed

Notable Architectures

Video Vision Transformer (ViViT)

VariantDescriptionComplexity
Factorized EncoderSeparate spatial and temporal transformersO(T·H·W)
Factorized Self-AttentionJoint encoder with factorized attentionO(T·H·W)
Factorized Dot-ProductSingle attention with factorized computationO(T + H·W)

Sora Architecture (Diffusion Transformer)

Sora uses a Diffusion Transformer (DiT) architecture:

python
class DiTBlock(nn.Module):
    """Diffusion Transformer block with adaptive layer norm"""
    
    def __init__(self, dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        # Adaptive layer norm parameters
        self.adaLN_modulation = nn.Linear(dim, 6 * dim)
    
    def forward(self, x, c):
        # c is conditioning (timestep, text embedding, etc.)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c).chunk(6, dim=-1)
        
        # Modulated attention
        x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa
        x = x + gate_msa * self.attn(x_norm, x_norm, x_norm)[0]
        
        # Modulated MLP
        x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
        x = x + gate_mlp * self.mlp(x_norm)
        
        return x

Scaling Considerations

Model SizeParametersTraining ComputeTypical Use
Small100M-500M100 GPU-daysResearch, prototyping
Medium1B-5B1000 GPU-daysProduction applications
Large10B+10000+ GPU-daysFoundation models

Summary

Transformer architectures enable world models to process video data effectively through spatial-temporal attention mechanisms. The choice of architecture depends on the trade-off between computational cost and modeling capability.

Knowledge Check

Test your understanding with 1 questions