Vision Transformer (ViT)
Transformer architecture adapted for computer vision tasks, processing images as sequences of patches.
What is Vision Transformer (ViT)?
Vision Transformer (ViT) is a neural network architecture that applies the transformer model, originally designed for natural language processing, to computer vision tasks. Instead of processing images through convolutional layers, ViT divides images into fixed-size patches and processes them as sequences, similar to how transformers handle tokens in text.
Key Characteristics
- Patch-Based Processing: Divides images into non-overlapping patches
- Transformer Architecture: Uses self-attention mechanisms
- Positional Embeddings: Encodes spatial information
- Global Context: Captures long-range dependencies
- Scalability: Performs well with large datasets
- Parallel Processing: Efficient computation on modern hardware
- Transfer Learning: Effective with pre-training on large datasets
- Flexible Input Size: Can handle variable image sizes
Architecture Overview
graph TD
A[Input Image] --> B[Patch Embedding]
B --> C[Add Positional Embeddings]
C --> D[Transformer Encoder]
D --> E[Classification Head]
D --> F[Feature Representation]
subgraph Transformer Encoder
G[Layer Normalization] --> H[Multi-Head Attention]
H --> I[Residual Connection]
I --> J[Layer Normalization]
J --> K[MLP]
K --> L[Residual Connection]
end
Core Components
Patch Embedding
# Patch embedding implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super(PatchEmbedding, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Linear projection of patches
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
"""
x: (batch_size, in_channels, img_size, img_size)
Returns: (batch_size, n_patches, embed_dim)
"""
# Project patches
x = self.proj(x) # (batch_size, embed_dim, n_patches**0.5, n_patches**0.5)
# Flatten patches
x = x.flatten(2) # (batch_size, embed_dim, n_patches)
# Transpose to get (batch_size, n_patches, embed_dim)
x = x.transpose(1, 2)
return x
Positional Embedding
# Positional embedding implementation
class PositionalEmbedding(nn.Module):
def __init__(self, n_patches, embed_dim):
super(PositionalEmbedding, self).__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
"""
x: (batch_size, n_patches, embed_dim)
Returns: (batch_size, n_patches + 1, embed_dim)
"""
# Add class token
batch_size = x.size(0)
class_token = nn.Parameter(torch.zeros(batch_size, 1, x.size(2))).to(x.device)
nn.init.trunc_normal_(class_token, std=0.02)
x = torch.cat([class_token, x], dim=1)
# Add positional embeddings
x = x + self.pos_embed
return x
Transformer Encoder
# Transformer encoder block implementation
class TransformerEncoderBlock(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout=0.1):
super(TransformerEncoderBlock, self).__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Self-attention
x_norm = self.norm1(x)
attn_output, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_output
# MLP
x_norm = self.norm2(x)
mlp_output = self.mlp(x_norm)
x = x + mlp_output
return x
Complete ViT Architecture
# Complete Vision Transformer implementation
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4, dropout=0.1):
super(VisionTransformer, self).__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
self.pos_embed = PositionalEmbedding(n_patches, embed_dim)
# Transformer encoder
self.blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Classification head
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
# Patch embedding
x = self.patch_embed(x)
# Add positional embeddings
x = self.pos_embed(x)
# Transformer encoder
for block in self.blocks:
x = block(x)
# Classification
x = self.norm(x)
class_token = x[:, 0] # Take class token
x = self.head(class_token)
return x
ViT Variants
Standard ViT Architectures
| Variant | Layers (depth) | Embedding Dim | Heads | Parameters | Use Case |
|---|---|---|---|---|---|
| ViT-Tiny | 12 | 192 | 3 | ~5M | Lightweight applications |
| ViT-Small | 12 | 384 | 6 | ~22M | Mobile/edge devices |
| ViT-Base | 12 | 768 | 12 | ~86M | General purpose |
| ViT-Large | 24 | 1024 | 16 | ~307M | High performance |
| ViT-Huge | 32 | 1280 | 16 | ~632M | Very complex tasks |
Modified ViT Architectures
# Data-efficient Image Transformer (DeiT)
class DeiT(VisionTransformer):
def __init__(self, *args, **kwargs):
super(DeiT, self).__init__(*args, **kwargs)
# Add distillation token
self.dist_token = nn.Parameter(torch.zeros(1, 1, kwargs['embed_dim']))
nn.init.trunc_normal_(self.dist_token, std=0.02)
def forward(self, x):
# Patch embedding
x = self.patch_embed(x)
# Add class and distillation tokens
batch_size = x.size(0)
class_token = self.pos_embed.pos_embed[:, :1, :].expand(batch_size, -1, -1)
dist_token = self.dist_token.expand(batch_size, -1, -1)
x = torch.cat([class_token, dist_token, x], dim=1)
# Add positional embeddings
x = x + self.pos_embed.pos_embed
# Transformer encoder
for block in self.blocks:
x = block(x)
# Classification
x = self.norm(x)
class_token = x[:, 0]
dist_token = x[:, 1]
x = (self.head(class_token) + self.head(dist_token)) / 2
return x
# Swin Transformer
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0):
super(SwinTransformerBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
# Layer normalization
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# Window-based multi-head attention
self.attn = nn.MultiheadAttention(dim, num_heads)
# MLP
self.mlp = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim)
)
def forward(self, x, H, W):
# Window partitioning
B, L, C = x.shape
assert L == H * W, "Input feature has wrong size"
# Apply window attention
x = self.norm1(x)
x = x.view(B, H, W, C)
# Shift windows if needed
if self.shift_size > 0:
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# Partition windows
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# Window attention
attn_output, _ = self.attn(x_windows, x_windows, x_windows)
attn_output = attn_output.view(-1, self.window_size, self.window_size, C)
x = window_reverse(attn_output, self.window_size, H, W)
# Reverse shift if needed
if self.shift_size > 0:
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
x = x.view(B, H * W, C)
x = x + attn_output.view(B, H * W, C)
# MLP
x = x + self.mlp(self.norm2(x))
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
ViT vs CNNs
| Feature | Vision Transformer (ViT) | Convolutional Neural Networks (CNNs) |
|---|---|---|
| Feature Extraction | Patch-based, global attention | Local convolutional filters |
| Inductive Bias | Minimal (learns from data) | Strong (translation equivariance) |
| Global Context | Built-in through self-attention | Requires deep stacks |
| Parameter Efficiency | More parameters for same performance | Fewer parameters |
| Data Efficiency | Requires large datasets | Works well with smaller datasets |
| Training Stability | Can be unstable | Generally stable |
| Hardware Utilization | Optimized for parallel processing | Optimized for convolution operations |
| Interpretability | Attention maps | Feature maps |
| Scalability | Excellent with large data | Limited by architecture |
| Transfer Learning | Excellent with pre-training | Good with pre-training |
| Computational Cost | High (quadratic attention) | Lower (linear convolutions) |
| Memory Usage | High | Lower |
| Flexibility | High (variable input sizes) | Limited (fixed receptive fields) |
Training ViT
Training Configuration
# Training configuration for ViT
def get_training_config():
return {
'optimizer': 'AdamW',
'learning_rate': 3e-4,
'weight_decay': 0.05,
'lr_scheduler': {
'type': 'CosineAnnealingLR',
'T_max': 300,
'eta_min': 1e-5
},
'batch_size': 1024,
'epochs': 300,
'augmentation': {
'random_resized_crop': True,
'random_horizontal_flip': True,
'color_jitter': 0.4,
'auto_augment': 'rand-m9-mstd0.5-inc1',
'mixup': 0.8,
'cutmix': 1.0,
'label_smoothing': 0.1
},
'gradient_clipping': 1.0,
'warmup_epochs': 5
}
Training Loop
# Training loop for ViT
def train_vit(model, train_loader, val_loader, config, device):
# Optimizer
if config['optimizer'] == 'AdamW':
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
else:
optimizer = torch.optim.SGD(
model.parameters(),
lr=config['learning_rate'],
momentum=0.9,
weight_decay=config['weight_decay']
)
# Learning rate scheduler
if config['lr_scheduler']['type'] == 'CosineAnnealingLR':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config['lr_scheduler']['T_max'],
eta_min=config['lr_scheduler']['eta_min']
)
else:
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config['lr_scheduler']['step_size'],
gamma=config['lr_scheduler']['gamma']
)
# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=config['augmentation']['label_smoothing'])
# Training loop
for epoch in range(config['epochs']):
model.train()
train_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Forward pass
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if 'gradient_clipping' in config:
torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clipping'])
optimizer.step()
# Statistics
train_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
# Validation
val_loss, val_acc = validate_vit(model, val_loader, criterion, device)
# Update learning rate
if epoch >= config.get('warmup_epochs', 0):
scheduler.step()
# Print statistics
print(f'Epoch {epoch+1}/{config["epochs"]}')
print(f'Train Loss: {train_loss/len(train_loader):.4f} | '
f'Train Acc: {100.*correct/total:.2f}%')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
print('-' * 50)
def validate_vit(model, val_loader, criterion, device):
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
return val_loss/len(val_loader), 100.*correct/total
ViT Applications
Image Classification
# Image classification with ViT
class ImageClassifier:
def __init__(self, num_classes=1000, variant='ViT-Base'):
self.variant = variant
self.model = self._create_model(num_classes)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def _create_model(self, num_classes):
"""Create ViT model based on variant"""
if self.variant == 'ViT-Tiny':
return VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=num_classes, embed_dim=192, depth=12,
num_heads=3, mlp_ratio=4
)
elif self.variant == 'ViT-Small':
return VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=num_classes, embed_dim=384, depth=12,
num_heads=6, mlp_ratio=4
)
elif self.variant == 'ViT-Base':
return VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=num_classes, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4
)
elif self.variant == 'ViT-Large':
return VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=num_classes, embed_dim=1024, depth=24,
num_heads=16, mlp_ratio=4
)
else:
raise ValueError(f'Unknown ViT variant: {self.variant}')
def train(self, train_loader, val_loader, epochs=300):
"""Train the ViT model"""
config = get_training_config()
config['epochs'] = epochs
train_vit(self.model, train_loader, val_loader, config, self.device)
def predict(self, image):
"""Predict class for an image"""
self.model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(self.device)
output = self.model(image)
return output.argmax(dim=1).item()
def save(self, path):
"""Save model weights"""
torch.save(self.model.state_dict(), path)
def load(self, path):
"""Load model weights"""
self.model.load_state_dict(torch.load(path, map_location=self.device))
Object Detection
# Object detection with ViT backbone (conceptual)
class ViTBackbone(nn.Module):
def __init__(self, variant='ViT-Base'):
super(ViTBackbone, self).__init__()
# Create base ViT
if variant == 'ViT-Tiny':
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=192, depth=12,
num_heads=3, mlp_ratio=4
)
elif variant == 'ViT-Small':
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=384, depth=12,
num_heads=6, mlp_ratio=4
)
else: # ViT-Base
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4
)
# Remove classification head
self.vit.head = nn.Identity()
def forward(self, x):
# Get patch embeddings
x = self.vit.patch_embed(x)
# Add positional embeddings
x = self.vit.pos_embed(x)
# Transformer encoder
for block in self.vit.blocks:
x = block(x)
# Get features at different stages
# Return class token and patch tokens
class_token = x[:, 0:1]
patch_tokens = x[:, 1:]
return class_token, patch_tokens
class ViTDetector(nn.Module):
def __init__(self, num_classes, variant='ViT-Base'):
super(ViTDetector, self).__init__()
# Backbone
self.backbone = ViTBackbone(variant)
# Feature pyramid
self.fpn = FeaturePyramidNetwork()
# Detection head
self.detection_head = DetectionHead(num_classes)
def forward(self, x):
# Extract features
class_token, patch_tokens = self.backbone(x)
# Build feature pyramid
features = self.fpn(patch_tokens)
# Detection
class_scores, bbox_preds = self.detection_head(features)
return class_scores, bbox_preds
Medical Imaging
# Medical imaging with ViT
class MedicalViT(nn.Module):
def __init__(self, in_channels=1, num_classes=2, variant='ViT-Base'):
super(MedicalViT, self).__init__()
# Create base ViT with modified input channels
if variant == 'ViT-Tiny':
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=in_channels,
num_classes=num_classes, embed_dim=192, depth=12,
num_heads=3, mlp_ratio=4
)
elif variant == 'ViT-Small':
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=in_channels,
num_classes=num_classes, embed_dim=384, depth=12,
num_heads=6, mlp_ratio=4
)
else: # ViT-Base
self.vit = VisionTransformer(
img_size=224, patch_size=16, in_channels=in_channels,
num_classes=num_classes, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4
)
# Add segmentation head
self.segmentation_head = nn.Sequential(
nn.Linear(768, 256),
nn.GELU(),
nn.Linear(256, 128),
nn.GELU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
# Classification path
class_output = self.vit(x)
# Segmentation path
# Get patch embeddings
x_patch = self.vit.patch_embed(x)
# Add positional embeddings (without class token)
pos_embed = self.vit.pos_embed.pos_embed[:, 1:, :]
x_patch = x_patch + pos_embed
# Transformer encoder
for block in self.vit.blocks:
x_patch = block(x_patch)
# Get patch features
x_patch = self.vit.norm(x_patch)
# Reshape to image dimensions
B, N, C = x_patch.shape
H = W = int(N ** 0.5)
x_patch = x_patch.permute(0, 2, 1).view(B, C, H, W)
# Upsample to original image size
seg_output = F.interpolate(
x_patch, size=(224, 224), mode='bilinear', align_corners=False
)
# Apply segmentation head
seg_output = self.segmentation_head(seg_output.permute(0, 2, 3, 1))
seg_output = seg_output.permute(0, 3, 1, 2)
return class_output, seg_output
ViT Research
Key Papers
- "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020)
- Introduced Vision Transformer (ViT)
- Demonstrated effectiveness of transformers for vision
- Foundation for ViT research
- "Training data-efficient image transformers & distillation through attention" (Touvron et al., 2021)
- Introduced DeiT (Data-efficient Image Transformers)
- Demonstrated efficient training with less data
- Foundation for practical ViT applications
- "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (Liu et al., 2021)
- Introduced Swin Transformer
- Demonstrated hierarchical feature learning
- Foundation for efficient ViT architectures
- "Masked Autoencoders Are Scalable Vision Learners" (He et al., 2021)
- Introduced MAE (Masked Autoencoders)
- Demonstrated self-supervised learning for ViT
- Foundation for self-supervised ViT
- "CoCa: Contrastive Captioners are Image-Text Foundation Models" (Yu et al., 2022)
- Introduced CoCa (Contrastive Captioners)
- Demonstrated multimodal learning with ViT
- Foundation for multimodal ViT
Emerging Research Directions
- Efficient ViTs: More compute-efficient architectures
- Neural Architecture Search: Automated ViT design
- Self-Supervised ViTs: Learning without labeled data
- Multimodal ViTs: Combining vision and language
- Explainable ViTs: More interpretable attention
- Few-Shot ViTs: Learning from few examples
- Adversarial ViTs: Robust ViT architectures
- Theoretical Foundations: Better understanding of ViT
- Hardware Acceleration: Specialized hardware for ViT
- Real-Time ViTs: Faster inference for edge devices
- 3D ViTs: ViT for volumetric data
- Video ViTs: Temporal ViT for videos
- Foundation ViT Models: Large pre-trained ViT models
Best Practices
Implementation Guidelines
| Aspect | Recommendation | Notes |
|---|---|---|
| Patch Size | 16x16 for most cases | Balance between detail and computation |
| Model Size | Start with ViT-Base | Good balance of performance and cost |
| Pre-training | Use large datasets (e.g., ImageNet-21k) | Critical for good performance |
| Fine-tuning | Use lower learning rates | Prevents catastrophic forgetting |
| Augmentation | Heavy augmentation | Improves generalization |
| Optimizer | AdamW with weight decay | Works best for ViT |
| Learning Rate | 3e-4 to 1e-3 | Use learning rate scheduling |
| Batch Size | 1024 or larger | Larger batches for stability |
| Gradient Clipping | Use with value 1.0 | Prevents exploding gradients |
| Warmup | 5-10 epochs | Stabilizes early training |
| Label Smoothing | 0.1 | Improves calibration |
Common Pitfalls and Solutions
| Pitfall | Solution | Example |
|---|---|---|
| Data Hunger | Use pre-training or DeiT | Pre-train on ImageNet-21k |
| Training Instability | Use gradient clipping, warmup | Clip gradients at 1.0 |
| Overfitting | Use strong augmentation, dropout | RandAugment, Mixup, CutMix |
| Slow Convergence | Use learning rate scheduling | Cosine annealing with warmup |
| Memory Issues | Use gradient checkpointing | Enable gradient checkpointing |
| Class Imbalance | Use weighted loss, oversampling | Weight classes by inverse frequency |
| Small Objects | Use smaller patch sizes | Try 8x8 or 12x12 patches |
| High Resolution | Use hierarchical ViT (e.g., Swin) | Swin Transformer for high-res images |
| Numerical Instability | Use layer normalization | Built into ViT architecture |
Future Directions
- Foundation ViT Models: Large pre-trained ViT models for transfer learning
- Automated ViT Design: Neural architecture search for optimal ViT configurations
- Self-Supervised ViTs: Learning from unlabeled data at scale
- Multimodal ViTs: Combining vision, language, and other modalities
- Explainable ViTs: More interpretable attention mechanisms
- Real-Time ViTs: Optimized architectures for edge devices
- 3D ViTs: Better architectures for volumetric data
- Video ViTs: Temporal ViT for dynamic scenes
- Few-Shot ViTs: Learning from very few labeled examples
- Adversarial ViTs: Robust ViT against adversarial attacks
- Neuromorphic ViTs: Brain-inspired ViT architectures
- Quantum ViTs: ViT architectures for quantum computing
- Green ViTs: Energy-efficient ViT models
External Resources
- Original ViT Paper (Dosovitskiy et al.)
- DeiT Paper (Touvron et al.)
- Swin Transformer Paper (Liu et al.)
- MAE Paper (He et al.)
- CoCa Paper (Yu et al.)
- ViT Implementation (PyTorch)
- ViT Tutorial (YouTube)
- ViT for Medical Imaging (arXiv)
- Efficient ViTs (arXiv)
- ViT for Object Detection (arXiv)
- ViT Survey (arXiv)
- ViT Hardware Acceleration (arXiv)
- Self-Supervised ViTs (arXiv)
- Adversarial ViTs (arXiv)
- ViT Datasets
Virtual Assistant
AI-powered digital assistants that perform tasks, provide information, and manage personal or professional workflows through natural language interaction.
Weak AI (Narrow AI)
Artificial intelligence systems designed to perform specific tasks without general intelligence or consciousness.