vision transformer (vit) from scratch
Hi there! Let's break down Vision Transformers (ViT) in the simplest way possible. No complex math, just plain English and code that you can actually understand.
Full source code here: ViT using Pytorch
 
What's a Vision Transformer?
Think of a Vision Transformer as a smart image analyzer that first chops up an image into smaller pieces (like a puzzle), then looks at how these pieces relate to each other to understand what's in the image. Cool, right?
 
Part 1: Slicing the Image (Patch Embedding)
We need to cut our image into smaller squares called patches. It's like taking a big photo and cutting it into a grid of smaller photos. Here's how we do it:
class PatchEmbedding(nn.Module):
def __init__(self,
in_channels=3, # RGB images have 3 channels
patch_size=16, # We'll cut the image into 16x16 patches
embedding_dim=768): # How big we want our patch representations to be
super().__init__()
self.patch_size = patch_size
# This layer does the actual patch creation and embedding
self.patcher = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0)
# Flatten the patches into a sequence
self.flatten = nn.Flatten(start_dim=2, end_dim=3)
Note: If you have a 224x224 pixel image and use 16x16 patches, you'll end up with 196 patches (14x14). Each patch becomes its own piece of the puzzle!
Part 2: The Class Token
Here's something interesting - we add a special token at the start of our sequence. This token will eventually help us make the final decision about what's in the image:
# Inside the ViT class
self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim),
requires_grad=True)
Part 3: Adding Position Information
Just like you need to know where puzzle pieces go, our model needs to know where each patch came from in the original image. We do this by adding position information:
# Calculate how many patches we have
num_patches = (img_size * img_size) // patch_size ** 2
# Create position embeddings for patches + 1 for the class token
self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
Part 4: The Transformer Magic
Now comes the cool part! The transformer layers look at all the patches and figure out how they relate to each other. It's like having someone look at all the puzzle pieces and understanding how they fit together:
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=embedding_dim,
activation='gelu',
batch_first=True,
norm_first=True,
nhead=num_heads,
dim_feedforward=mlp_size),
num_layers=num_transformers_layers)
Part 5: Making the Final Decision
Finally, we look at what our manager token (remember that special class token?) learned about the image and make a prediction:
self.mlp_head = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
Here's how the whole process flows in the forward pass:
def forward(self, x):
batch_size = x.shape[0]
# 1. Create patches
x = self.patch_embedding(x)
# 2. Add the manager token
class_token = self.class_token.expand(batch_size, -1, -1)
x = torch.cat((class_token, x), dim=1)
# 3. Add position information
x = self.positional_embedding + x
# 4. Apply transformer magic
x = self.transformer_encoder(x)
# 5. Make the final prediction
x = self.mlp_head(x[:, 0])
return x
That's how a Vision Transformer works - it chops up images, adds some special tokens and position information, looks at how everything relates to each other, and makes a prediction.
Remember, while this might seem like a lot, it's really just a series of simple steps working together to do something amazing.
Happy coding :)
Reference
-> Daniel Brouke
-> Pytorch Documentation