Back to Applications & Use Cases

Autonomous Vehicles

28 min

World Models for Autonomous Vehicles

Autonomous vehicles (AVs) are one of the most important applications of world models. They enable vehicles to predict traffic scenarios, plan safe trajectories, and handle edge cases.

The AV Stack

┌─────────────────────────────────────────────────────────┐
│                    Autonomous Vehicle                    │
├─────────────────────────────────────────────────────────┤
│  Sensors → Perception → Prediction → Planning → Control │
│    ↓           ↓            ↓           ↓          ↓   │
│  Camera    Detection    World Model   Path      Steering│
│  LiDAR     Tracking     Forecasting   Planning  Throttle│
│  Radar     Fusion                               Brake   │
└─────────────────────────────────────────────────────────┘

World Models in AVs

1. Scene Understanding

World models help AVs understand complex traffic scenes:

python
class TrafficWorldModel(nn.Module):
    """World model for traffic scene understanding"""
    
    def __init__(self):
        super().__init__()
        self.encoder = SceneEncoder()  # Encode multi-sensor input
        self.dynamics = TemporalModel()  # Model scene dynamics
        self.decoder = SceneDecoder()  # Reconstruct/predict scene
    
    def forward(self, sensor_data, history):
        # Encode current observation
        z_t = self.encoder(sensor_data)
        
        # Model temporal dynamics
        z_future = self.dynamics(z_t, history)
        
        # Decode predictions
        predictions = self.decoder(z_future)
        
        return predictions

2. Trajectory Prediction

Predict future trajectories of other road users:

python
class TrajectoryPredictor(nn.Module):
    """Predict future trajectories of vehicles and pedestrians"""
    
    def __init__(self, num_modes=6):
        super().__init__()
        self.encoder = AgentEncoder()
        self.map_encoder = MapEncoder()
        self.interaction = InteractionModule()
        self.decoder = MultiModalDecoder(num_modes)
    
    def forward(self, agent_history, map_data):
        # Encode agent histories
        agent_features = self.encoder(agent_history)
        
        # Encode map context
        map_features = self.map_encoder(map_data)
        
        # Model agent-agent and agent-map interactions
        interaction_features = self.interaction(agent_features, map_features)
        
        # Predict multiple possible futures
        trajectories, probabilities = self.decoder(interaction_features)
        
        return trajectories, probabilities  # (num_agents, num_modes, future_steps, 2)

3. Scenario Generation

Generate diverse training scenarios:

python
class ScenarioGenerator:
    """Generate diverse driving scenarios for training"""
    
    def __init__(self, world_model):
        self.world_model = world_model
    
    def generate_scenario(self, initial_state, num_steps=50):
        """Generate a plausible traffic scenario"""
        scenario = [initial_state]
        
        state = initial_state
        for _ in range(num_steps):
            # Sample next state from world model
            next_state = self.world_model.sample_next(state)
            scenario.append(next_state)
            state = next_state
        
        return scenario
    
    def generate_edge_case(self, base_scenario, perturbation_type):
        """Generate edge case by perturbing base scenario"""
        if perturbation_type == "sudden_brake":
            return self.add_sudden_brake(base_scenario)
        elif perturbation_type == "cut_in":
            return self.add_cut_in(base_scenario)
        elif perturbation_type == "pedestrian_crossing":
            return self.add_pedestrian(base_scenario)

Data Sources

Waymo Open Dataset

python
import tensorflow as tf
from waymo_open_dataset import dataset_pb2

def load_waymo_data(tfrecord_path):
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    
    for data in dataset:
        frame = dataset_pb2.Frame()
        frame.ParseFromString(data.numpy())
        
        # Extract camera images
        for image in frame.images:
            img = tf.image.decode_jpeg(image.image)
        
        # Extract LiDAR points
        for laser in frame.lasers:
            points = extract_points(laser)
        
        # Extract labels
        for label in frame.laser_labels:
            box = label.box
            obj_type = label.type
        
        yield {
            "images": images,
            "lidar": points,
            "labels": labels
        }

nuScenes Dataset

python
from nuscenes.nuscenes import NuScenes

nusc = NuScenes(version='v1.0-trainval', dataroot='/data/nuscenes')

def get_scene_data(scene_token):
    scene = nusc.get('scene', scene_token)
    
    sample_token = scene['first_sample_token']
    samples = []
    
    while sample_token:
        sample = nusc.get('sample', sample_token)
        
        # Get sensor data
        cam_data = nusc.get('sample_data', sample['data']['CAM_FRONT'])
        lidar_data = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
        
        samples.append({
            'camera': load_image(cam_data['filename']),
            'lidar': load_pointcloud(lidar_data['filename']),
            'annotations': get_annotations(sample)
        })
        
        sample_token = sample['next']
    
    return samples

Simulation for AVs

CARLA Simulator

python
import carla

def collect_driving_data():
    client = carla.Client('localhost', 2000)
    world = client.get_world()
    
    # Spawn ego vehicle
    blueprint = world.get_blueprint_library().find('vehicle.tesla.model3')
    spawn_point = random.choice(world.get_map().get_spawn_points())
    vehicle = world.spawn_actor(blueprint, spawn_point)
    
    # Attach sensors
    camera_bp = world.get_blueprint_library().find('sensor.camera.rgb')
    camera = world.spawn_actor(camera_bp, carla.Transform(), attach_to=vehicle)
    
    lidar_bp = world.get_blueprint_library().find('sensor.lidar.ray_cast')
    lidar = world.spawn_actor(lidar_bp, carla.Transform(), attach_to=vehicle)
    
    # Collect data
    data = []
    camera.listen(lambda img: data.append(process_image(img)))
    
    # Run autopilot
    vehicle.set_autopilot(True)
    
    for _ in range(1000):
        world.tick()
    
    return data

Evaluation Metrics

MetricDescriptionTarget
ADEAverage Displacement Error< 1m
FDEFinal Displacement Error< 2m
Miss RatePredictions outside threshold< 10%
Collision RatePredicted collisions< 1%

Summary

World models are essential for autonomous vehicles, enabling prediction, planning, and scenario generation. They help AVs handle the long tail of rare but critical driving situations.