|
--- |
|
tags: |
|
- pytorch_model_hub_mixin |
|
- model_hub_mixin |
|
--- |
|
|
|
### How to use |
|
|
|
First, clone the repository: |
|
|
|
``` |
|
git clone https://github.com/IvanDrokin/torch-conv-kan.git |
|
cd torch-conv-kan |
|
pip install -r requirements.txt |
|
``` |
|
Then you can initialize the model and load weights. |
|
|
|
```python |
|
import torch |
|
from models import VGGKAGN_BN |
|
model = VGGKAGN_BN.from_pretrained('brivangl/vgg_kagn_bn11_v4_opt', |
|
groups=1, |
|
degree=3, |
|
dropout=0.05, |
|
l1_decay=0, |
|
width_scale=3, |
|
affine=True, |
|
norm_layer=nn.BatchNorm2d, |
|
expected_feature_shape=(1, 1), |
|
vgg_type='VGG11v4') |
|
``` |
|
|
|
Transforms, used for validation on Imagenet1k: |
|
|
|
```python |
|
from torchvision.transforms import v2 |
|
transforms_val = v2.Compose([ |
|
v2.ToImage(), |
|
v2.Resize(256, antialias=True), |
|
v2.CenterCrop(224), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
``` |
|
|