Video Generation with World Models
Video Generation with World Models
World models have revolutionized video generation, enabling AI to create realistic, physically plausible videos from text descriptions or images. This lesson explores how world models power video generation.
The Video Generation Pipeline
Text/Image Input → World Model → Video Frames → Post-Processing → Output Video
↓ ↓ ↓ ↓
"A robot Simulate world Generate each Upscale, smooth
walking" dynamics frame temporal artifacts
Text/Image Input → World Model → Video Frames → Post-Processing → Output Video
↓ ↓ ↓ ↓
"A robot Simulate world Generate each Upscale, smooth
walking" dynamics frame temporal artifacts
Key Models
OpenAI Sora
Sora demonstrates world simulation through video generation:
"Sora is a diffusion model... trained on videos and images of variable durations, resolutions and aspect ratios... We find that video models exhibit a number of interesting emergent capabilities when trained at scale."
Key capabilities:
- 3D consistency
- Long-range coherence
- Object permanence
- Physical interactions
Google Veo
Veo generates high-quality video with:
- 1080p resolution
- 60+ second duration
- Cinematic effects
- Style control
Architecture Deep Dive
Latent Diffusion for Video
class VideoLatentDiffusion(nn.Module):
"""Latent diffusion model for video generation"""
def __init__(self):
super().__init__()
# Video autoencoder
self.encoder = VideoEncoder() # Compress to latent space
self.decoder = VideoDecoder() # Reconstruct from latent
# Diffusion model in latent space
self.unet = VideoUNet3D()
# Text encoder
self.text_encoder = T5Encoder()
def encode(self, video):
"""Encode video to latent space"""
return self.encoder(video)
def decode(self, latent):
"""Decode latent to video"""
return self.decoder(latent)
def forward(self, video, text, timestep):
"""Predict noise in latent space"""
# Encode video
z = self.encode(video)
# Encode text
text_emb = self.text_encoder(text)
# Predict noise
noise_pred = self.unet(z, timestep, text_emb)
return noise_pred
class VideoLatentDiffusion(nn.Module):
"""Latent diffusion model for video generation"""
def __init__(self):
super().__init__()
# Video autoencoder
self.encoder = VideoEncoder() # Compress to latent space
self.decoder = VideoDecoder() # Reconstruct from latent
# Diffusion model in latent space
self.unet = VideoUNet3D()
# Text encoder
self.text_encoder = T5Encoder()
def encode(self, video):
"""Encode video to latent space"""
return self.encoder(video)
def decode(self, latent):
"""Decode latent to video"""
return self.decoder(latent)
def forward(self, video, text, timestep):
"""Predict noise in latent space"""
# Encode video
z = self.encode(video)
# Encode text
text_emb = self.text_encoder(text)
# Predict noise
noise_pred = self.unet(z, timestep, text_emb)
return noise_pred
Autoregressive Video Generation
class AutoregressiveVideoModel(nn.Module):
"""Generate video frame by frame"""
def __init__(self):
super().__init__()
self.frame_encoder = FrameEncoder()
self.temporal_model = TransformerDecoder()
self.frame_decoder = FrameDecoder()
def generate(self, first_frame, text, num_frames=16):
"""Generate video autoregressively"""
frames = [first_frame]
# Encode first frame
z = self.frame_encoder(first_frame)
context = [z]
for _ in range(num_frames - 1):
# Predict next frame latent
z_next = self.temporal_model(
context,
text_conditioning=text
)
# Decode to frame
frame = self.frame_decoder(z_next)
frames.append(frame)
# Update context
context.append(z_next)
return torch.stack(frames, dim=1)
class AutoregressiveVideoModel(nn.Module):
"""Generate video frame by frame"""
def __init__(self):
super().__init__()
self.frame_encoder = FrameEncoder()
self.temporal_model = TransformerDecoder()
self.frame_decoder = FrameDecoder()
def generate(self, first_frame, text, num_frames=16):
"""Generate video autoregressively"""
frames = [first_frame]
# Encode first frame
z = self.frame_encoder(first_frame)
context = [z]
for _ in range(num_frames - 1):
# Predict next frame latent
z_next = self.temporal_model(
context,
text_conditioning=text
)
# Decode to frame
frame = self.frame_decoder(z_next)
frames.append(frame)
# Update context
context.append(z_next)
return torch.stack(frames, dim=1)
Training Video Generation Models
Data Preparation
class VideoDataset(torch.utils.data.Dataset):
def __init__(self, video_paths, caption_file, num_frames=16):
self.video_paths = video_paths
self.captions = load_captions(caption_file)
self.num_frames = num_frames
def __getitem__(self, idx):
# Load video
video = load_video(self.video_paths[idx])
# Sample frames
if len(video) > self.num_frames:
start = random.randint(0, len(video) - self.num_frames)
video = video[start:start + self.num_frames]
# Get caption
caption = self.captions[idx]
# Augmentation
video = self.augment(video)
return {
"video": video,
"caption": caption
}
def augment(self, video):
# Random horizontal flip
if random.random() > 0.5:
video = torch.flip(video, dims=[-1])
# Color jitter
video = color_jitter(video)
return video
class VideoDataset(torch.utils.data.Dataset):
def __init__(self, video_paths, caption_file, num_frames=16):
self.video_paths = video_paths
self.captions = load_captions(caption_file)
self.num_frames = num_frames
def __getitem__(self, idx):
# Load video
video = load_video(self.video_paths[idx])
# Sample frames
if len(video) > self.num_frames:
start = random.randint(0, len(video) - self.num_frames)
video = video[start:start + self.num_frames]
# Get caption
caption = self.captions[idx]
# Augmentation
video = self.augment(video)
return {
"video": video,
"caption": caption
}
def augment(self, video):
# Random horizontal flip
if random.random() > 0.5:
video = torch.flip(video, dims=[-1])
# Color jitter
video = color_jitter(video)
return video
Training Loop
def train_video_model(model, dataloader, epochs=100):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
for epoch in range(epochs):
for batch in dataloader:
video = batch["video"]
caption = batch["caption"]
# Encode video to latent
latent = model.encode(video)
# Sample noise and timestep
noise = torch.randn_like(latent)
timesteps = torch.randint(0, 1000, (video.shape[0],))
# Add noise
noisy_latent = noise_scheduler.add_noise(latent, noise, timesteps)
# Predict noise
noise_pred = model(noisy_latent, caption, timesteps)
# Loss
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def train_video_model(model, dataloader, epochs=100):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
for epoch in range(epochs):
for batch in dataloader:
video = batch["video"]
caption = batch["caption"]
# Encode video to latent
latent = model.encode(video)
# Sample noise and timestep
noise = torch.randn_like(latent)
timesteps = torch.randint(0, 1000, (video.shape[0],))
# Add noise
noisy_latent = noise_scheduler.add_noise(latent, noise, timesteps)
# Predict noise
noise_pred = model(noisy_latent, caption, timesteps)
# Loss
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Inference and Sampling
@torch.no_grad()
def generate_video(model, prompt, num_frames=16, num_steps=50):
"""Generate video from text prompt"""
# Encode prompt
text_emb = model.text_encoder(prompt)
# Initialize with noise
latent_shape = (1, num_frames, 4, 64, 64)
latent = torch.randn(latent_shape)
# Denoising loop
scheduler = DDIMScheduler(num_inference_steps=num_steps)
for t in scheduler.timesteps:
# Predict noise
noise_pred = model.unet(latent, t, text_emb)
# Denoise step
latent = scheduler.step(noise_pred, t, latent).prev_sample
# Decode to video
video = model.decode(latent)
return video
@torch.no_grad()
def generate_video(model, prompt, num_frames=16, num_steps=50):
"""Generate video from text prompt"""
# Encode prompt
text_emb = model.text_encoder(prompt)
# Initialize with noise
latent_shape = (1, num_frames, 4, 64, 64)
latent = torch.randn(latent_shape)
# Denoising loop
scheduler = DDIMScheduler(num_inference_steps=num_steps)
for t in scheduler.timesteps:
# Predict noise
noise_pred = model.unet(latent, t, text_emb)
# Denoise step
latent = scheduler.step(noise_pred, t, latent).prev_sample
# Decode to video
video = model.decode(latent)
return video
Evaluation
Metrics
| Metric | Description | What it Measures |
|---|---|---|
| FVD | Fréchet Video Distance | Overall quality |
| FID | Fréchet Inception Distance | Frame quality |
| CLIP Score | Text-video alignment | Semantic accuracy |
| Temporal Consistency | Frame-to-frame coherence | Smoothness |
Human Evaluation
def human_evaluation_protocol():
"""Protocol for human evaluation of generated videos"""
criteria = {
"visual_quality": "Rate the visual quality (1-5)",
"motion_realism": "How realistic is the motion? (1-5)",
"text_alignment": "Does the video match the prompt? (1-5)",
"temporal_coherence": "Is the video temporally consistent? (1-5)",
"physical_plausibility": "Are physics realistic? (1-5)"
}
return criteria
def human_evaluation_protocol():
"""Protocol for human evaluation of generated videos"""
criteria = {
"visual_quality": "Rate the visual quality (1-5)",
"motion_realism": "How realistic is the motion? (1-5)",
"text_alignment": "Does the video match the prompt? (1-5)",
"temporal_coherence": "Is the video temporally consistent? (1-5)",
"physical_plausibility": "Are physics realistic? (1-5)"
}
return criteria
Applications
- Content Creation: Generate videos for marketing, entertainment
- Prototyping: Visualize concepts before production
- Data Augmentation: Generate training data for other models
- Simulation: Create scenarios for testing autonomous systems
Summary
Video generation with world models represents a major advance in AI capabilities. By learning to simulate world dynamics, these models can generate realistic, coherent videos that follow physical laws and match semantic descriptions.