Reinforcement Learning for World Models
30 min
Reinforcement Learning for World Models
Reinforcement Learning (RL) enables world models to learn optimal behaviors through interaction with environments. This lesson covers how RL is used to train and improve world models.
RL Fundamentals
The RL Framework
Agent ←→ Environment
↓ ↓
Action State, Reward
Agent ←→ Environment
↓ ↓
Action State, Reward
Key components:
- State (s): Current observation
- Action (a): Agent's decision
- Reward (r): Feedback signal
- Policy (π): Action selection strategy
World Models in RL
World models serve two key roles in RL:
1. Learning Environment Dynamics
python
class LearnedWorldModel(nn.Module):
"""World model that predicts next state and reward"""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.dynamics = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.state_head = nn.Linear(hidden_dim, state_dim)
self.reward_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
h = self.dynamics(x)
next_state = self.state_head(h)
reward = self.reward_head(h)
return next_state, reward
class LearnedWorldModel(nn.Module):
"""World model that predicts next state and reward"""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.dynamics = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.state_head = nn.Linear(hidden_dim, state_dim)
self.reward_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
h = self.dynamics(x)
next_state = self.state_head(h)
reward = self.reward_head(h)
return next_state, reward
2. Planning in Imagination
python
def imagine_trajectory(world_model, policy, initial_state, horizon=50):
"""Roll out policy in imagined world"""
states = [initial_state]
actions = []
rewards = []
state = initial_state
for _ in range(horizon):
action = policy(state)
next_state, reward = world_model(state, action)
states.append(next_state)
actions.append(action)
rewards.append(reward)
state = next_state
return states, actions, rewards
def imagine_trajectory(world_model, policy, initial_state, horizon=50):
"""Roll out policy in imagined world"""
states = [initial_state]
actions = []
rewards = []
state = initial_state
for _ in range(horizon):
action = policy(state)
next_state, reward = world_model(state, action)
states.append(next_state)
actions.append(action)
rewards.append(reward)
state = next_state
return states, actions, rewards
Dreamer Algorithm
Dreamer is a state-of-the-art model-based RL algorithm:
python
class Dreamer:
def __init__(self, state_dim, action_dim):
# World model components
self.encoder = Encoder(state_dim)
self.dynamics = RSSM(latent_dim=256) # Recurrent State Space Model
self.decoder = Decoder(state_dim)
self.reward_model = RewardModel()
# Actor-Critic
self.actor = Actor(action_dim)
self.critic = Critic()
def train_world_model(self, replay_buffer, batch_size=50):
"""Train world model on real experience"""
batch = replay_buffer.sample(batch_size)
# Encode observations
embedded = self.encoder(batch.observations)
# Run dynamics model
posts, priors = self.dynamics(embedded, batch.actions)
# Reconstruction loss
recon = self.decoder(posts)
recon_loss = F.mse_loss(recon, batch.observations)
# KL divergence loss
kl_loss = kl_divergence(posts, priors).mean()
# Reward prediction loss
reward_pred = self.reward_model(posts)
reward_loss = F.mse_loss(reward_pred, batch.rewards)
return recon_loss + kl_loss + reward_loss
def train_actor_critic(self, horizon=15):
"""Train policy in imagination"""
# Start from encoded real states
initial_state = self.dynamics.get_initial_state()
# Imagine trajectories
imagined_states = []
imagined_rewards = []
state = initial_state
for _ in range(horizon):
action = self.actor(state)
state = self.dynamics.imagine_step(state, action)
reward = self.reward_model(state)
imagined_states.append(state)
imagined_rewards.append(reward)
# Compute returns
returns = compute_lambda_returns(imagined_rewards, self.critic(imagined_states))
# Actor loss (maximize returns)
actor_loss = -returns.mean()
# Critic loss (predict returns)
critic_loss = F.mse_loss(self.critic(imagined_states), returns.detach())
return actor_loss, critic_loss
class Dreamer:
def __init__(self, state_dim, action_dim):
# World model components
self.encoder = Encoder(state_dim)
self.dynamics = RSSM(latent_dim=256) # Recurrent State Space Model
self.decoder = Decoder(state_dim)
self.reward_model = RewardModel()
# Actor-Critic
self.actor = Actor(action_dim)
self.critic = Critic()
def train_world_model(self, replay_buffer, batch_size=50):
"""Train world model on real experience"""
batch = replay_buffer.sample(batch_size)
# Encode observations
embedded = self.encoder(batch.observations)
# Run dynamics model
posts, priors = self.dynamics(embedded, batch.actions)
# Reconstruction loss
recon = self.decoder(posts)
recon_loss = F.mse_loss(recon, batch.observations)
# KL divergence loss
kl_loss = kl_divergence(posts, priors).mean()
# Reward prediction loss
reward_pred = self.reward_model(posts)
reward_loss = F.mse_loss(reward_pred, batch.rewards)
return recon_loss + kl_loss + reward_loss
def train_actor_critic(self, horizon=15):
"""Train policy in imagination"""
# Start from encoded real states
initial_state = self.dynamics.get_initial_state()
# Imagine trajectories
imagined_states = []
imagined_rewards = []
state = initial_state
for _ in range(horizon):
action = self.actor(state)
state = self.dynamics.imagine_step(state, action)
reward = self.reward_model(state)
imagined_states.append(state)
imagined_rewards.append(reward)
# Compute returns
returns = compute_lambda_returns(imagined_rewards, self.critic(imagined_states))
# Actor loss (maximize returns)
actor_loss = -returns.mean()
# Critic loss (predict returns)
critic_loss = F.mse_loss(self.critic(imagined_states), returns.detach())
return actor_loss, critic_loss
RLHF for World Models
Reinforcement Learning from Human Feedback can improve world model quality:
python
class RLHFWorldModel:
def __init__(self, world_model, reward_model):
self.world_model = world_model
self.reward_model = reward_model # Learned from human preferences
def collect_preferences(self, num_comparisons=1000):
"""Collect human preferences between generated videos"""
preferences = []
for _ in range(num_comparisons):
# Generate two videos
video_a = self.world_model.generate()
video_b = self.world_model.generate()
# Get human preference
preference = get_human_preference(video_a, video_b)
preferences.append((video_a, video_b, preference))
return preferences
def train_reward_model(self, preferences):
"""Train reward model on human preferences"""
for video_a, video_b, pref in preferences:
reward_a = self.reward_model(video_a)
reward_b = self.reward_model(video_b)
# Bradley-Terry model
prob_a_preferred = torch.sigmoid(reward_a - reward_b)
if pref == 'a':
loss = -torch.log(prob_a_preferred)
else:
loss = -torch.log(1 - prob_a_preferred)
loss.backward()
def fine_tune_with_rl(self):
"""Fine-tune world model using learned reward"""
for _ in range(num_iterations):
# Generate video
video = self.world_model.generate()
# Get reward
reward = self.reward_model(video)
# Policy gradient update
loss = -reward # Maximize reward
loss.backward()
class RLHFWorldModel:
def __init__(self, world_model, reward_model):
self.world_model = world_model
self.reward_model = reward_model # Learned from human preferences
def collect_preferences(self, num_comparisons=1000):
"""Collect human preferences between generated videos"""
preferences = []
for _ in range(num_comparisons):
# Generate two videos
video_a = self.world_model.generate()
video_b = self.world_model.generate()
# Get human preference
preference = get_human_preference(video_a, video_b)
preferences.append((video_a, video_b, preference))
return preferences
def train_reward_model(self, preferences):
"""Train reward model on human preferences"""
for video_a, video_b, pref in preferences:
reward_a = self.reward_model(video_a)
reward_b = self.reward_model(video_b)
# Bradley-Terry model
prob_a_preferred = torch.sigmoid(reward_a - reward_b)
if pref == 'a':
loss = -torch.log(prob_a_preferred)
else:
loss = -torch.log(1 - prob_a_preferred)
loss.backward()
def fine_tune_with_rl(self):
"""Fine-tune world model using learned reward"""
for _ in range(num_iterations):
# Generate video
video = self.world_model.generate()
# Get reward
reward = self.reward_model(video)
# Policy gradient update
loss = -reward # Maximize reward
loss.backward()
Summary
Reinforcement learning enables world models to learn optimal behaviors and improve through feedback. Whether learning environment dynamics, planning in imagination, or incorporating human preferences, RL is a powerful tool for world model development.