Back to Tokenization & Representation

Visual Tokenization Fundamentals

22 min

Visual Tokenization Fundamentals

Tokenization converts high-dimensional visual data (images, video) into discrete or continuous tokens that can be efficiently processed by neural networks. This is a crucial component of modern world models.

Why Tokenize Visual Data?

Raw video is extremely high-dimensional:

  • 1080p video at 30fps = 1920 × 1080 × 3 × 30 = 186 million values per second

Tokenization provides:

  1. Compression: Reduce data dimensionality
  2. Semantic Units: Create meaningful discrete units
  3. Efficiency: Enable transformer processing
  4. Generalization: Learn reusable representations

Types of Tokenization

1. Discrete Tokenization

Converts visual data to integer indices (like words in text):

python
class DiscreteTokenizer(nn.Module):
    """VQ-VAE style discrete tokenizer"""
    
    def __init__(self, vocab_size=8192, embed_dim=256):
        super().__init__()
        self.encoder = Encoder(output_dim=embed_dim)
        self.codebook = nn.Embedding(vocab_size, embed_dim)
        self.decoder = Decoder(input_dim=embed_dim)
    
    def encode(self, x):
        z = self.encoder(x)  # (B, C, H, W)
        
        # Find nearest codebook entries
        z_flat = z.permute(0, 2, 3, 1).reshape(-1, z.shape[1])
        distances = torch.cdist(z_flat, self.codebook.weight)
        indices = distances.argmin(dim=-1)
        
        return indices.view(x.shape[0], -1)  # (B, num_tokens)
    
    def decode(self, indices):
        z_q = self.codebook(indices)  # (B, num_tokens, embed_dim)
        z_q = z_q.view(...)  # Reshape to spatial
        return self.decoder(z_q)

2. Continuous Tokenization

Converts visual data to continuous vectors:

python
class ContinuousTokenizer(nn.Module):
    """VAE-style continuous tokenizer"""
    
    def __init__(self, latent_dim=256):
        super().__init__()
        self.encoder = Encoder(output_dim=latent_dim * 2)  # mean and logvar
        self.decoder = Decoder(input_dim=latent_dim)
    
    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = h.chunk(2, dim=1)
        
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mean + eps * std
        
        return z, mean, logvar
    
    def decode(self, z):
        return self.decoder(z)

Comparison

AspectDiscreteContinuous
RepresentationInteger indicesFloat vectors
CompressionHigher (fixed vocab)Variable
GenerationAutoregressiveDiffusion/Flow
QualityGood for structureBetter for details

Video Tokenization

Video tokenizers must handle temporal dimension:

python
class VideoTokenizer(nn.Module):
    """3D tokenizer for video"""
    
    def __init__(self, spatial_downsample=8, temporal_downsample=4):
        super().__init__()
        self.encoder = nn.Sequential(
            # Spatial downsampling
            nn.Conv3d(3, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(),
            nn.Conv3d(64, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(),
            nn.Conv3d(128, 256, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(),
            # Temporal downsampling
            nn.Conv3d(256, 256, (4, 1, 1), stride=(4, 1, 1), padding=(0, 0, 0)),
        )
        
        self.quantizer = VectorQuantizer(num_embeddings=8192, embedding_dim=256)
        
        self.decoder = nn.Sequential(
            # Temporal upsampling
            nn.ConvTranspose3d(256, 256, (4, 1, 1), stride=(4, 1, 1)),
            nn.ReLU(),
            # Spatial upsampling
            nn.ConvTranspose3d(256, 128, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 3, (1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)),
        )
    
    def encode(self, video):
        # video: (B, C, T, H, W)
        z = self.encoder(video)
        z_q, indices, loss = self.quantizer(z)
        return indices, loss
    
    def decode(self, indices):
        z_q = self.quantizer.embed(indices)
        return self.decoder(z_q)

NVIDIA Cosmos Tokenizer

Cosmos uses a state-of-the-art video tokenizer:

  • Compression Ratio: 8×8 spatial, 8× temporal
  • Codebook Size: 64K tokens
  • Quality: Near-lossless reconstruction
  • Speed: Real-time encoding/decoding on GPU

Training Tokenizers

python
def train_tokenizer(tokenizer, dataloader, epochs=100):
    optimizer = torch.optim.Adam(tokenizer.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        for batch in dataloader:
            # Encode and decode
            indices, vq_loss = tokenizer.encode(batch)
            reconstruction = tokenizer.decode(indices)
            
            # Reconstruction loss
            recon_loss = F.mse_loss(reconstruction, batch)
            
            # Perceptual loss (optional but improves quality)
            perceptual_loss = lpips(reconstruction, batch)
            
            # Total loss
            loss = recon_loss + 0.1 * vq_loss + 0.1 * perceptual_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Summary

Visual tokenization is essential for efficient world model training. The choice between discrete and continuous tokenization depends on the downstream task and generation method.