Jared AI Hub
Published on

Vision Transformer (ViT): Bringing Transformers to Computer Vision

Authors
  • avatar
    Name
    Jared Chung
    Twitter

Imagine if you could look at a photo the same way you read a book - breaking it into meaningful chunks and understanding how each piece relates to the others. That's exactly what Vision Transformers (ViT) do, and they've revolutionized computer vision in the process.

In 2020, Google's "An Image is Worth 16x16 Words" paper proved that the attention mechanisms powering language models could work just as well for images. The result? A simpler, more elegant approach that often outperforms traditional convolutional neural networks.

The Problem with Traditional Approaches

Convolutional Neural Networks (CNNs) dominated computer vision for decades by mimicking how our visual cortex processes images - starting with local features (edges, textures) and gradually building up to global understanding (objects, scenes).

But CNNs have limitations:

  • Local Bias: They focus on nearby pixels, potentially missing long-range relationships
  • Fixed Patterns: Convolution filters are the same everywhere in the image
  • Inductive Bias: They assume nearby pixels are more related than distant ones

What if we could process all parts of an image simultaneously, like attention does for text?

The Vision Transformer Breakthrough

ViT's core insight: treat image patches like words in a sentence.

The "16x16 Words" Analogy

Think of reading a comic book:

  • Instead of reading letter by letter (pixel by pixel)
  • You look at panels (patches) and understand the story
  • Each panel can reference any other panel (global attention)
  • The sequence of panels matters (positional encoding)

Here's how ViT transforms images:

Step 1: Divide and Conquer

Original Image (224×224)  →  196 patches (each 16×16)
[Entire photo]           →  [Panel 1][Panel 2]...[Panel 196]

Step 2: Make Patches "Speakable" Each 16×16×3 patch (768 numbers) becomes a single vector that the transformer can understand - like turning a complex idea into a word.

Step 3: Add Position Memory Since attention doesn't inherently understand spatial layout, we add positional embeddings - like numbering comic panels so we know their arrangement.

Step 4: Let Attention Work Its Magic Now patches can "ask" other patches: "What are you showing that's relevant to what I'm showing?"

Understanding ViT Through Examples

Example 1: Recognizing a Dog

Traditional CNN thinking:

  • Early layers: "I see edges and textures"
  • Middle layers: "These edges form ear shapes and fur patterns"
  • Later layers: "These patterns combine to form a dog"

ViT thinking:

  • "The patch with the nose should pay attention to patches with ears"
  • "Patches with fur texture should connect to patches with similar texture"
  • "The background patches can be largely ignored"

Example 2: Understanding Scenes

For a photo of "person riding a bicycle":

  • Person patches attend to bicycle patches to understand the relationship
  • Background patches get less attention when classifying the main action
  • Global context emerges from patch interactions

The Core Innovation: Patch Embeddings

Let's see how ViT converts images into "words":

import torch
import torch.nn as nn

def image_to_patches(image, patch_size=16):
    """
    Convert an image into patches - like cutting a photo into puzzle pieces
    
    Input: image of shape (3, 224, 224) - RGB image
    Output: sequence of patches (196, 768) - 196 "words" of 768 dimensions each
    """
    batch_size, channels, height, width = image.shape
    
    # Calculate number of patches
    num_patches_h = height // patch_size  # 224 // 16 = 14
    num_patches_w = width // patch_size   # 224 // 16 = 14
    
    # Reshape image into patches
    patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    # Shape: (batch, 3, 14, 14, 16, 16)
    
    # Flatten each patch
    patches = patches.contiguous().view(batch_size, -1, patch_size * patch_size * channels)
    # Shape: (batch, 196, 768) - 196 patches, each with 768 values
    
    return patches

# Example: Transform a typical ImageNet image
image = torch.randn(1, 3, 224, 224)  # Batch of 1 RGB image
patches = image_to_patches(image)
print(f"Original image: {image.shape}")
print(f"Patches: {patches.shape}")  # (1, 196, 768)

Adding the Special [CLS] Token

Just like BERT uses a special [CLS] token for classification, ViT adds a learnable "class token":

class VisionTransformerEmbedding(nn.Module):
    def __init__(self, image_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        
        # Convert patches to embeddings
        self.patch_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)
        
        # Special [CLS] token for classification
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Positional embeddings (where is each patch located?)
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
    
    def forward(self, image):
        # Step 1: Break image into patches
        patches = image_to_patches(image, self.patch_size)
        batch_size = patches.shape[0]
        
        # Step 2: Convert patches to embeddings
        patch_embeddings = self.patch_embedding(patches)
        
        # Step 3: Add [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat([cls_tokens, patch_embeddings], dim=1)
        
        # Step 4: Add positional information
        embeddings = embeddings + self.pos_embedding
        
        return embeddings

# Example usage
vit_embedding = VisionTransformerEmbedding()
image = torch.randn(2, 3, 224, 224)  # Batch of 2 images
embeddings = vit_embedding(image)
print(f"Output shape: {embeddings.shape}")  # (2, 197, 768)
# 197 = 1 [CLS] token + 196 image patches

Why Vision Transformers Work So Well

1. Global Understanding from Day One

Unlike CNNs that build understanding layer by layer, ViT can connect any patch to any other patch immediately. A dog's ear can directly "talk to" its tail.

2. Adaptive Attention Patterns

Different attention heads learn different types of relationships:

  • Spatial heads: Focus on nearby patches (like CNN receptive fields)
  • Semantic heads: Connect patches with similar content regardless of distance
  • Object heads: Link patches belonging to the same object

3. Scale Beautifully

ViTs improve predictably with more data and larger models, following similar scaling laws to language models.

4. Transfer Learning Champions

Pre-trained ViTs transfer exceptionally well to new tasks - learning general visual representations that work across domains.

Comparing ViT with CNNs

AspectCNNsVision Transformers
ProcessingLocal → GlobalGlobal from start
Inductive BiasStrong spatial biasMinimal assumptions
Data EfficiencyGood with small dataNeeds large datasets
InterpretabilityHard to interpretAttention maps show reasoning
FlexibilityFixed filter patternsAdaptive attention

When to Use Vision Transformers

ViTs Excel When:

  • You have large datasets (ImageNet-21k scale)
  • You need interpretable attention patterns
  • You want unified architecture across modalities
  • Transfer learning is important

CNNs Still Win When:

  • Working with small datasets
  • Need strong inductive biases
  • Computational efficiency is critical
  • Working with very high-resolution images

The Bigger Picture: Unified Architectures

ViT's true significance isn't just better image classification - it's proof that attention is all you need across modalities:

  • Text: Transformers (GPT, BERT)
  • Images: Vision Transformers
  • Audio: Audio Transformers
  • Video: Video Transformers
  • Multimodal: CLIP, DALL-E

This convergence enables:

  • Shared knowledge across modalities
  • Simpler architectures (one design for everything)
  • Cross-modal understanding (text + images)

Key Takeaways for Learning

Mental Models:

  • Patches = Words: Treat image regions like tokens in a sentence
  • Global Attention: Every patch can "see" every other patch
  • Position Matters: Positional encoding teaches spatial relationships
  • CLS Token: Special token summarizes the entire image

When to Remember ViT:

  • Understanding modern multimodal models (CLIP, DALL-E)
  • Designing attention-based vision systems
  • Thinking about inductive biases in ML
  • Appreciating the power of unified architectures

The Fundamental Insight: CNNs assume local patterns are most important. ViTs let the data decide what's important through attention. Sometimes the dog's tail is more informative than the nearby grass for recognizing "dog."

What's Next?

Understanding ViT opens doors to:

  • Multimodal Models: How CLIP connects text and images
  • Object Detection: DETR and attention-based detection
  • Generative Models: How diffusion models use attention
  • Efficient Transformers: Making ViTs faster and smaller

Vision Transformers proved that good ideas transcend domains. The attention mechanism that revolutionized NLP works just as well for computer vision, opening possibilities we're still exploring today.

References

  • Dosovitskiy, A., et al. (2020). "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale."
  • Vaswani, A., et al. (2017). "Attention is All You Need."
  • Radford, A., et al. (2021). "Learning Transferable Visual Models From Natural Language Supervision."