|
--- |
|
license: mit |
|
datasets: |
|
- imagenet1k |
|
metrics: |
|
- accuracy |
|
--- |
|
# VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials |
|
|
|
This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our [torch-conv-kan](https://github.com/IvanDrokin/torch-conv-kan) repository on GitHub. |
|
|
|
## Model description |
|
|
|
The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head: |
|
|
|
1. KAGN Convolution, 32 filters, 3x3 |
|
2. Max pooling, 2x2 |
|
3. KAGN Convolution, 64 filters, 3x3 |
|
4. Max pooling, 2x2 |
|
5. KAGN Convolution, 128 filters, 3x3 |
|
6. KAGN Convolution, 128 filters, 3x3 |
|
7. Max pooling, 2x2 |
|
8. KAGN Convolution, 256 filters, 3x3 |
|
9. KAGN Convolution, 256 filters, 3x3 |
|
10 Max pooling, 2x2 |
|
11. KAGN Convolution, 256 filters, 3x3 |
|
12. KAGN Convolution, 256 filters, 3x3 |
|
13. Max pooling, 2x2 |
|
14. KAGN Convolution, 512 filters, 3x3 |
|
15. KAGN Convolution, 512 filters, 3x3 |
|
16. Global Average pooling |
|
17. Output layer, 1000 nodes. |
|
|
|
data:image/s3,"s3://crabby-images/02ed3/02ed3aedca78cef7c2e402a5cdb3844c5c38f3c9" alt="model image" |
|
|
|
|
|
## Intended uses & limitations |
|
|
|
You can use the raw model for image classification or use it as pretrained model for further finetuning. |
|
|
|
### How to use |
|
|
|
First, clone the repository: |
|
|
|
``` |
|
git clone https://github.com/IvanDrokin/torch-conv-kan.git |
|
cd torch-conv-kan |
|
pip install -r requirements.txt |
|
``` |
|
Then you can initialize the model and load weights. |
|
|
|
```python |
|
import torch |
|
from models import vggkagn |
|
model = vggkagn(3, |
|
1000, |
|
groups=1, |
|
degree=5, |
|
dropout=0.15, |
|
l1_decay=0, |
|
dropout_linear=0.25, |
|
width_scale=2, |
|
vgg_type='VGG11v4', |
|
expected_feature_shape=(1, 1), |
|
affine=True |
|
) |
|
model.from_pretrained('brivangl/vgg_kagn11_v4') |
|
``` |
|
|
|
Transforms, used for validation on Imagenet1k: |
|
|
|
```python |
|
from torchvision.transforms import v2 |
|
transforms_val = v2.Compose([ |
|
v2.ToImage(), |
|
v2.Resize(256, antialias=True), |
|
v2.CenterCrop(224), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
``` |
|
|
|
|
|
|
|
## Training data |
|
This model trained on Imagenet1k dataset (1281167 images in train set) |
|
|
|
## Training procedure |
|
|
|
Model was trained during 200 full epochs with AdamW optimizer, with following parameters: |
|
```python |
|
{'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06, |
|
'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False} |
|
``` |
|
And this augmnetations: |
|
```python |
|
transforms_train = v2.Compose([ |
|
v2.ToImage(), |
|
v2.RandomHorizontalFlip(p=0.5), |
|
v2.RandomResizedCrop(224, antialias=True), |
|
v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10), |
|
v2.AutoAugment(AutoAugmentPolicy.IMAGENET) |
|
]), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
``` |
|
|
|
## Evaluation results |
|
|
|
On Imagenet1k Validation: |
|
|
|
| Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) | |
|
|:--------------:|:--------------:|:---------:|:---------:| |
|
| 61.17 | 83.26 | 99.42 | 99.43 | |
|
|
|
On Imagenet1k Test: |
|
Coming soon |
|
|
|
### BibTeX entry and citation info |
|
|
|
If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry. |
|
|
|
```bibtex |
|
@misc{torch-conv-kan, |
|
author = {Ivan Drokin}, |
|
title = {Torch Conv KAN}, |
|
year = {2024}, |
|
publisher = {GitHub}, |
|
journal = {GitHub repository}, |
|
howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}} |
|
} |
|
``` |
|
|
|
## References |
|
|
|
- [1] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756 |
|
- [2] https://github.com/KindXiaoming/pykan |
|
- [3] https://github.com/Khochawongwat/GRAMKAN |