---
library_name: aim
pipeline_tag: image-classification
license: other
license_name: apple-sample-code-license
license_link: LICENSE
datasets:
- imagenet-1k
metrics:
- accuracy
tags:
- large-scale-vision-models
- pytorch
- mlx
- jax
- vision
- ssl
- pre-training
- DFN
---
# AIM: Autoregressive Image Models
*Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar,
Joshua M Susskind, and Armand Joulin*
This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective.
We show that autoregressive pre-training of image features exhibits similar scaling properties to their
textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
1. the model capacity can be trivially scaled to billions of parameters, and
2. AIM effectively leverages large collections of uncurated image data.
## Installation
Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/).
Afterward, install the package as:
```commandline
pip install git+https://git@github.com/apple/ml-aim.git
```
We also offer [MLX](https://github.com/ml-explore/mlx) backend support for research and experimentation on Apple silicon.
To enable MLX support, simply run:
```commandline
pip install mlx
```
## Usage
Below we provide an example of usage in [PyTorch](https://pytorch.org/):
```python
from PIL import Image
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, _ = model(inp)
```
and in both MLX
```python
from PIL import Image
import mlx.core as mx
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, _ = model(inp)
```
and JAX
```python
from PIL import Image
import jax.numpy as jnp
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, _), _ = model.apply(params, inp, mutable=['batch_stats'])
```
## Pre-trained checkpoints
The pre-trained models can be accessed either via [Hugging Face](https://huggingface.co./collections/apple/aim-65aa3ce948c718a574f09eb7):
```python
# after running pip install git+https://git@github.com/apple/ml-aim.git
from aim.torch.models import AIMForImageClassification
aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b = AIMForImageClassification.from_pretrained("apple/aim-7B")
```
or [PyTorch Hub](https://pytorch.org/hub/) as:
```python
import torch
aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b = torch.hub.load("apple/ml-aim", "aim_7B")
```
### Pre-trained backbones
The following table contains pre-trained backbones used in our paper.
model |
#params |
attn (best layer) |
backbone, SHA256 |
AIM-0.6B |
0.6B |
79.4% |
link, 0d6f6b8f |
AIM-1B |
1B |
82.3% |
link, d254ecd3 |
AIM-3B |
3B |
83.3% |
link, 8475ce4e |
AIM-7B |
7B |
84.0% |
link, 184ed94c |
### Pre-trained attention heads
The table below contains the classification results on ImageNet-1k validation set.
model |
top-1 IN-1k |
attention head, SHA256 |
last layer |
best layer |
last layer |
best layer |
AIM-0.6B |
78.5% |
79.4% |
link, 5ce5a341 |
link, ebd45c05 |
AIM-1B |
80.6% |
82.3% |
link, db3be2ad |
link, f1ed7852 |
AIM-3B |
82.2% |
83.3% |
link, 5c057b30 |
link, ad380e16 |
AIM-7B |
82.4% |
84.0% |
link, 1e5c99ba |
link, 73ecd732 |
## Reproducing the IN-1k classification results
The commands below reproduce the [attention probe results](#pre-trained-attention-heads) on ImageNet-1k
validation set. We run the evaluation using 1 node with 8 GPUs:
```commandline
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=last \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
```
By default, we probe the last 6 layers. To change this, simply pass `--probe-layers=best`.