Back to Training Methods

Reinforcement Learning for World Models

30 min

Reinforcement Learning for World Models

Reinforcement Learning (RL) enables world models to learn optimal behaviors through interaction with environments. This lesson covers how RL is used to train and improve world models.

RL Fundamentals

The RL Framework

Agent ←→ Environment
  ↓           ↓
Action      State, Reward

Key components:

  • State (s): Current observation
  • Action (a): Agent's decision
  • Reward (r): Feedback signal
  • Policy (π): Action selection strategy

World Models in RL

World models serve two key roles in RL:

1. Learning Environment Dynamics

python
class LearnedWorldModel(nn.Module):
    """World model that predicts next state and reward"""
    
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.dynamics = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.state_head = nn.Linear(hidden_dim, state_dim)
        self.reward_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        h = self.dynamics(x)
        next_state = self.state_head(h)
        reward = self.reward_head(h)
        return next_state, reward

2. Planning in Imagination

python
def imagine_trajectory(world_model, policy, initial_state, horizon=50):
    """Roll out policy in imagined world"""
    states = [initial_state]
    actions = []
    rewards = []
    
    state = initial_state
    for _ in range(horizon):
        action = policy(state)
        next_state, reward = world_model(state, action)
        
        states.append(next_state)
        actions.append(action)
        rewards.append(reward)
        
        state = next_state
    
    return states, actions, rewards

Dreamer Algorithm

Dreamer is a state-of-the-art model-based RL algorithm:

python
class Dreamer:
    def __init__(self, state_dim, action_dim):
        # World model components
        self.encoder = Encoder(state_dim)
        self.dynamics = RSSM(latent_dim=256)  # Recurrent State Space Model
        self.decoder = Decoder(state_dim)
        self.reward_model = RewardModel()
        
        # Actor-Critic
        self.actor = Actor(action_dim)
        self.critic = Critic()
    
    def train_world_model(self, replay_buffer, batch_size=50):
        """Train world model on real experience"""
        batch = replay_buffer.sample(batch_size)
        
        # Encode observations
        embedded = self.encoder(batch.observations)
        
        # Run dynamics model
        posts, priors = self.dynamics(embedded, batch.actions)
        
        # Reconstruction loss
        recon = self.decoder(posts)
        recon_loss = F.mse_loss(recon, batch.observations)
        
        # KL divergence loss
        kl_loss = kl_divergence(posts, priors).mean()
        
        # Reward prediction loss
        reward_pred = self.reward_model(posts)
        reward_loss = F.mse_loss(reward_pred, batch.rewards)
        
        return recon_loss + kl_loss + reward_loss
    
    def train_actor_critic(self, horizon=15):
        """Train policy in imagination"""
        # Start from encoded real states
        initial_state = self.dynamics.get_initial_state()
        
        # Imagine trajectories
        imagined_states = []
        imagined_rewards = []
        
        state = initial_state
        for _ in range(horizon):
            action = self.actor(state)
            state = self.dynamics.imagine_step(state, action)
            reward = self.reward_model(state)
            
            imagined_states.append(state)
            imagined_rewards.append(reward)
        
        # Compute returns
        returns = compute_lambda_returns(imagined_rewards, self.critic(imagined_states))
        
        # Actor loss (maximize returns)
        actor_loss = -returns.mean()
        
        # Critic loss (predict returns)
        critic_loss = F.mse_loss(self.critic(imagined_states), returns.detach())
        
        return actor_loss, critic_loss

RLHF for World Models

Reinforcement Learning from Human Feedback can improve world model quality:

python
class RLHFWorldModel:
    def __init__(self, world_model, reward_model):
        self.world_model = world_model
        self.reward_model = reward_model  # Learned from human preferences
    
    def collect_preferences(self, num_comparisons=1000):
        """Collect human preferences between generated videos"""
        preferences = []
        
        for _ in range(num_comparisons):
            # Generate two videos
            video_a = self.world_model.generate()
            video_b = self.world_model.generate()
            
            # Get human preference
            preference = get_human_preference(video_a, video_b)
            preferences.append((video_a, video_b, preference))
        
        return preferences
    
    def train_reward_model(self, preferences):
        """Train reward model on human preferences"""
        for video_a, video_b, pref in preferences:
            reward_a = self.reward_model(video_a)
            reward_b = self.reward_model(video_b)
            
            # Bradley-Terry model
            prob_a_preferred = torch.sigmoid(reward_a - reward_b)
            
            if pref == 'a':
                loss = -torch.log(prob_a_preferred)
            else:
                loss = -torch.log(1 - prob_a_preferred)
            
            loss.backward()
    
    def fine_tune_with_rl(self):
        """Fine-tune world model using learned reward"""
        for _ in range(num_iterations):
            # Generate video
            video = self.world_model.generate()
            
            # Get reward
            reward = self.reward_model(video)
            
            # Policy gradient update
            loss = -reward  # Maximize reward
            loss.backward()

Summary

Reinforcement learning enables world models to learn optimal behaviors and improve through feedback. Whether learning environment dynamics, planning in imagination, or incorporating human preferences, RL is a powerful tool for world model development.