File size: 1,219 Bytes
47fbb67
 
 
 
 
 
1d25867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
---
tags:
- pytorch_model_hub_mixin
- model_hub_mixin
---

### How to use

First, clone the repository:

```
git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt
```
Then you can initialize the model and load weights.

```python
import torch
from models import VGGKAGN_BN
model = VGGKAGN_BN.from_pretrained('brivangl/vgg_kagn_bn11_v4_opt',
                                   groups=1,
                                   degree=3,
                                   dropout=0.05,
                                   l1_decay=0,
                                   width_scale=3,
                                   affine=True,
                                   norm_layer=nn.BatchNorm2d,
                                   expected_feature_shape=(1, 1),
                                   vgg_type='VGG11v4')
```

Transforms, used for validation on Imagenet1k:

```python
from torchvision.transforms import v2
transforms_val = v2.Compose([
        v2.ToImage(),
        v2.Resize(256, antialias=True),
        v2.CenterCrop(224),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
```