Vision Transformer (ViT)

Transformer architecture adapted for computer vision tasks, processing images as sequences of patches.

What is Vision Transformer (ViT)?

Vision Transformer (ViT) is a neural network architecture that applies the transformer model, originally designed for natural language processing, to computer vision tasks. Instead of processing images through convolutional layers, ViT divides images into fixed-size patches and processes them as sequences, similar to how transformers handle tokens in text.

Key Characteristics

  • Patch-Based Processing: Divides images into non-overlapping patches
  • Transformer Architecture: Uses self-attention mechanisms
  • Positional Embeddings: Encodes spatial information
  • Global Context: Captures long-range dependencies
  • Scalability: Performs well with large datasets
  • Parallel Processing: Efficient computation on modern hardware
  • Transfer Learning: Effective with pre-training on large datasets
  • Flexible Input Size: Can handle variable image sizes

Architecture Overview

graph TD
    A[Input Image] --> B[Patch Embedding]
    B --> C[Add Positional Embeddings]
    C --> D[Transformer Encoder]
    D --> E[Classification Head]
    D --> F[Feature Representation]

    subgraph Transformer Encoder
        G[Layer Normalization] --> H[Multi-Head Attention]
        H --> I[Residual Connection]
        I --> J[Layer Normalization]
        J --> K[MLP]
        K --> L[Residual Connection]
    end

Core Components

Patch Embedding

# Patch embedding implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Linear projection of patches
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        """
        x: (batch_size, in_channels, img_size, img_size)
        Returns: (batch_size, n_patches, embed_dim)
        """
        # Project patches
        x = self.proj(x)  # (batch_size, embed_dim, n_patches**0.5, n_patches**0.5)

        # Flatten patches
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)

        # Transpose to get (batch_size, n_patches, embed_dim)
        x = x.transpose(1, 2)

        return x

Positional Embedding

# Positional embedding implementation
class PositionalEmbedding(nn.Module):
    def __init__(self, n_patches, embed_dim):
        super(PositionalEmbedding, self).__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        """
        x: (batch_size, n_patches, embed_dim)
        Returns: (batch_size, n_patches + 1, embed_dim)
        """
        # Add class token
        batch_size = x.size(0)
        class_token = nn.Parameter(torch.zeros(batch_size, 1, x.size(2))).to(x.device)
        nn.init.trunc_normal_(class_token, std=0.02)

        x = torch.cat([class_token, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embed

        return x

Transformer Encoder

# Transformer encoder block implementation
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Self-attention
        x_norm = self.norm1(x)
        attn_output, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_output

        # MLP
        x_norm = self.norm2(x)
        mlp_output = self.mlp(x_norm)
        x = x + mlp_output

        return x

Complete ViT Architecture

# Complete Vision Transformer implementation
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        self.pos_embed = PositionalEmbedding(n_patches, embed_dim)

        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)

        # Add positional embeddings
        x = self.pos_embed(x)

        # Transformer encoder
        for block in self.blocks:
            x = block(x)

        # Classification
        x = self.norm(x)
        class_token = x[:, 0]  # Take class token
        x = self.head(class_token)

        return x

ViT Variants

Standard ViT Architectures

VariantLayers (depth)Embedding DimHeadsParametersUse Case
ViT-Tiny121923~5MLightweight applications
ViT-Small123846~22MMobile/edge devices
ViT-Base1276812~86MGeneral purpose
ViT-Large24102416~307MHigh performance
ViT-Huge32128016~632MVery complex tasks

Modified ViT Architectures

# Data-efficient Image Transformer (DeiT)
class DeiT(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super(DeiT, self).__init__(*args, **kwargs)
        # Add distillation token
        self.dist_token = nn.Parameter(torch.zeros(1, 1, kwargs['embed_dim']))
        nn.init.trunc_normal_(self.dist_token, std=0.02)

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)

        # Add class and distillation tokens
        batch_size = x.size(0)
        class_token = self.pos_embed.pos_embed[:, :1, :].expand(batch_size, -1, -1)
        dist_token = self.dist_token.expand(batch_size, -1, -1)
        x = torch.cat([class_token, dist_token, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embed.pos_embed

        # Transformer encoder
        for block in self.blocks:
            x = block(x)

        # Classification
        x = self.norm(x)
        class_token = x[:, 0]
        dist_token = x[:, 1]
        x = (self.head(class_token) + self.head(dist_token)) / 2

        return x

# Swin Transformer
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift_size=0):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        # Layer normalization
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Window-based multi-head attention
        self.attn = nn.MultiheadAttention(dim, num_heads)

        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )

    def forward(self, x, H, W):
        # Window partitioning
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        # Apply window attention
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Shift windows if needed
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

        # Partition windows
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # Window attention
        attn_output, _ = self.attn(x_windows, x_windows, x_windows)
        attn_output = attn_output.view(-1, self.window_size, self.window_size, C)
        x = window_reverse(attn_output, self.window_size, H, W)

        # Reverse shift if needed
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        x = x.view(B, H * W, C)
        x = x + attn_output.view(B, H * W, C)

        # MLP
        x = x + self.mlp(self.norm2(x))

        return x

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = windows.view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H, W, -1)
    return x

ViT vs CNNs

FeatureVision Transformer (ViT)Convolutional Neural Networks (CNNs)
Feature ExtractionPatch-based, global attentionLocal convolutional filters
Inductive BiasMinimal (learns from data)Strong (translation equivariance)
Global ContextBuilt-in through self-attentionRequires deep stacks
Parameter EfficiencyMore parameters for same performanceFewer parameters
Data EfficiencyRequires large datasetsWorks well with smaller datasets
Training StabilityCan be unstableGenerally stable
Hardware UtilizationOptimized for parallel processingOptimized for convolution operations
InterpretabilityAttention mapsFeature maps
ScalabilityExcellent with large dataLimited by architecture
Transfer LearningExcellent with pre-trainingGood with pre-training
Computational CostHigh (quadratic attention)Lower (linear convolutions)
Memory UsageHighLower
FlexibilityHigh (variable input sizes)Limited (fixed receptive fields)

Training ViT

Training Configuration

# Training configuration for ViT
def get_training_config():
    return {
        'optimizer': 'AdamW',
        'learning_rate': 3e-4,
        'weight_decay': 0.05,
        'lr_scheduler': {
            'type': 'CosineAnnealingLR',
            'T_max': 300,
            'eta_min': 1e-5
        },
        'batch_size': 1024,
        'epochs': 300,
        'augmentation': {
            'random_resized_crop': True,
            'random_horizontal_flip': True,
            'color_jitter': 0.4,
            'auto_augment': 'rand-m9-mstd0.5-inc1',
            'mixup': 0.8,
            'cutmix': 1.0,
            'label_smoothing': 0.1
        },
        'gradient_clipping': 1.0,
        'warmup_epochs': 5
    }

Training Loop

# Training loop for ViT
def train_vit(model, train_loader, val_loader, config, device):
    # Optimizer
    if config['optimizer'] == 'AdamW':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=config['learning_rate'],
            momentum=0.9,
            weight_decay=config['weight_decay']
        )

    # Learning rate scheduler
    if config['lr_scheduler']['type'] == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config['lr_scheduler']['T_max'],
            eta_min=config['lr_scheduler']['eta_min']
        )
    else:
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['lr_scheduler']['step_size'],
            gamma=config['lr_scheduler']['gamma']
        )

    # Loss function
    criterion = nn.CrossEntropyLoss(label_smoothing=config['augmentation']['label_smoothing'])

    # Training loop
    for epoch in range(config['epochs']):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            # Forward pass
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)

            # Backward pass
            loss.backward()

            # Gradient clipping
            if 'gradient_clipping' in config:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clipping'])

            optimizer.step()

            # Statistics
            train_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

        # Validation
        val_loss, val_acc = validate_vit(model, val_loader, criterion, device)

        # Update learning rate
        if epoch >= config.get('warmup_epochs', 0):
            scheduler.step()

        # Print statistics
        print(f'Epoch {epoch+1}/{config["epochs"]}')
        print(f'Train Loss: {train_loss/len(train_loader):.4f} | '
              f'Train Acc: {100.*correct/total:.2f}%')
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
        print('-' * 50)

def validate_vit(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            val_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    return val_loss/len(val_loader), 100.*correct/total

ViT Applications

Image Classification

# Image classification with ViT
class ImageClassifier:
    def __init__(self, num_classes=1000, variant='ViT-Base'):
        self.variant = variant
        self.model = self._create_model(num_classes)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def _create_model(self, num_classes):
        """Create ViT model based on variant"""
        if self.variant == 'ViT-Tiny':
            return VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=num_classes, embed_dim=192, depth=12,
                num_heads=3, mlp_ratio=4
            )
        elif self.variant == 'ViT-Small':
            return VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=num_classes, embed_dim=384, depth=12,
                num_heads=6, mlp_ratio=4
            )
        elif self.variant == 'ViT-Base':
            return VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=num_classes, embed_dim=768, depth=12,
                num_heads=12, mlp_ratio=4
            )
        elif self.variant == 'ViT-Large':
            return VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=num_classes, embed_dim=1024, depth=24,
                num_heads=16, mlp_ratio=4
            )
        else:
            raise ValueError(f'Unknown ViT variant: {self.variant}')

    def train(self, train_loader, val_loader, epochs=300):
        """Train the ViT model"""
        config = get_training_config()
        config['epochs'] = epochs
        train_vit(self.model, train_loader, val_loader, config, self.device)

    def predict(self, image):
        """Predict class for an image"""
        self.model.eval()
        with torch.no_grad():
            image = image.unsqueeze(0).to(self.device)
            output = self.model(image)
            return output.argmax(dim=1).item()

    def save(self, path):
        """Save model weights"""
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        """Load model weights"""
        self.model.load_state_dict(torch.load(path, map_location=self.device))

Object Detection

# Object detection with ViT backbone (conceptual)
class ViTBackbone(nn.Module):
    def __init__(self, variant='ViT-Base'):
        super(ViTBackbone, self).__init__()
        # Create base ViT
        if variant == 'ViT-Tiny':
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=1000, embed_dim=192, depth=12,
                num_heads=3, mlp_ratio=4
            )
        elif variant == 'ViT-Small':
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=1000, embed_dim=384, depth=12,
                num_heads=6, mlp_ratio=4
            )
        else:  # ViT-Base
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=3,
                num_classes=1000, embed_dim=768, depth=12,
                num_heads=12, mlp_ratio=4
            )

        # Remove classification head
        self.vit.head = nn.Identity()

    def forward(self, x):
        # Get patch embeddings
        x = self.vit.patch_embed(x)

        # Add positional embeddings
        x = self.vit.pos_embed(x)

        # Transformer encoder
        for block in self.vit.blocks:
            x = block(x)

        # Get features at different stages
        # Return class token and patch tokens
        class_token = x[:, 0:1]
        patch_tokens = x[:, 1:]

        return class_token, patch_tokens

class ViTDetector(nn.Module):
    def __init__(self, num_classes, variant='ViT-Base'):
        super(ViTDetector, self).__init__()
        # Backbone
        self.backbone = ViTBackbone(variant)

        # Feature pyramid
        self.fpn = FeaturePyramidNetwork()

        # Detection head
        self.detection_head = DetectionHead(num_classes)

    def forward(self, x):
        # Extract features
        class_token, patch_tokens = self.backbone(x)

        # Build feature pyramid
        features = self.fpn(patch_tokens)

        # Detection
        class_scores, bbox_preds = self.detection_head(features)

        return class_scores, bbox_preds

Medical Imaging

# Medical imaging with ViT
class MedicalViT(nn.Module):
    def __init__(self, in_channels=1, num_classes=2, variant='ViT-Base'):
        super(MedicalViT, self).__init__()
        # Create base ViT with modified input channels
        if variant == 'ViT-Tiny':
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=in_channels,
                num_classes=num_classes, embed_dim=192, depth=12,
                num_heads=3, mlp_ratio=4
            )
        elif variant == 'ViT-Small':
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=in_channels,
                num_classes=num_classes, embed_dim=384, depth=12,
                num_heads=6, mlp_ratio=4
            )
        else:  # ViT-Base
            self.vit = VisionTransformer(
                img_size=224, patch_size=16, in_channels=in_channels,
                num_classes=num_classes, embed_dim=768, depth=12,
                num_heads=12, mlp_ratio=4
            )

        # Add segmentation head
        self.segmentation_head = nn.Sequential(
            nn.Linear(768, 256),
            nn.GELU(),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Classification path
        class_output = self.vit(x)

        # Segmentation path
        # Get patch embeddings
        x_patch = self.vit.patch_embed(x)

        # Add positional embeddings (without class token)
        pos_embed = self.vit.pos_embed.pos_embed[:, 1:, :]
        x_patch = x_patch + pos_embed

        # Transformer encoder
        for block in self.vit.blocks:
            x_patch = block(x_patch)

        # Get patch features
        x_patch = self.vit.norm(x_patch)

        # Reshape to image dimensions
        B, N, C = x_patch.shape
        H = W = int(N ** 0.5)
        x_patch = x_patch.permute(0, 2, 1).view(B, C, H, W)

        # Upsample to original image size
        seg_output = F.interpolate(
            x_patch, size=(224, 224), mode='bilinear', align_corners=False
        )

        # Apply segmentation head
        seg_output = self.segmentation_head(seg_output.permute(0, 2, 3, 1))
        seg_output = seg_output.permute(0, 3, 1, 2)

        return class_output, seg_output

ViT Research

Key Papers

  1. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020)
    • Introduced Vision Transformer (ViT)
    • Demonstrated effectiveness of transformers for vision
    • Foundation for ViT research
  2. "Training data-efficient image transformers & distillation through attention" (Touvron et al., 2021)
    • Introduced DeiT (Data-efficient Image Transformers)
    • Demonstrated efficient training with less data
    • Foundation for practical ViT applications
  3. "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (Liu et al., 2021)
    • Introduced Swin Transformer
    • Demonstrated hierarchical feature learning
    • Foundation for efficient ViT architectures
  4. "Masked Autoencoders Are Scalable Vision Learners" (He et al., 2021)
    • Introduced MAE (Masked Autoencoders)
    • Demonstrated self-supervised learning for ViT
    • Foundation for self-supervised ViT
  5. "CoCa: Contrastive Captioners are Image-Text Foundation Models" (Yu et al., 2022)
    • Introduced CoCa (Contrastive Captioners)
    • Demonstrated multimodal learning with ViT
    • Foundation for multimodal ViT

Emerging Research Directions

  • Efficient ViTs: More compute-efficient architectures
  • Neural Architecture Search: Automated ViT design
  • Self-Supervised ViTs: Learning without labeled data
  • Multimodal ViTs: Combining vision and language
  • Explainable ViTs: More interpretable attention
  • Few-Shot ViTs: Learning from few examples
  • Adversarial ViTs: Robust ViT architectures
  • Theoretical Foundations: Better understanding of ViT
  • Hardware Acceleration: Specialized hardware for ViT
  • Real-Time ViTs: Faster inference for edge devices
  • 3D ViTs: ViT for volumetric data
  • Video ViTs: Temporal ViT for videos
  • Foundation ViT Models: Large pre-trained ViT models

Best Practices

Implementation Guidelines

AspectRecommendationNotes
Patch Size16x16 for most casesBalance between detail and computation
Model SizeStart with ViT-BaseGood balance of performance and cost
Pre-trainingUse large datasets (e.g., ImageNet-21k)Critical for good performance
Fine-tuningUse lower learning ratesPrevents catastrophic forgetting
AugmentationHeavy augmentationImproves generalization
OptimizerAdamW with weight decayWorks best for ViT
Learning Rate3e-4 to 1e-3Use learning rate scheduling
Batch Size1024 or largerLarger batches for stability
Gradient ClippingUse with value 1.0Prevents exploding gradients
Warmup5-10 epochsStabilizes early training
Label Smoothing0.1Improves calibration

Common Pitfalls and Solutions

PitfallSolutionExample
Data HungerUse pre-training or DeiTPre-train on ImageNet-21k
Training InstabilityUse gradient clipping, warmupClip gradients at 1.0
OverfittingUse strong augmentation, dropoutRandAugment, Mixup, CutMix
Slow ConvergenceUse learning rate schedulingCosine annealing with warmup
Memory IssuesUse gradient checkpointingEnable gradient checkpointing
Class ImbalanceUse weighted loss, oversamplingWeight classes by inverse frequency
Small ObjectsUse smaller patch sizesTry 8x8 or 12x12 patches
High ResolutionUse hierarchical ViT (e.g., Swin)Swin Transformer for high-res images
Numerical InstabilityUse layer normalizationBuilt into ViT architecture

Future Directions

  • Foundation ViT Models: Large pre-trained ViT models for transfer learning
  • Automated ViT Design: Neural architecture search for optimal ViT configurations
  • Self-Supervised ViTs: Learning from unlabeled data at scale
  • Multimodal ViTs: Combining vision, language, and other modalities
  • Explainable ViTs: More interpretable attention mechanisms
  • Real-Time ViTs: Optimized architectures for edge devices
  • 3D ViTs: Better architectures for volumetric data
  • Video ViTs: Temporal ViT for dynamic scenes
  • Few-Shot ViTs: Learning from very few labeled examples
  • Adversarial ViTs: Robust ViT against adversarial attacks
  • Neuromorphic ViTs: Brain-inspired ViT architectures
  • Quantum ViTs: ViT architectures for quantum computing
  • Green ViTs: Energy-efficient ViT models

External Resources