Autonomous Vehicles
28 min
World Models for Autonomous Vehicles
Autonomous vehicles (AVs) are one of the most important applications of world models. They enable vehicles to predict traffic scenarios, plan safe trajectories, and handle edge cases.
The AV Stack
┌─────────────────────────────────────────────────────────┐
│ Autonomous Vehicle │
├─────────────────────────────────────────────────────────┤
│ Sensors → Perception → Prediction → Planning → Control │
│ ↓ ↓ ↓ ↓ ↓ │
│ Camera Detection World Model Path Steering│
│ LiDAR Tracking Forecasting Planning Throttle│
│ Radar Fusion Brake │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ Autonomous Vehicle │
├─────────────────────────────────────────────────────────┤
│ Sensors → Perception → Prediction → Planning → Control │
│ ↓ ↓ ↓ ↓ ↓ │
│ Camera Detection World Model Path Steering│
│ LiDAR Tracking Forecasting Planning Throttle│
│ Radar Fusion Brake │
└─────────────────────────────────────────────────────────┘
World Models in AVs
1. Scene Understanding
World models help AVs understand complex traffic scenes:
python
class TrafficWorldModel(nn.Module):
"""World model for traffic scene understanding"""
def __init__(self):
super().__init__()
self.encoder = SceneEncoder() # Encode multi-sensor input
self.dynamics = TemporalModel() # Model scene dynamics
self.decoder = SceneDecoder() # Reconstruct/predict scene
def forward(self, sensor_data, history):
# Encode current observation
z_t = self.encoder(sensor_data)
# Model temporal dynamics
z_future = self.dynamics(z_t, history)
# Decode predictions
predictions = self.decoder(z_future)
return predictions
class TrafficWorldModel(nn.Module):
"""World model for traffic scene understanding"""
def __init__(self):
super().__init__()
self.encoder = SceneEncoder() # Encode multi-sensor input
self.dynamics = TemporalModel() # Model scene dynamics
self.decoder = SceneDecoder() # Reconstruct/predict scene
def forward(self, sensor_data, history):
# Encode current observation
z_t = self.encoder(sensor_data)
# Model temporal dynamics
z_future = self.dynamics(z_t, history)
# Decode predictions
predictions = self.decoder(z_future)
return predictions
2. Trajectory Prediction
Predict future trajectories of other road users:
python
class TrajectoryPredictor(nn.Module):
"""Predict future trajectories of vehicles and pedestrians"""
def __init__(self, num_modes=6):
super().__init__()
self.encoder = AgentEncoder()
self.map_encoder = MapEncoder()
self.interaction = InteractionModule()
self.decoder = MultiModalDecoder(num_modes)
def forward(self, agent_history, map_data):
# Encode agent histories
agent_features = self.encoder(agent_history)
# Encode map context
map_features = self.map_encoder(map_data)
# Model agent-agent and agent-map interactions
interaction_features = self.interaction(agent_features, map_features)
# Predict multiple possible futures
trajectories, probabilities = self.decoder(interaction_features)
return trajectories, probabilities # (num_agents, num_modes, future_steps, 2)
class TrajectoryPredictor(nn.Module):
"""Predict future trajectories of vehicles and pedestrians"""
def __init__(self, num_modes=6):
super().__init__()
self.encoder = AgentEncoder()
self.map_encoder = MapEncoder()
self.interaction = InteractionModule()
self.decoder = MultiModalDecoder(num_modes)
def forward(self, agent_history, map_data):
# Encode agent histories
agent_features = self.encoder(agent_history)
# Encode map context
map_features = self.map_encoder(map_data)
# Model agent-agent and agent-map interactions
interaction_features = self.interaction(agent_features, map_features)
# Predict multiple possible futures
trajectories, probabilities = self.decoder(interaction_features)
return trajectories, probabilities # (num_agents, num_modes, future_steps, 2)
3. Scenario Generation
Generate diverse training scenarios:
python
class ScenarioGenerator:
"""Generate diverse driving scenarios for training"""
def __init__(self, world_model):
self.world_model = world_model
def generate_scenario(self, initial_state, num_steps=50):
"""Generate a plausible traffic scenario"""
scenario = [initial_state]
state = initial_state
for _ in range(num_steps):
# Sample next state from world model
next_state = self.world_model.sample_next(state)
scenario.append(next_state)
state = next_state
return scenario
def generate_edge_case(self, base_scenario, perturbation_type):
"""Generate edge case by perturbing base scenario"""
if perturbation_type == "sudden_brake":
return self.add_sudden_brake(base_scenario)
elif perturbation_type == "cut_in":
return self.add_cut_in(base_scenario)
elif perturbation_type == "pedestrian_crossing":
return self.add_pedestrian(base_scenario)
class ScenarioGenerator:
"""Generate diverse driving scenarios for training"""
def __init__(self, world_model):
self.world_model = world_model
def generate_scenario(self, initial_state, num_steps=50):
"""Generate a plausible traffic scenario"""
scenario = [initial_state]
state = initial_state
for _ in range(num_steps):
# Sample next state from world model
next_state = self.world_model.sample_next(state)
scenario.append(next_state)
state = next_state
return scenario
def generate_edge_case(self, base_scenario, perturbation_type):
"""Generate edge case by perturbing base scenario"""
if perturbation_type == "sudden_brake":
return self.add_sudden_brake(base_scenario)
elif perturbation_type == "cut_in":
return self.add_cut_in(base_scenario)
elif perturbation_type == "pedestrian_crossing":
return self.add_pedestrian(base_scenario)
Data Sources
Waymo Open Dataset
python
import tensorflow as tf
from waymo_open_dataset import dataset_pb2
def load_waymo_data(tfrecord_path):
dataset = tf.data.TFRecordDataset(tfrecord_path)
for data in dataset:
frame = dataset_pb2.Frame()
frame.ParseFromString(data.numpy())
# Extract camera images
for image in frame.images:
img = tf.image.decode_jpeg(image.image)
# Extract LiDAR points
for laser in frame.lasers:
points = extract_points(laser)
# Extract labels
for label in frame.laser_labels:
box = label.box
obj_type = label.type
yield {
"images": images,
"lidar": points,
"labels": labels
}
import tensorflow as tf
from waymo_open_dataset import dataset_pb2
def load_waymo_data(tfrecord_path):
dataset = tf.data.TFRecordDataset(tfrecord_path)
for data in dataset:
frame = dataset_pb2.Frame()
frame.ParseFromString(data.numpy())
# Extract camera images
for image in frame.images:
img = tf.image.decode_jpeg(image.image)
# Extract LiDAR points
for laser in frame.lasers:
points = extract_points(laser)
# Extract labels
for label in frame.laser_labels:
box = label.box
obj_type = label.type
yield {
"images": images,
"lidar": points,
"labels": labels
}
nuScenes Dataset
python
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version='v1.0-trainval', dataroot='/data/nuscenes')
def get_scene_data(scene_token):
scene = nusc.get('scene', scene_token)
sample_token = scene['first_sample_token']
samples = []
while sample_token:
sample = nusc.get('sample', sample_token)
# Get sensor data
cam_data = nusc.get('sample_data', sample['data']['CAM_FRONT'])
lidar_data = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
samples.append({
'camera': load_image(cam_data['filename']),
'lidar': load_pointcloud(lidar_data['filename']),
'annotations': get_annotations(sample)
})
sample_token = sample['next']
return samples
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version='v1.0-trainval', dataroot='/data/nuscenes')
def get_scene_data(scene_token):
scene = nusc.get('scene', scene_token)
sample_token = scene['first_sample_token']
samples = []
while sample_token:
sample = nusc.get('sample', sample_token)
# Get sensor data
cam_data = nusc.get('sample_data', sample['data']['CAM_FRONT'])
lidar_data = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
samples.append({
'camera': load_image(cam_data['filename']),
'lidar': load_pointcloud(lidar_data['filename']),
'annotations': get_annotations(sample)
})
sample_token = sample['next']
return samples
Simulation for AVs
CARLA Simulator
python
import carla
def collect_driving_data():
client = carla.Client('localhost', 2000)
world = client.get_world()
# Spawn ego vehicle
blueprint = world.get_blueprint_library().find('vehicle.tesla.model3')
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(blueprint, spawn_point)
# Attach sensors
camera_bp = world.get_blueprint_library().find('sensor.camera.rgb')
camera = world.spawn_actor(camera_bp, carla.Transform(), attach_to=vehicle)
lidar_bp = world.get_blueprint_library().find('sensor.lidar.ray_cast')
lidar = world.spawn_actor(lidar_bp, carla.Transform(), attach_to=vehicle)
# Collect data
data = []
camera.listen(lambda img: data.append(process_image(img)))
# Run autopilot
vehicle.set_autopilot(True)
for _ in range(1000):
world.tick()
return data
import carla
def collect_driving_data():
client = carla.Client('localhost', 2000)
world = client.get_world()
# Spawn ego vehicle
blueprint = world.get_blueprint_library().find('vehicle.tesla.model3')
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(blueprint, spawn_point)
# Attach sensors
camera_bp = world.get_blueprint_library().find('sensor.camera.rgb')
camera = world.spawn_actor(camera_bp, carla.Transform(), attach_to=vehicle)
lidar_bp = world.get_blueprint_library().find('sensor.lidar.ray_cast')
lidar = world.spawn_actor(lidar_bp, carla.Transform(), attach_to=vehicle)
# Collect data
data = []
camera.listen(lambda img: data.append(process_image(img)))
# Run autopilot
vehicle.set_autopilot(True)
for _ in range(1000):
world.tick()
return data
Evaluation Metrics
| Metric | Description | Target |
|---|---|---|
| ADE | Average Displacement Error | < 1m |
| FDE | Final Displacement Error | < 2m |
| Miss Rate | Predictions outside threshold | < 10% |
| Collision Rate | Predicted collisions | < 1% |
Summary
World models are essential for autonomous vehicles, enabling prediction, planning, and scenario generation. They help AVs handle the long tail of rare but critical driving situations.
Module Lessons