Visual Tokenization Fundamentals
22 min
Visual Tokenization Fundamentals
Tokenization converts high-dimensional visual data (images, video) into discrete or continuous tokens that can be efficiently processed by neural networks. This is a crucial component of modern world models.
Why Tokenize Visual Data?
Raw video is extremely high-dimensional:
- 1080p video at 30fps = 1920 × 1080 × 3 × 30 = 186 million values per second
Tokenization provides:
- Compression: Reduce data dimensionality
- Semantic Units: Create meaningful discrete units
- Efficiency: Enable transformer processing
- Generalization: Learn reusable representations
Types of Tokenization
1. Discrete Tokenization
Converts visual data to integer indices (like words in text):
python
class DiscreteTokenizer(nn.Module):
"""VQ-VAE style discrete tokenizer"""
def __init__(self, vocab_size=8192, embed_dim=256):
super().__init__()
self.encoder = Encoder(output_dim=embed_dim)
self.codebook = nn.Embedding(vocab_size, embed_dim)
self.decoder = Decoder(input_dim=embed_dim)
def encode(self, x):
z = self.encoder(x) # (B, C, H, W)
# Find nearest codebook entries
z_flat = z.permute(0, 2, 3, 1).reshape(-1, z.shape[1])
distances = torch.cdist(z_flat, self.codebook.weight)
indices = distances.argmin(dim=-1)
return indices.view(x.shape[0], -1) # (B, num_tokens)
def decode(self, indices):
z_q = self.codebook(indices) # (B, num_tokens, embed_dim)
z_q = z_q.view(...) # Reshape to spatial
return self.decoder(z_q)
class DiscreteTokenizer(nn.Module):
"""VQ-VAE style discrete tokenizer"""
def __init__(self, vocab_size=8192, embed_dim=256):
super().__init__()
self.encoder = Encoder(output_dim=embed_dim)
self.codebook = nn.Embedding(vocab_size, embed_dim)
self.decoder = Decoder(input_dim=embed_dim)
def encode(self, x):
z = self.encoder(x) # (B, C, H, W)
# Find nearest codebook entries
z_flat = z.permute(0, 2, 3, 1).reshape(-1, z.shape[1])
distances = torch.cdist(z_flat, self.codebook.weight)
indices = distances.argmin(dim=-1)
return indices.view(x.shape[0], -1) # (B, num_tokens)
def decode(self, indices):
z_q = self.codebook(indices) # (B, num_tokens, embed_dim)
z_q = z_q.view(...) # Reshape to spatial
return self.decoder(z_q)
2. Continuous Tokenization
Converts visual data to continuous vectors:
python
class ContinuousTokenizer(nn.Module):
"""VAE-style continuous tokenizer"""
def __init__(self, latent_dim=256):
super().__init__()
self.encoder = Encoder(output_dim=latent_dim * 2) # mean and logvar
self.decoder = Decoder(input_dim=latent_dim)
def encode(self, x):
h = self.encoder(x)
mean, logvar = h.chunk(2, dim=1)
# Reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mean + eps * std
return z, mean, logvar
def decode(self, z):
return self.decoder(z)
class ContinuousTokenizer(nn.Module):
"""VAE-style continuous tokenizer"""
def __init__(self, latent_dim=256):
super().__init__()
self.encoder = Encoder(output_dim=latent_dim * 2) # mean and logvar
self.decoder = Decoder(input_dim=latent_dim)
def encode(self, x):
h = self.encoder(x)
mean, logvar = h.chunk(2, dim=1)
# Reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mean + eps * std
return z, mean, logvar
def decode(self, z):
return self.decoder(z)
Comparison
| Aspect | Discrete | Continuous |
|---|---|---|
| Representation | Integer indices | Float vectors |
| Compression | Higher (fixed vocab) | Variable |
| Generation | Autoregressive | Diffusion/Flow |
| Quality | Good for structure | Better for details |
Video Tokenization
Video tokenizers must handle temporal dimension:
python
class VideoTokenizer(nn.Module):
"""3D tokenizer for video"""
def __init__(self, spatial_downsample=8, temporal_downsample=4):
super().__init__()
self.encoder = nn.Sequential(
# Spatial downsampling
nn.Conv3d(3, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.Conv3d(64, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.Conv3d(128, 256, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
# Temporal downsampling
nn.Conv3d(256, 256, (4, 1, 1), stride=(4, 1, 1), padding=(0, 0, 0)),
)
self.quantizer = VectorQuantizer(num_embeddings=8192, embedding_dim=256)
self.decoder = nn.Sequential(
# Temporal upsampling
nn.ConvTranspose3d(256, 256, (4, 1, 1), stride=(4, 1, 1)),
nn.ReLU(),
# Spatial upsampling
nn.ConvTranspose3d(256, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.ConvTranspose3d(128, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.ConvTranspose3d(64, 3, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
)
def encode(self, video):
# video: (B, C, T, H, W)
z = self.encoder(video)
z_q, indices, loss = self.quantizer(z)
return indices, loss
def decode(self, indices):
z_q = self.quantizer.embed(indices)
return self.decoder(z_q)
class VideoTokenizer(nn.Module):
"""3D tokenizer for video"""
def __init__(self, spatial_downsample=8, temporal_downsample=4):
super().__init__()
self.encoder = nn.Sequential(
# Spatial downsampling
nn.Conv3d(3, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.Conv3d(64, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.Conv3d(128, 256, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
# Temporal downsampling
nn.Conv3d(256, 256, (4, 1, 1), stride=(4, 1, 1), padding=(0, 0, 0)),
)
self.quantizer = VectorQuantizer(num_embeddings=8192, embedding_dim=256)
self.decoder = nn.Sequential(
# Temporal upsampling
nn.ConvTranspose3d(256, 256, (4, 1, 1), stride=(4, 1, 1)),
nn.ReLU(),
# Spatial upsampling
nn.ConvTranspose3d(256, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.ConvTranspose3d(128, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.ReLU(),
nn.ConvTranspose3d(64, 3, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
)
def encode(self, video):
# video: (B, C, T, H, W)
z = self.encoder(video)
z_q, indices, loss = self.quantizer(z)
return indices, loss
def decode(self, indices):
z_q = self.quantizer.embed(indices)
return self.decoder(z_q)
NVIDIA Cosmos Tokenizer
Cosmos uses a state-of-the-art video tokenizer:
- Compression Ratio: 8×8 spatial, 8× temporal
- Codebook Size: 64K tokens
- Quality: Near-lossless reconstruction
- Speed: Real-time encoding/decoding on GPU
Training Tokenizers
python
def train_tokenizer(tokenizer, dataloader, epochs=100):
optimizer = torch.optim.Adam(tokenizer.parameters(), lr=1e-4)
for epoch in range(epochs):
for batch in dataloader:
# Encode and decode
indices, vq_loss = tokenizer.encode(batch)
reconstruction = tokenizer.decode(indices)
# Reconstruction loss
recon_loss = F.mse_loss(reconstruction, batch)
# Perceptual loss (optional but improves quality)
perceptual_loss = lpips(reconstruction, batch)
# Total loss
loss = recon_loss + 0.1 * vq_loss + 0.1 * perceptual_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
def train_tokenizer(tokenizer, dataloader, epochs=100):
optimizer = torch.optim.Adam(tokenizer.parameters(), lr=1e-4)
for epoch in range(epochs):
for batch in dataloader:
# Encode and decode
indices, vq_loss = tokenizer.encode(batch)
reconstruction = tokenizer.decode(indices)
# Reconstruction loss
recon_loss = F.mse_loss(reconstruction, batch)
# Perceptual loss (optional but improves quality)
perceptual_loss = lpips(reconstruction, batch)
# Total loss
loss = recon_loss + 0.1 * vq_loss + 0.1 * perceptual_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
Summary
Visual tokenization is essential for efficient world model training. The choice between discrete and continuous tokenization depends on the downstream task and generation method.
Module Lessons