Model Summary

  • Architecture: Vision Transformer (ViT).
  • Backbone: Token embedding via image patches, Multi-Head Self-Attention (MHSA), and MLP blocks.
  • Dataset: CIFAR-10 (10 classes, 60k images).
  • Training Framework: PyTorch.
  • Performance: Demonstration-level training loop for illustration.

Training Process

  1. Dataset & Transforms

    • We used CIFAR-10 (32×32 color images).
    • Images were resized to 224×224 to match the original ViT patching approach.
    • [Optional] Normalization can be applied as needed, e.g. using mean/std of CIFAR-10.
  2. Model Architecture

    • Patches of size P × P.
    • Embedding dimension D.
    • Multi-Head Self-Attention with k heads.
    • MLP dimension of mlp_size.
    • A stack of L Transformer blocks.
  3. Optimizer & Loss

    • Optimizer: Adam (learning rate = 1e-4).
    • Loss: CrossEntropyLoss.
  4. Training Loop

    • Standard PyTorch loop with mini-batches.
    • Multiple epochs.
    • Tracked the training loss and accuracy.

How to Use the Model

1. Installation

Make sure you have the following libraries installed:

pip install torch torchvision matplotlib gradio huggingface_hub

2. Loading the Model

If you have a local vit_cifar_model.pth (the trained state dict), you can load the model like this:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import or define your ViT class
from model_definition import ViT  # your model code

model_cifar = ViT().to(device)
checkpoint = torch.load("vit_cifar_model.pth", map_location=device)
model_cifar.load_state_dict(checkpoint)
model_cifar.eval()

3. Inference on a Single Image

from PIL import Image
import torchvision.transforms as T

transform_cifar = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

img = Image.open("some_image.jpg")  # Load an image
x = transform_cifar(img).unsqueeze(0).to(device)  # shape [1, 3, 224, 224]

with torch.no_grad():
    logits = model_cifar(x)
pred = torch.argmax(logits, dim=1).item()
print("Predicted class:", pred)

Training & Evaluation Graphs

Below is a conceptual summary of the typical outputs you might see after training. (In your code, these graphs are generated using Matplotlib.)

  1. Training Loss Plot

image/png

Shows the training loss decreasing over epochs.

  1. Training Accuracy Plot

    Training Accuracy Plot

    Tracks the Test Accuracy: 41.51% percentage of correct predictions on the training set each epoch.

  2. Test Set Accuracy

    Test Accuracy Plot

    Evaluates the model on the test set across epochs.

  3. Confusion Matrix

image/png

Visual representation of true labels vs. predicted labels.

(Note: Replace placeholder image URLs with your actual plots if you have them hosted somewhere.)


Classification Report: precision recall f1-score support

       0     0.5618    0.4090    0.4734      1000
       1     0.5385    0.3500    0.4242      1000
       2     0.2884    0.2030    0.2383      1000
       3     0.3481    0.1570    0.2164      1000
       4     0.3686    0.5050    0.4262      1000
       5     0.3280    0.3910    0.3568      1000
       6     0.5423    0.4680    0.5024      1000
       7     0.4477    0.4110    0.4286      1000
       8     0.4668    0.5770    0.5161      1000
       9     0.3602    0.6800    0.4709      1000

accuracy                         0.4151     10000

macro avg 0.4250 0.4151 0.4053 10000 weighted avg 0.4250 0.4151 0.4053 10000

###############################################################################

CELL: Vision Transformer Hyperparameters

all important parameters for your ViT model.

Batch size

B = 2 # e.g., for demonstration

Number of channels (RGB = 3)

C = 3

Image height and width

H = 224 W = 224

Patch size

P = 16

Number of patches (derived from H, W, and P)

N = (H // P) * (W // P)

Embedding dimension

D = 768

Number of attention heads

k = 12

Dimension per head (must be compatible with D)

Dh = D // k

Dropout probability

p = 0.1

Hidden layer size for MLP inside the Transformer block

mlp_size = 3072

Number of Transformer blocks (depth of the encoder)

L = 12

Number of output classes (e.g., CIFAR-10 has 10 classes)

n_classes = 10

Print them out in a structured format

print("=== Vision Transformer Parameters ===") print(f"B (Batch Size): {B}") print(f"C (Channels): {C}") print(f"H (Image Height): {H}") print(f"W (Image Width): {W}") print(f"P (Patch Size): {P}") print(f"N (Number of Patches): {N}") print(f"D (Embedding Dimension): {D}") print(f"k (Attention Heads): {k}") print(f"Dh (Dim per Head): {Dh}") print(f"p (Dropout Probability): {p}") print(f"mlp_size (MLP Hidden): {mlp_size}") print(f"L (Num Transformer Blocks): {L}") print(f"n_classes (Output Classes): {n_classes}") print("=====================================")

Integration with Gradio & Hugging Face Spaces

Gradio Demo

A simple Gradio demo can be created to classify uploaded images:

import gradio as gr
import torch
import torchvision.transforms as T
from PIL import Image

model_cifar.eval()

class_names_cifar = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

def predict_cifar(img):
    x = T.Compose([T.Resize((224, 224)), T.ToTensor()])(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model_cifar(x)
    pred_id = torch.argmax(logits, dim=1).item()
    return f"Prediction: {class_names_cifar[pred_id]}"

gr.Interface(
    fn=predict_cifar,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="ViT on CIFAR-10"
).launch()

Hugging Face Hub

You can push the model and code to the Hugging Face Hub:

from huggingface_hub import HfApi, HfFolder

api = HfApi()
repo_id = "username/my-cifar-vit"

api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_file(
    path_or_fileobj="vit_cifar_model.pth",
    path_in_repo="vit_cifar_model.pth",
    repo_id=repo_id,
    repo_type="model"
)

Then create a Space with Gradio integration if you want a hosted web app.


License

MIT License or any license of your choice.


Author

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train Omarrran/HNM-Vision-model-cifar-vit