Sreedeep.

vision transformer (vit) from scratch

Hi there! Let's break down Vision Transformers (ViT) in the simplest way possible. No complex math, just plain English and code that you can actually understand.

Full source code here: ViT using Pytorch

 

What's a Vision Transformer?

Think of a Vision Transformer as a smart image analyzer that first chops up an image into smaller pieces (like a puzzle), then looks at how these pieces relate to each other to understand what's in the image. Cool, right?

 

Part 1: Slicing the Image (Patch Embedding)

We need to cut our image into smaller squares called patches. It's like taking a big photo and cutting it into a grid of smaller photos. Here's how we do it:

class PatchEmbedding(nn.Module):
    def __init__(self, 
                 in_channels=3,  # RGB images have 3 channels
                 patch_size=16,  # We'll cut the image into 16x16 patches
                 embedding_dim=768):  # How big we want our patch representations to be
        super().__init__()
        
        self.patch_size = patch_size
        
        # This layer does the actual patch creation and embedding 

        self.patcher = nn.Conv2d(in_channels=in_channels,
                                out_channels=embedding_dim,
                                kernel_size=patch_size,
                                stride=patch_size,
                                padding=0)

        # Flatten the patches into a sequence
        self.flatten = nn.Flatten(start_dim=2, end_dim=3) 

Note: If you have a 224x224 pixel image and use 16x16 patches, you'll end up with 196 patches (14x14). Each patch becomes its own piece of the puzzle!

Part 2: The Class Token

Here's something interesting - we add a special token at the start of our sequence. This token will eventually help us make the final decision about what's in the image:

# Inside the ViT class 

self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim),
                               requires_grad=True)

Part 3: Adding Position Information

Just like you need to know where puzzle pieces go, our model needs to know where each patch came from in the original image. We do this by adding position information:

# Calculate how many patches we have
num_patches = (img_size * img_size) // patch_size ** 2
# Create position embeddings for patches + 1 for the class token
self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim)) 

Part 4: The Transformer Magic

Now comes the cool part! The transformer layers look at all the patches and figure out how they relate to each other. It's like having someone look at all the puzzle pieces and understanding how they fit together:

self.transformer_encoder = nn.TransformerEncoder(
    encoder_layer=nn.TransformerEncoderLayer(
        d_model=embedding_dim,
        activation='gelu',
        batch_first=True,
        norm_first=True,
        nhead=num_heads,
        dim_feedforward=mlp_size),
    num_layers=num_transformers_layers)

Part 5: Making the Final Decision

Finally, we look at what our manager token (remember that special class token?) learned about the image and make a prediction:

self.mlp_head = nn.Sequential(
    nn.LayerNorm(normalized_shape=embedding_dim),
    nn.Linear(in_features=embedding_dim,
             out_features=num_classes)
)

Here's how the whole process flows in the forward pass:

def forward(self, x):
    batch_size = x.shape[0]
    
    # 1. Create patches
    x = self.patch_embedding(x)
    
    # 2. Add the manager token
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)
    
    # 3. Add position information
    x = self.positional_embedding + x
    
    # 4. Apply transformer magic
    x = self.transformer_encoder(x)
    
    # 5. Make the final prediction
    x = self.mlp_head(x[:, 0])
    
    return x

That's how a Vision Transformer works - it chops up images, adds some special tokens and position information, looks at how everything relates to each other, and makes a prediction.

Remember, while this might seem like a lot, it's really just a series of simple steps working together to do something amazing.

Happy coding :)

Reference
-> Daniel Brouke
-> Pytorch Documentation