Model Summary
- Architecture: Vision Transformer (ViT).
- Backbone: Token embedding via image patches, Multi-Head Self-Attention (MHSA), and MLP blocks.
- Dataset: CIFAR-10 (10 classes, 60k images).
- Training Framework: PyTorch.
- Performance: Demonstration-level training loop for illustration.
Training Process
Dataset & Transforms
- We used CIFAR-10 (32×32 color images).
- Images were resized to 224×224 to match the original ViT patching approach.
- [Optional] Normalization can be applied as needed, e.g. using mean/std of CIFAR-10.
Model Architecture
- Patches of size
P × P
. - Embedding dimension
D
. - Multi-Head Self-Attention with
k
heads. - MLP dimension of
mlp_size
. - A stack of
L
Transformer blocks.
- Patches of size
Optimizer & Loss
- Optimizer: Adam (learning rate = 1e-4).
- Loss: CrossEntropyLoss.
Training Loop
- Standard PyTorch loop with mini-batches.
- Multiple epochs.
- Tracked the training loss and accuracy.
How to Use the Model
1. Installation
Make sure you have the following libraries installed:
pip install torch torchvision matplotlib gradio huggingface_hub
2. Loading the Model
If you have a local vit_cifar_model.pth
(the trained state dict), you can load the model like this:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Import or define your ViT class
from model_definition import ViT # your model code
model_cifar = ViT().to(device)
checkpoint = torch.load("vit_cifar_model.pth", map_location=device)
model_cifar.load_state_dict(checkpoint)
model_cifar.eval()
3. Inference on a Single Image
from PIL import Image
import torchvision.transforms as T
transform_cifar = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
])
img = Image.open("some_image.jpg") # Load an image
x = transform_cifar(img).unsqueeze(0).to(device) # shape [1, 3, 224, 224]
with torch.no_grad():
logits = model_cifar(x)
pred = torch.argmax(logits, dim=1).item()
print("Predicted class:", pred)
Training & Evaluation Graphs
Below is a conceptual summary of the typical outputs you might see after training. (In your code, these graphs are generated using Matplotlib.)
- Training Loss Plot
Shows the training loss decreasing over epochs.
Training Accuracy Plot
Tracks the Test Accuracy: 41.51% percentage of correct predictions on the training set each epoch.
Test Set Accuracy
Evaluates the model on the test set across epochs.
Confusion Matrix
Visual representation of true labels vs. predicted labels.
(Note: Replace placeholder image URLs with your actual plots if you have them hosted somewhere.)
Classification Report: precision recall f1-score support
0 0.5618 0.4090 0.4734 1000
1 0.5385 0.3500 0.4242 1000
2 0.2884 0.2030 0.2383 1000
3 0.3481 0.1570 0.2164 1000
4 0.3686 0.5050 0.4262 1000
5 0.3280 0.3910 0.3568 1000
6 0.5423 0.4680 0.5024 1000
7 0.4477 0.4110 0.4286 1000
8 0.4668 0.5770 0.5161 1000
9 0.3602 0.6800 0.4709 1000
accuracy 0.4151 10000
macro avg 0.4250 0.4151 0.4053 10000 weighted avg 0.4250 0.4151 0.4053 10000
###############################################################################
CELL: Vision Transformer Hyperparameters
all important parameters for your ViT model.
Batch size
B = 2 # e.g., for demonstration
Number of channels (RGB = 3)
C = 3
Image height and width
H = 224 W = 224
Patch size
P = 16
Number of patches (derived from H, W, and P)
N = (H // P) * (W // P)
Embedding dimension
D = 768
Number of attention heads
k = 12
Dimension per head (must be compatible with D)
Dh = D // k
Dropout probability
p = 0.1
Hidden layer size for MLP inside the Transformer block
mlp_size = 3072
Number of Transformer blocks (depth of the encoder)
L = 12
Number of output classes (e.g., CIFAR-10 has 10 classes)
n_classes = 10
Print them out in a structured format
print("=== Vision Transformer Parameters ===") print(f"B (Batch Size): {B}") print(f"C (Channels): {C}") print(f"H (Image Height): {H}") print(f"W (Image Width): {W}") print(f"P (Patch Size): {P}") print(f"N (Number of Patches): {N}") print(f"D (Embedding Dimension): {D}") print(f"k (Attention Heads): {k}") print(f"Dh (Dim per Head): {Dh}") print(f"p (Dropout Probability): {p}") print(f"mlp_size (MLP Hidden): {mlp_size}") print(f"L (Num Transformer Blocks): {L}") print(f"n_classes (Output Classes): {n_classes}") print("=====================================")
Integration with Gradio & Hugging Face Spaces
Gradio Demo
A simple Gradio demo can be created to classify uploaded images:
import gradio as gr
import torch
import torchvision.transforms as T
from PIL import Image
model_cifar.eval()
class_names_cifar = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
def predict_cifar(img):
x = T.Compose([T.Resize((224, 224)), T.ToTensor()])(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model_cifar(x)
pred_id = torch.argmax(logits, dim=1).item()
return f"Prediction: {class_names_cifar[pred_id]}"
gr.Interface(
fn=predict_cifar,
inputs=gr.Image(type="pil"),
outputs="text",
title="ViT on CIFAR-10"
).launch()
Hugging Face Hub
You can push the model and code to the Hugging Face Hub:
from huggingface_hub import HfApi, HfFolder
api = HfApi()
repo_id = "username/my-cifar-vit"
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_file(
path_or_fileobj="vit_cifar_model.pth",
path_in_repo="vit_cifar_model.pth",
repo_id=repo_id,
repo_type="model"
)
Then create a Space with Gradio integration if you want a hosted web app.
License
MIT License or any license of your choice.
Author
- HNM – GitHub Profile