Doven commited on
Commit
f7009b3
·
1 Parent(s): 6abdfdc

update code.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +10 -0
  2. README.md +79 -3
  3. checkpoint/generalization.pth +0 -0
  4. dataset/__init__.py +1 -0
  5. dataset/cifar100_resnet18bn/model.py +27 -0
  6. dataset/cifar100_resnet18bn/prepare.py +192 -0
  7. dataset/cifar100_resnet18bn/test.py +28 -0
  8. dataset/cifar100_resnet18bn/train.py +195 -0
  9. dataset/cifar10_cnnmedium/model.py +48 -0
  10. dataset/cifar10_cnnmedium/test.py +28 -0
  11. dataset/cifar10_cnnmedium/train.py +192 -0
  12. dataset/cifar10_cnnsmall/model.py +48 -0
  13. dataset/cifar10_cnnsmall/test.py +28 -0
  14. dataset/cifar10_cnnsmall/train.py +192 -0
  15. dataset/cifar10_mobilenetv3/model.py +21 -0
  16. dataset/cifar10_mobilenetv3/test.py +28 -0
  17. dataset/cifar10_mobilenetv3/train.py +199 -0
  18. dataset/cifar10_resnet18/model.py +17 -0
  19. dataset/cifar10_resnet18/test.py +28 -0
  20. dataset/cifar10_resnet18/train.py +191 -0
  21. dataset/cifar10_vitbase/model.py +17 -0
  22. dataset/cifar10_vitbase/test.py +28 -0
  23. dataset/cifar10_vitbase/train.py +199 -0
  24. dataset/condition_classinput_inference/dataset.py +41 -0
  25. dataset/condition_classinput_inference/model.py +25 -0
  26. dataset/condition_classinput_inference/test.py +30 -0
  27. dataset/condition_classinput_inference/train.py +209 -0
  28. dataset/condition_classinput_vittiny/dataset.py +41 -0
  29. dataset/condition_classinput_vittiny/detail.py +58 -0
  30. dataset/condition_classinput_vittiny/finetune.py +215 -0
  31. dataset/condition_classinput_vittiny/model.py +25 -0
  32. dataset/condition_classinput_vittiny/split.sh +28 -0
  33. dataset/condition_classinput_vittiny/test.py +30 -0
  34. dataset/condition_classinput_vittiny/train.py +212 -0
  35. dataset/condition_classinput_vittiny/train.sh +10 -0
  36. dataset/condition_imageinput_vittiny/README.md +1 -0
  37. dataset/condition_imageinput_vittiny/dataset.py +46 -0
  38. dataset/condition_imageinput_vittiny/model.py +18 -0
  39. dataset/condition_imageinput_vittiny/test.py +30 -0
  40. dataset/condition_imageinput_vittiny/train.py +208 -0
  41. dataset/condition_imageinput_vittiny/train.sh +11 -0
  42. dataset/condition_permutation_vittiny/model.py +18 -0
  43. dataset/condition_permutation_vittiny/test.py +31 -0
  44. dataset/condition_permutation_vittiny/train.py +210 -0
  45. dataset/condition_permutation_vittiny/train.sh +10 -0
  46. dataset/config.json +1 -0
  47. dataset/dataset.py +327 -0
  48. dataset/downtask_detection/README.md +1 -0
  49. dataset/downtask_detection/test.sh +11 -0
  50. dataset/downtask_dora_r16/adapter_config.json +23 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /.idea
2
+ /.vscode
3
+ **/checkpoint*/
4
+ **/__pycache__/
5
+ **/generated*/
6
+ **/wandb/
7
+ **/full_model.pth
8
+ /rubbish
9
+ **/*cache*
10
+ /workspace/classinput/Qwen25llm/
README.md CHANGED
@@ -1,3 +1,79 @@
1
- ---
2
- license: unknown
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recurrent Parameter Generation
2
+ The official repository of paper [Recurrent Diffusion for Large-Scale Parameter Generation]().
3
+
4
+
5
+ ## Introduction
6
+ Parameter generation has long struggled to scale, significantly limiting its applications.
7
+ In this study, we introduce Recurrent diffusion for large-scale Parameter Generation, or RPG,
8
+ which models large-scale parameter generation through a recurrent diffusion process.
9
+ We divide the trained parameters into non-overlapping parts and propose a recurrent model to learn their relationships.
10
+ The outputs of this recurrent model, serving as conditions, are then input into a diffusion model to generate neural network parameters.
11
+ Utilizing only a single GPU, our method can generate parameters for popular vision and language models, such as ConvNeXt-L and LoRA parameters for LLaMA-7B.
12
+ Across various architectures and tasks, the generated parameters consistently achieve comparable performance to those of trained networks.
13
+ Additionally, our approach demonstrates potential in generating models capable of handling unseen tasks,
14
+ indicating that recurrent diffusion greatly enhances the practicality of parameter generation.
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ ## Environment
24
+ Before you get started, you need to set up a conda environment first.
25
+ 1. Create your conda environment.
26
+ ```shell
27
+ conda create -n rpg python=3.11
28
+ conda activate rpg
29
+ conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
30
+ ```
31
+ 2. Install mamba-ssm. (You may run into compilation issues, refer to the [official mamba-ssm repository](https://github.com/state-spaces/mamba) for details.)
32
+ ```shell
33
+ pip install mamba-ssm[causal-conv1d]
34
+ pip install causal-conv1d
35
+ ```
36
+ 3. Install other dependencies for this repository.
37
+ ```shell
38
+ git lfs install
39
+ git clone https://huggingface.co/MTDoven/Recurrent-Parameter-Generation
40
+ cd Recurrent-Parameter-Generation
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+
45
+
46
+
47
+ ## Quick Start
48
+ 1. Modify your config file.
49
+ ```shell
50
+ # Set up your configs interactively.
51
+ python ./workspace/set_configs.py
52
+ ```
53
+
54
+ 2. Download checkpoint datasets.
55
+ ```shell
56
+
57
+ # Download the ViTTiny1022 dataset to /path/to/your/download/ViTTiny1022
58
+ mv /path/to/your/download/ViTTiny1022/* ./dataset/condition_classinput_vittiny/
59
+ ```
60
+
61
+ 3. Try to generate with RPG model.
62
+ ```shell
63
+ cd ./workspace
64
+ CUDA_VISIBLE_DEVICES=0 python ./classinput/launch.py
65
+ # CUDA_VISIBLE_DEVICES=<GPU_index> python ./classinput/launch.py
66
+ ```
67
+
68
+ You can get more information from [Github](https://github.com/NUS-HPC-AI-Lab/Recurrent-Parameter-Generation).
69
+
70
+
71
+
72
+
73
+ ## Acknowledgment
74
+ coming soon...
75
+
76
+
77
+ ## Citation
78
+ coming soon...
79
+
checkpoint/generalization.pth ADDED
File without changes
dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .register import *
dataset/cifar100_resnet18bn/model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import timm
4
+ import os
5
+
6
+
7
+ def Model():
8
+ model = timm.create_model("resnet18", pretrained=True)
9
+ model.fc = nn.Linear(512, 100)
10
+ if os.path.exists(os.path.join(os.path.dirname(__file__), "full_model.pth")):
11
+ model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), "full_model.pth"), map_location="cpu"))
12
+ for k, v in model.named_parameters():
13
+ if k in ["layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn2.weight", "layer4.1.bn2.bias"]:
14
+ v.requires_grad = True
15
+ else: # requires_grad = False
16
+ v.requires_grad = False
17
+ return model, model.fc
18
+
19
+
20
+ if __name__ == "__main__":
21
+ model, _ = Model()
22
+ print(model)
23
+ num_param = 0
24
+ for k, v in model.named_parameters():
25
+ num_param += v.numel()
26
+ print(k)
27
+ print("num_param:", num_param)
dataset/cifar100_resnet18bn/prepare.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR100 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 32,
47
+ "learning_rate": 0.0005,
48
+ "weight_decay": 0.000005,
49
+ "epochs": 200,
50
+ "save_learning_rate": 0.0,
51
+ "total_save_number": 1,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(80),
66
+ transforms.RandomHorizontalFlip(),
67
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
70
+ ])
71
+ )
72
+ train_loader = DataLoader(
73
+ dataset=dataset,
74
+ batch_size=config["batch_size"],
75
+ num_workers=config["num_workers"],
76
+ shuffle=True,
77
+ drop_last=True,
78
+ pin_memory=True,
79
+ )
80
+ test_loader = DataLoader(
81
+ dataset=Dataset(
82
+ root=config["dataset_root"],
83
+ download=True,
84
+ train=False,
85
+ transform=transforms.Compose([
86
+ transforms.Resize(80),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
89
+ ])),
90
+ batch_size=config["batch_size"],
91
+ num_workers=config["num_workers"],
92
+ shuffle=False,
93
+ pin_memory=True,
94
+ )
95
+
96
+ # Model
97
+ model, head = Model()
98
+ model = model.to(device)
99
+ criterion = nn.CrossEntropyLoss()
100
+ pre_optimizer = optim.AdamW(
101
+ head.parameters(),
102
+ lr=0.001,
103
+ weight_decay=config["weight_decay"],
104
+ )
105
+ optimizer = optim.AdamW(
106
+ model.parameters(),
107
+ lr=config["learning_rate"],
108
+ weight_decay=config["weight_decay"],
109
+ )
110
+ scheduler = lr_scheduler.CosineAnnealingLR(
111
+ optimizer,
112
+ T_max=config["epochs"],
113
+ eta_min=config["save_learning_rate"],
114
+ )
115
+
116
+
117
+
118
+
119
+ # Training
120
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
121
+ model.train()
122
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
123
+ total=len(dataset) // config["batch_size"]):
124
+ inputs, targets = inputs.to(device), targets.to(device)
125
+ optimizer.zero_grad()
126
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
127
+ outputs = model(inputs)
128
+ loss = criterion(outputs, targets)
129
+ loss.backward()
130
+ optimizer.step()
131
+ if scheduler is not None:
132
+ scheduler.step()
133
+
134
+ # test
135
+ @torch.no_grad()
136
+ def test(model=model):
137
+ model.eval()
138
+ all_targets = []
139
+ all_predicts = []
140
+ test_loss = 0
141
+ correct = 0
142
+ total = 0
143
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
144
+ total=len(test_loader.dataset) // config["batch_size"]):
145
+ inputs, targets = inputs.to(device), targets.to(device)
146
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
147
+ outputs = model(inputs)
148
+ loss = criterion(outputs, targets)
149
+ # to logging losses
150
+ all_targets.extend(targets.flatten().tolist())
151
+ test_loss += loss.item()
152
+ _, predicts = outputs.max(1)
153
+ all_predicts.extend(predicts.flatten().tolist())
154
+ total += targets.size(0)
155
+ correct += predicts.eq(targets).sum().item()
156
+ loss = test_loss / (batch_idx + 1)
157
+ acc = correct / total
158
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}")
159
+ model.train()
160
+ return loss, acc, all_targets, all_predicts
161
+
162
+ # save train
163
+ def save_train(model=model, optimizer=optimizer):
164
+ model.train()
165
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
166
+ inputs, targets = inputs.to(device), targets.to(device)
167
+ optimizer.zero_grad()
168
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
169
+ outputs = model(inputs)
170
+ loss = criterion(outputs, targets)
171
+ loss.backward()
172
+ optimizer.step()
173
+ # Save checkpoint
174
+ _, acc, _, _ = test(model=model)
175
+ if not os.path.isdir('checkpoint'):
176
+ os.mkdir('checkpoint')
177
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
178
+ torch.save(save_state, f"full_model.pth")
179
+ print("save:", f"full_model.pth")
180
+
181
+
182
+
183
+
184
+ # main
185
+ if __name__ == '__main__':
186
+ test(model=model)
187
+ train(model=model, optimizer=pre_optimizer, scheduler=scheduler)
188
+ train(model=model, optimizer=pre_optimizer, scheduler=scheduler)
189
+ for epoch in range(config["epochs"]):
190
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
191
+ test(model=model)
192
+ save_train(model=model, optimizer=optimizer)
dataset/cifar100_resnet18bn/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()}, strict=False)
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar100_resnet18bn/train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR100 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 100 if __name__ == "__main__" else 200,
46
+ "num_workers": 4,
47
+ "learning_rate": 0.01,
48
+ "weight_decay": 5e-6,
49
+ "epochs": 1,
50
+ "save_learning_rate": 0.01,
51
+ "total_save_number": 200,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(80),
66
+ transforms.RandomHorizontalFlip(),
67
+ transforms.RandAugment(),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
70
+ ])
71
+ )
72
+ train_loader = DataLoader(
73
+ dataset=dataset,
74
+ batch_size=config["batch_size"],
75
+ num_workers=config["num_workers"],
76
+ shuffle=True,
77
+ drop_last=True,
78
+ pin_memory=True,
79
+ persistent_workers=False,
80
+ )
81
+ test_loader = DataLoader(
82
+ dataset=Dataset(
83
+ root=config["dataset_root"],
84
+ download=True,
85
+ train=False,
86
+ transform=transforms.Compose([
87
+ transforms.Resize(80),
88
+ transforms.CenterCrop(80),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
91
+ ])),
92
+ batch_size=config["batch_size"],
93
+ num_workers=config["num_workers"],
94
+ shuffle=False,
95
+ pin_memory=True,
96
+ persistent_workers=False,
97
+ pin_memory_device="cuda",
98
+ )
99
+
100
+ # Model
101
+ model, head = Model()
102
+ model = model.to(device)
103
+ criterion = nn.CrossEntropyLoss()
104
+ optimizer = optim.AdamW(
105
+ model.parameters(),
106
+ lr=config["learning_rate"],
107
+ weight_decay=config["weight_decay"],
108
+ )
109
+ scheduler = lr_scheduler.CosineAnnealingLR(
110
+ optimizer,
111
+ T_max=config["epochs"],
112
+ eta_min=config["save_learning_rate"],
113
+ )
114
+
115
+
116
+
117
+
118
+ # Training
119
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
120
+ model.train()
121
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
122
+ total=len(dataset) // config["batch_size"]):
123
+ inputs, targets = inputs.to(device), targets.to(device)
124
+ optimizer.zero_grad()
125
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
126
+ outputs = model(inputs)
127
+ loss = criterion(outputs, targets)
128
+ loss.backward()
129
+ optimizer.step()
130
+ if scheduler is not None:
131
+ scheduler.step()
132
+
133
+ # test
134
+ @torch.no_grad()
135
+ def test(model=model):
136
+ model.eval()
137
+ all_targets = []
138
+ all_predicts = []
139
+ test_loss = 0
140
+ correct = 0
141
+ total = 0
142
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
143
+ total=len(test_loader.dataset) // config["batch_size"]):
144
+ inputs, targets = inputs.to(device), targets.to(device)
145
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
146
+ outputs = model(inputs)
147
+ loss = criterion(outputs, targets)
148
+ # to logging losses
149
+ all_targets.extend(targets.flatten().tolist())
150
+ test_loss += loss.item()
151
+ _, predicts = outputs.max(1)
152
+ all_predicts.extend(predicts.flatten().tolist())
153
+ total += targets.size(0)
154
+ correct += predicts.eq(targets).sum().item()
155
+ loss = test_loss / (batch_idx + 1)
156
+ acc = correct / total
157
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
158
+ model.train()
159
+ return loss, acc, all_targets, all_predicts
160
+
161
+ # save train
162
+ def save_train(model=model, optimizer=optimizer):
163
+ model.train()
164
+ saved_number = 0
165
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
166
+ inputs, targets = inputs.to(device), targets.to(device)
167
+ optimizer.zero_grad()
168
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
169
+ outputs = model(inputs)
170
+ loss = criterion(outputs, targets)
171
+ loss.backward()
172
+ optimizer.step()
173
+ # Save checkpoint
174
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
175
+ _, acc, _, _ = test(model=model)
176
+ if not os.path.isdir('checkpoint'):
177
+ os.mkdir('checkpoint')
178
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items() \
179
+ if key in ["layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn2.weight", "layer4.1.bn2.bias"]}
180
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
181
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
182
+ saved_number += 1
183
+ if saved_number >= config["total_save_number"]:
184
+ break
185
+
186
+
187
+
188
+
189
+ # main
190
+ if __name__ == '__main__':
191
+ test(model=model)
192
+ for epoch in range(config["epochs"]):
193
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
194
+ test(model=model)
195
+ save_train(model=model, optimizer=optimizer)
dataset/cifar10_cnnmedium/model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import timm
5
+
6
+
7
+ class CNNMedium(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.module = nn.Sequential(
11
+ nn.Conv2d(3, 16, 3),
12
+ nn.MaxPool2d(2, 2),
13
+ nn.LeakyReLU(),
14
+ nn.Conv2d(16, 32, 3),
15
+ nn.MaxPool2d(2, 2),
16
+ nn.LeakyReLU(),
17
+ nn.Conv2d(32, 15, 3),
18
+ nn.MaxPool2d(2, 2),
19
+ nn.LeakyReLU(),
20
+ nn.Flatten(start_dim=1),
21
+ )
22
+ self.head = nn.Sequential(
23
+ nn.Linear(60, 20),
24
+ nn.LeakyReLU(),
25
+ nn.Linear(20, 10),
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.module(x)
30
+ x = self.head(x)
31
+ return x
32
+
33
+
34
+ def Model():
35
+ model = CNNMedium()
36
+ return model, model.head
37
+
38
+
39
+ if __name__ == "__main__":
40
+ model, _ = Model()
41
+ x = torch.ones([4, 3, 32, 32])
42
+ y = model(x)
43
+ print(y.shape)
44
+ print(model)
45
+ num_param = 0
46
+ for v in model.parameters():
47
+ num_param += v.numel()
48
+ print("num_param:", num_param)
dataset/cifar10_cnnmedium/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar10_cnnmedium/train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR10 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 32,
47
+ "learning_rate": 1e-2,
48
+ "weight_decay": 0.00666,
49
+ "epochs": 50,
50
+ "save_learning_rate": 1e-5,
51
+ "total_save_number": 50,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(32),
66
+ transforms.RandomCrop(32),
67
+ transforms.RandomHorizontalFlip(),
68
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ download=True,
86
+ train=False,
87
+ transform=transforms.Compose([
88
+ transforms.Resize(32),
89
+ transforms.CenterCrop(32),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
92
+ ])),
93
+ batch_size=config["batch_size"],
94
+ num_workers=config["num_workers"],
95
+ shuffle=False,
96
+ pin_memory=True,
97
+ persistent_workers=True,
98
+ pin_memory_device="cuda",
99
+ )
100
+
101
+ # Model
102
+ model, head = Model()
103
+ model = model.to(device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ optimizer = optim.SGD(
106
+ model.parameters(),
107
+ lr=config["learning_rate"],
108
+ weight_decay=config["weight_decay"],
109
+ momentum=0.9,
110
+ )
111
+ scheduler = lr_scheduler.CosineAnnealingLR(
112
+ optimizer,
113
+ T_max=config["epochs"],
114
+ eta_min=config["save_learning_rate"],
115
+ )
116
+
117
+
118
+
119
+
120
+ # Training
121
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
122
+ model.train()
123
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
124
+ total=len(dataset) // config["batch_size"]):
125
+ inputs, targets = inputs.to(device), targets.to(device)
126
+ optimizer.zero_grad()
127
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
128
+ outputs = model(inputs)
129
+ loss = criterion(outputs, targets)
130
+ loss.backward()
131
+ optimizer.step()
132
+ if scheduler is not None:
133
+ scheduler.step()
134
+
135
+ # test
136
+ @torch.no_grad()
137
+ def test(model=model):
138
+ model.eval()
139
+ all_targets = []
140
+ all_predicts = []
141
+ test_loss = 0
142
+ correct = 0
143
+ total = 0
144
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
145
+ total=len(test_loader.dataset) // config["batch_size"]):
146
+ inputs, targets = inputs.to(device), targets.to(device)
147
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
148
+ outputs = model(inputs)
149
+ loss = criterion(outputs, targets)
150
+ # to logging losses
151
+ all_targets.extend(targets.flatten().tolist())
152
+ test_loss += loss.item()
153
+ _, predicts = outputs.max(1)
154
+ all_predicts.extend(predicts.flatten().tolist())
155
+ total += targets.size(0)
156
+ correct += predicts.eq(targets).sum().item()
157
+ loss = test_loss / (batch_idx + 1)
158
+ acc = correct / total
159
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
160
+ model.train()
161
+ return loss, acc, all_targets, all_predicts
162
+
163
+ # save train
164
+ def save_train(model=model, optimizer=optimizer):
165
+ model.train()
166
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
167
+ inputs, targets = inputs.to(device), targets.to(device)
168
+ optimizer.zero_grad()
169
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
170
+ outputs = model(inputs)
171
+ loss = criterion(outputs, targets)
172
+ loss.backward()
173
+ optimizer.step()
174
+ # Save checkpoint
175
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
176
+ _, acc, _, _ = test(model=model)
177
+ if not os.path.isdir('checkpoint'):
178
+ os.mkdir('checkpoint')
179
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
180
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
181
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
182
+
183
+
184
+
185
+
186
+ # main
187
+ if __name__ == '__main__':
188
+ test(model=model)
189
+ for epoch in range(config["epochs"]):
190
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
191
+ test(model=model)
192
+ save_train(model=model, optimizer=optimizer)
dataset/cifar10_cnnsmall/model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import timm
5
+
6
+
7
+ class CNNSmall(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.module = nn.Sequential(
11
+ nn.Conv2d(3, 8, 5),
12
+ nn.MaxPool2d(2, 2),
13
+ nn.LeakyReLU(),
14
+ nn.Conv2d(8, 6, 5),
15
+ nn.MaxPool2d(2, 2),
16
+ nn.LeakyReLU(),
17
+ nn.Conv2d(6, 4, 2),
18
+ nn.LeakyReLU(),
19
+ nn.Flatten(start_dim=1),
20
+ )
21
+ self.head = nn.Sequential(
22
+ nn.Linear(36, 20),
23
+ nn.LeakyReLU(),
24
+ nn.Linear(20, 10),
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = F.interpolate(x, (28, 28), mode='bilinear')
29
+ x = self.module(x)
30
+ x = self.head(x)
31
+ return x
32
+
33
+
34
+ def Model():
35
+ model = CNNSmall()
36
+ return model, model.head
37
+
38
+
39
+ if __name__ == "__main__":
40
+ model, _ = Model()
41
+ x = torch.ones([4, 3, 28, 28])
42
+ y = model(x)
43
+ print(y.shape)
44
+ print(model)
45
+ num_param = 0
46
+ for v in model.parameters():
47
+ num_param += v.numel()
48
+ print("num_param:", num_param)
dataset/cifar10_cnnsmall/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar10_cnnsmall/train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR10 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 32,
47
+ "learning_rate": 1e-2,
48
+ "weight_decay": 0.001,
49
+ "epochs": 50,
50
+ "save_learning_rate": 1e-5,
51
+ "total_save_number": 50,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(32),
66
+ transforms.RandomCrop(32),
67
+ transforms.RandomHorizontalFlip(),
68
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ download=True,
86
+ train=False,
87
+ transform=transforms.Compose([
88
+ transforms.Resize(32),
89
+ transforms.CenterCrop(32),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
92
+ ])),
93
+ batch_size=config["batch_size"],
94
+ num_workers=config["num_workers"],
95
+ shuffle=False,
96
+ pin_memory=True,
97
+ persistent_workers=True,
98
+ pin_memory_device="cuda",
99
+ )
100
+
101
+ # Model
102
+ model, head = Model()
103
+ model = model.to(device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ optimizer = optim.SGD(
106
+ model.parameters(),
107
+ lr=config["learning_rate"],
108
+ weight_decay=config["weight_decay"],
109
+ momentum=0.9,
110
+ )
111
+ scheduler = lr_scheduler.CosineAnnealingLR(
112
+ optimizer,
113
+ T_max=config["epochs"],
114
+ eta_min=config["save_learning_rate"],
115
+ )
116
+
117
+
118
+
119
+
120
+ # Training
121
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
122
+ model.train()
123
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
124
+ total=len(dataset) // config["batch_size"]):
125
+ inputs, targets = inputs.to(device), targets.to(device)
126
+ optimizer.zero_grad()
127
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
128
+ outputs = model(inputs)
129
+ loss = criterion(outputs, targets)
130
+ loss.backward()
131
+ optimizer.step()
132
+ if scheduler is not None:
133
+ scheduler.step()
134
+
135
+ # test
136
+ @torch.no_grad()
137
+ def test(model=model):
138
+ model.eval()
139
+ all_targets = []
140
+ all_predicts = []
141
+ test_loss = 0
142
+ correct = 0
143
+ total = 0
144
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
145
+ total=len(test_loader.dataset) // config["batch_size"]):
146
+ inputs, targets = inputs.to(device), targets.to(device)
147
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
148
+ outputs = model(inputs)
149
+ loss = criterion(outputs, targets)
150
+ # to logging losses
151
+ all_targets.extend(targets.flatten().tolist())
152
+ test_loss += loss.item()
153
+ _, predicts = outputs.max(1)
154
+ all_predicts.extend(predicts.flatten().tolist())
155
+ total += targets.size(0)
156
+ correct += predicts.eq(targets).sum().item()
157
+ loss = test_loss / (batch_idx + 1)
158
+ acc = correct / total
159
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
160
+ model.train()
161
+ return loss, acc, all_targets, all_predicts
162
+
163
+ # save train
164
+ def save_train(model=model, optimizer=optimizer):
165
+ model.train()
166
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
167
+ inputs, targets = inputs.to(device), targets.to(device)
168
+ optimizer.zero_grad()
169
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
170
+ outputs = model(inputs)
171
+ loss = criterion(outputs, targets)
172
+ loss.backward()
173
+ optimizer.step()
174
+ # Save checkpoint
175
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
176
+ _, acc, _, _ = test(model=model)
177
+ if not os.path.isdir('checkpoint'):
178
+ os.mkdir('checkpoint')
179
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
180
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
181
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
182
+
183
+
184
+
185
+
186
+ # main
187
+ if __name__ == '__main__':
188
+ test(model=model)
189
+ for epoch in range(config["epochs"]):
190
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
191
+ test(model=model)
192
+ save_train(model=model, optimizer=optimizer)
dataset/cifar10_mobilenetv3/model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+
4
+
5
+ def Model():
6
+ model = timm.create_model("mobilenetv3_large_100", pretrained=True)
7
+ model.classifier = nn.Linear(1280, 10)
8
+ for name, param in model.named_parameters():
9
+ if "bn" in name:
10
+ # print(f"freeze {name}")
11
+ param.requires_grad = False
12
+ return model, model.classifier
13
+
14
+
15
+ if __name__ == "__main__":
16
+ model, _ = Model()
17
+ print(model)
18
+ num_param = 0
19
+ for v in model.parameters():
20
+ num_param += v.numel()
21
+ print("num_param:", num_param)
dataset/cifar10_mobilenetv3/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar10_mobilenetv3/train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR10 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 4,
47
+ "learning_rate": 3e-3,
48
+ "weight_decay": 0.1,
49
+ "epochs": 5,
50
+ "save_learning_rate": 1e-6,
51
+ "total_save_number": 50,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(224),
66
+ transforms.RandomCrop(224),
67
+ transforms.RandomHorizontalFlip(),
68
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ download=True,
86
+ train=False,
87
+ transform=transforms.Compose([
88
+ transforms.Resize(224),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
92
+ ])),
93
+ batch_size=config["batch_size"],
94
+ num_workers=config["num_workers"],
95
+ shuffle=False,
96
+ pin_memory=True,
97
+ persistent_workers=True,
98
+ pin_memory_device="cuda",
99
+ )
100
+
101
+ # Model
102
+ model, head = Model()
103
+ model = model.to(device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ pre_optimizer = optim.AdamW(
106
+ head.parameters(),
107
+ lr=0.05,
108
+ weight_decay=0.01,
109
+ )
110
+ optimizer = optim.AdamW(
111
+ model.parameters(),
112
+ lr=config["learning_rate"],
113
+ weight_decay=config["weight_decay"],
114
+ )
115
+ scheduler = lr_scheduler.CosineAnnealingLR(
116
+ optimizer,
117
+ T_max=config["epochs"],
118
+ eta_min=config["save_learning_rate"],
119
+ )
120
+
121
+
122
+
123
+
124
+ # Training
125
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
126
+ model.train()
127
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
128
+ total=len(dataset) // config["batch_size"]):
129
+ inputs, targets = inputs.to(device), targets.to(device)
130
+ optimizer.zero_grad()
131
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
132
+ outputs = model(inputs)
133
+ loss = criterion(outputs, targets)
134
+ loss.backward()
135
+ optimizer.step()
136
+ if scheduler is not None:
137
+ scheduler.step()
138
+
139
+ # test
140
+ @torch.no_grad()
141
+ def test(model=model):
142
+ model.eval()
143
+ all_targets = []
144
+ all_predicts = []
145
+ test_loss = 0
146
+ correct = 0
147
+ total = 0
148
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
149
+ total=len(test_loader.dataset) // config["batch_size"]):
150
+ inputs, targets = inputs.to(device), targets.to(device)
151
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
152
+ outputs = model(inputs)
153
+ loss = criterion(outputs, targets)
154
+ # to logging losses
155
+ all_targets.extend(targets.flatten().tolist())
156
+ test_loss += loss.item()
157
+ _, predicts = outputs.max(1)
158
+ all_predicts.extend(predicts.flatten().tolist())
159
+ total += targets.size(0)
160
+ correct += predicts.eq(targets).sum().item()
161
+ loss = test_loss / (batch_idx + 1)
162
+ acc = correct / total
163
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
164
+ model.train()
165
+ return loss, acc, all_targets, all_predicts
166
+
167
+ # save train
168
+ def save_train(model=model, optimizer=optimizer):
169
+ model.train()
170
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
171
+ inputs, targets = inputs.to(device), targets.to(device)
172
+ optimizer.zero_grad()
173
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
174
+ outputs = model(inputs)
175
+ loss = criterion(outputs, targets)
176
+ loss.backward()
177
+ optimizer.step()
178
+ # Save checkpoint
179
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
180
+ _, acc, _, _ = test(model=model)
181
+ if not os.path.isdir('checkpoint'):
182
+ os.mkdir('checkpoint')
183
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
184
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
185
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
186
+
187
+
188
+
189
+
190
+ # main
191
+ if __name__ == '__main__':
192
+ test(model=model)
193
+ for _ in range(1):
194
+ train(model=model, optimizer=pre_optimizer)
195
+ test(model=model)
196
+ for epoch in range(config["epochs"]):
197
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
198
+ test(model=model)
199
+ save_train(model=model, optimizer=optimizer)
dataset/cifar10_resnet18/model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+
4
+
5
+ def Model():
6
+ model = timm.create_model("resnet18", pretrained=True)
7
+ model.fc = nn.Linear(512, 10)
8
+ return model, model.fc
9
+
10
+
11
+ if __name__ == "__main__":
12
+ model, _ = Model()
13
+ print(model)
14
+ num_param = 0
15
+ for v in model.parameters():
16
+ num_param += v.numel()
17
+ print("num_param:", num_param)
dataset/cifar10_resnet18/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar10_resnet18/train.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR10 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 32,
47
+ "learning_rate": 3e-3,
48
+ "weight_decay": 0.1,
49
+ "epochs": 50,
50
+ "save_learning_rate": 1e-5,
51
+ "total_save_number": 50,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(64),
66
+ transforms.RandomCrop(64),
67
+ transforms.RandomHorizontalFlip(),
68
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ download=True,
86
+ train=False,
87
+ transform=transforms.Compose([
88
+ transforms.Resize(64),
89
+ transforms.CenterCrop(64),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
92
+ ])),
93
+ batch_size=config["batch_size"],
94
+ num_workers=config["num_workers"],
95
+ shuffle=False,
96
+ pin_memory=True,
97
+ persistent_workers=True,
98
+ pin_memory_device="cuda",
99
+ )
100
+
101
+ # Model
102
+ model, head = Model()
103
+ model = model.to(device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ optimizer = optim.AdamW(
106
+ model.parameters(),
107
+ lr=config["learning_rate"],
108
+ weight_decay=config["weight_decay"],
109
+ )
110
+ scheduler = lr_scheduler.CosineAnnealingLR(
111
+ optimizer,
112
+ T_max=config["epochs"],
113
+ eta_min=config["save_learning_rate"],
114
+ )
115
+
116
+
117
+
118
+
119
+ # Training
120
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
121
+ model.train()
122
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
123
+ total=len(dataset) // config["batch_size"]):
124
+ inputs, targets = inputs.to(device), targets.to(device)
125
+ optimizer.zero_grad()
126
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
127
+ outputs = model(inputs)
128
+ loss = criterion(outputs, targets)
129
+ loss.backward()
130
+ optimizer.step()
131
+ if scheduler is not None:
132
+ scheduler.step()
133
+
134
+ # test
135
+ @torch.no_grad()
136
+ def test(model=model):
137
+ model.eval()
138
+ all_targets = []
139
+ all_predicts = []
140
+ test_loss = 0
141
+ correct = 0
142
+ total = 0
143
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
144
+ total=len(test_loader.dataset) // config["batch_size"]):
145
+ inputs, targets = inputs.to(device), targets.to(device)
146
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
147
+ outputs = model(inputs)
148
+ loss = criterion(outputs, targets)
149
+ # to logging losses
150
+ all_targets.extend(targets.flatten().tolist())
151
+ test_loss += loss.item()
152
+ _, predicts = outputs.max(1)
153
+ all_predicts.extend(predicts.flatten().tolist())
154
+ total += targets.size(0)
155
+ correct += predicts.eq(targets).sum().item()
156
+ loss = test_loss / (batch_idx + 1)
157
+ acc = correct / total
158
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
159
+ model.train()
160
+ return loss, acc, all_targets, all_predicts
161
+
162
+ # save train
163
+ def save_train(model=model, optimizer=optimizer):
164
+ model.train()
165
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
166
+ inputs, targets = inputs.to(device), targets.to(device)
167
+ optimizer.zero_grad()
168
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
169
+ outputs = model(inputs)
170
+ loss = criterion(outputs, targets)
171
+ loss.backward()
172
+ optimizer.step()
173
+ # Save checkpoint
174
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
175
+ _, acc, _, _ = test(model=model)
176
+ if not os.path.isdir('checkpoint'):
177
+ os.mkdir('checkpoint')
178
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
179
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
180
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
181
+
182
+
183
+
184
+
185
+ # main
186
+ if __name__ == '__main__':
187
+ test(model=model)
188
+ for epoch in range(config["epochs"]):
189
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
190
+ test(model=model)
191
+ save_train(model=model, optimizer=optimizer)
dataset/cifar10_vitbase/model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+
4
+
5
+ def Model():
6
+ model = timm.create_model("vit_base_patch16_224", pretrained=True)
7
+ model.head = nn.Linear(768, 10)
8
+ return model, model.head
9
+
10
+
11
+ if __name__ == "__main__":
12
+ model, _ = Model()
13
+ print(model)
14
+ num_param = 0
15
+ for v in model.parameters():
16
+ num_param += v.numel()
17
+ print("num_param:", num_param)
dataset/cifar10_vitbase/test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+ for item in test_items:
26
+ state = torch.load(item, map_location="cpu")
27
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
28
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/cifar10_vitbase/train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ except ImportError:
18
+ from .model import Model
19
+
20
+ # import
21
+ import torch.nn as nn
22
+ from torch import optim
23
+ from torch.optim import lr_scheduler
24
+ from torch.utils.data import DataLoader
25
+ import torchvision.transforms as transforms
26
+ from torchvision.datasets import CIFAR10 as Dataset
27
+ from tqdm.auto import tqdm
28
+ import os
29
+ import warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # load additional config
33
+ import json
34
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
35
+ with open(config_file, "r") as f:
36
+ additional_config = json.load(f)
37
+
38
+
39
+
40
+
41
+ # config
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ config = {
44
+ "dataset_root": "from_additional_config",
45
+ "batch_size": 500 if __name__ == "__main__" else 200,
46
+ "num_workers": 32,
47
+ "learning_rate": 3e-5,
48
+ "weight_decay": 0.1,
49
+ "epochs": 7,
50
+ "save_learning_rate": 1e-5,
51
+ "total_save_number": 50,
52
+ "tag": os.path.basename(os.path.dirname(__file__)),
53
+ }
54
+ config.update(additional_config)
55
+
56
+
57
+
58
+
59
+ # Data
60
+ dataset = Dataset(
61
+ root=config["dataset_root"],
62
+ download=True,
63
+ train=True,
64
+ transform=transforms.Compose([
65
+ transforms.Resize(224),
66
+ transforms.RandomCrop(224),
67
+ transforms.RandomHorizontalFlip(),
68
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ download=True,
86
+ train=False,
87
+ transform=transforms.Compose([
88
+ transforms.Resize(224),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
92
+ ])),
93
+ batch_size=config["batch_size"],
94
+ num_workers=config["num_workers"],
95
+ shuffle=False,
96
+ pin_memory=True,
97
+ persistent_workers=True,
98
+ pin_memory_device="cuda",
99
+ )
100
+
101
+ # Model
102
+ model, head = Model()
103
+ model = model.to(device)
104
+ criterion = nn.CrossEntropyLoss()
105
+ pre_optimizer = optim.AdamW(
106
+ head.parameters(),
107
+ lr=0.05,
108
+ weight_decay=0.01,
109
+ )
110
+ optimizer = optim.AdamW(
111
+ model.parameters(),
112
+ lr=config["learning_rate"],
113
+ weight_decay=config["weight_decay"],
114
+ )
115
+ scheduler = lr_scheduler.CosineAnnealingLR(
116
+ optimizer,
117
+ T_max=config["epochs"],
118
+ eta_min=config["save_learning_rate"],
119
+ )
120
+
121
+
122
+
123
+
124
+ # Training
125
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
126
+ model.train()
127
+ for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
128
+ total=len(dataset) // config["batch_size"]):
129
+ inputs, targets = inputs.to(device), targets.to(device)
130
+ optimizer.zero_grad()
131
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
132
+ outputs = model(inputs)
133
+ loss = criterion(outputs, targets)
134
+ loss.backward()
135
+ optimizer.step()
136
+ if scheduler is not None:
137
+ scheduler.step()
138
+
139
+ # test
140
+ @torch.no_grad()
141
+ def test(model=model):
142
+ model.eval()
143
+ all_targets = []
144
+ all_predicts = []
145
+ test_loss = 0
146
+ correct = 0
147
+ total = 0
148
+ for batch_idx, (inputs, targets) in tqdm(enumerate(test_loader),
149
+ total=len(test_loader.dataset) // config["batch_size"]):
150
+ inputs, targets = inputs.to(device), targets.to(device)
151
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
152
+ outputs = model(inputs)
153
+ loss = criterion(outputs, targets)
154
+ # to logging losses
155
+ all_targets.extend(targets.flatten().tolist())
156
+ test_loss += loss.item()
157
+ _, predicts = outputs.max(1)
158
+ all_predicts.extend(predicts.flatten().tolist())
159
+ total += targets.size(0)
160
+ correct += predicts.eq(targets).sum().item()
161
+ loss = test_loss / (batch_idx + 1)
162
+ acc = correct / total
163
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
164
+ model.train()
165
+ return loss, acc, all_targets, all_predicts
166
+
167
+ # save train
168
+ def save_train(model=model, optimizer=optimizer):
169
+ model.train()
170
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
171
+ inputs, targets = inputs.to(device), targets.to(device)
172
+ optimizer.zero_grad()
173
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
174
+ outputs = model(inputs)
175
+ loss = criterion(outputs, targets)
176
+ loss.backward()
177
+ optimizer.step()
178
+ # Save checkpoint
179
+ if batch_idx % (len(dataset) // train_loader.batch_size // config["total_save_number"]) == 0:
180
+ _, acc, _, _ = test(model=model)
181
+ if not os.path.isdir('checkpoint'):
182
+ os.mkdir('checkpoint')
183
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
184
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
185
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_seed{seed:04d}_{config['tag']}.pth")
186
+
187
+
188
+
189
+
190
+ # main
191
+ if __name__ == '__main__':
192
+ test(model=model)
193
+ for _ in range(3):
194
+ train(model=model, optimizer=pre_optimizer)
195
+ test(model=model)
196
+ for epoch in range(config["epochs"]):
197
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
198
+ test(model=model)
199
+ save_train(model=model, optimizer=optimizer)
dataset/condition_classinput_inference/dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from torch.utils.data import Dataset
4
+ from torchvision.datasets import CIFAR10
5
+ import torchvision.transforms as transforms
6
+
7
+
8
+ class BinaryClassifierDataset(Dataset):
9
+ def __init__(self, root, train, optimize_class: list):
10
+ self.optimize_class = optimize_class
11
+ self.dataset = CIFAR10(
12
+ root=root,
13
+ train=train,
14
+ download=True,
15
+ transform=transforms.Compose([
16
+ transforms.Resize(224),
17
+ transforms.RandomHorizontalFlip(),
18
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
21
+ ])
22
+ )
23
+
24
+ def __getitem__(self, index):
25
+ img, origin_target = self.dataset[index]
26
+ target = 1 if origin_target in self.optimize_class else 0
27
+ return img, target
28
+
29
+ def __len__(self):
30
+ return self.dataset.__len__()
31
+
32
+
33
+ def get_optimize_class():
34
+ try: # get string
35
+ string = sys.argv[1]
36
+ except IndexError:
37
+ RuntimeError("sys.argv[1] not found")
38
+ class_int_string = str(re.search(r'class(\d+)', string).group(1)).zfill(4)
39
+ one_hot_string = bin(int(class_int_string))[2:].zfill(10)
40
+ optimize_class = [index for index, i in enumerate(one_hot_string) if i == "1"]
41
+ return list(optimize_class), class_int_string
dataset/condition_classinput_inference/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+
6
+ def Model():
7
+ model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
8
+ model.head = nn.Sequential(
9
+ nn.Linear(192, 192, bias=True),
10
+ nn.SiLU(),
11
+ nn.Linear(192, 2, bias=False),
12
+ )
13
+ for param in model.head.parameters():
14
+ param = nn.Parameter(torch.ones_like(param) / 192)
15
+ param.requires_grad = True
16
+ return model, model.head
17
+
18
+
19
+ if __name__ == "__main__":
20
+ model, _ = Model()
21
+ print(model)
22
+ num_param = 0
23
+ for v in model.parameters():
24
+ num_param += v.numel()
25
+ print("num_param:", num_param)
dataset/condition_classinput_inference/test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint_test"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+
26
+
27
+ for item in test_items:
28
+ state = torch.load(item, map_location="cpu")
29
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
30
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/condition_classinput_inference/train.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+
15
+ try: # relative import
16
+ from model import Model
17
+ from dataset import BinaryClassifierDataset as Dataset
18
+ from dataset import get_optimize_class
19
+ except ImportError:
20
+ from .model import Model
21
+ from .dataset import BinaryClassifierDataset as Dataset
22
+ from .dataset import get_optimize_class
23
+
24
+ # import
25
+ import torch.nn as nn
26
+ from torch import optim
27
+ from torch.optim import lr_scheduler
28
+ from torch.utils.data import DataLoader
29
+ from torch.nn import functional as F
30
+ import os
31
+ import sys
32
+ import warnings
33
+ warnings.filterwarnings("ignore", category=UserWarning)
34
+
35
+ # load additional config
36
+ import json
37
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
38
+ with open(config_file, "r") as f:
39
+ additional_config = json.load(f)
40
+
41
+
42
+
43
+
44
+ # config
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ config = {
47
+ "dataset_root": "from_additional_config",
48
+ "batch_size": 500 if __name__ == "__main__" else 50,
49
+ "num_workers": 16,
50
+ "pre_learning_rate": 0.01,
51
+ "learning_rate": 1e-4,
52
+ "pre_epochs": 2,
53
+ "epochs": 13,
54
+ "weight_decay": 0.1,
55
+ "save_learning_rate": 2e-5,
56
+ "total_save_number": 5,
57
+ "tag": os.path.basename(os.path.dirname(__file__)),
58
+ "optimize_class": get_optimize_class()[0],
59
+ "optimize_class_int": get_optimize_class()[1],
60
+ }
61
+ config.update(additional_config)
62
+ print("Training/Testing:", config["optimize_class"])
63
+
64
+
65
+
66
+
67
+ # Data
68
+ dataset = Dataset(
69
+ root=config["dataset_root"],
70
+ train=True,
71
+ optimize_class=config["optimize_class"],
72
+ )
73
+ train_loader = DataLoader(
74
+ dataset=dataset,
75
+ batch_size=config["batch_size"],
76
+ num_workers=config["num_workers"],
77
+ shuffle=True,
78
+ drop_last=True,
79
+ pin_memory=True,
80
+ persistent_workers=True,
81
+ )
82
+ test_loader = DataLoader(
83
+ dataset=Dataset(
84
+ root=config["dataset_root"],
85
+ train=False,
86
+ optimize_class=config["optimize_class"],
87
+ ),
88
+ batch_size=config["batch_size"],
89
+ num_workers=config["num_workers"],
90
+ shuffle=False,
91
+ )
92
+
93
+ # Model
94
+ model, head = Model()
95
+ model = model.to(device)
96
+ class FocalLoss(nn.Module):
97
+ def __init__(self, weight=None, gamma=2):
98
+ super(FocalLoss, self).__init__()
99
+ self.weight = weight
100
+ self.gamma = gamma
101
+ def forward(self, input, target):
102
+ ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)
103
+ pt = torch.exp(-ce_loss)
104
+ focal_loss = (1 - pt) ** self.gamma * ce_loss
105
+ return focal_loss.mean()
106
+ criterion = FocalLoss()
107
+
108
+ # Optimizer
109
+ head_optimizer = optim.AdamW(
110
+ head.parameters(),
111
+ lr=config["pre_learning_rate"],
112
+ weight_decay=config["weight_decay"],
113
+ )
114
+ optimizer = optim.AdamW(
115
+ model.parameters(),
116
+ lr=config["learning_rate"],
117
+ weight_decay=config["weight_decay"],
118
+ )
119
+ scheduler = lr_scheduler.CosineAnnealingLR(
120
+ optimizer,
121
+ T_max=config["epochs"],
122
+ eta_min=config["save_learning_rate"],
123
+ )
124
+
125
+
126
+
127
+
128
+ # Training
129
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
130
+ model.train()
131
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
132
+ inputs, targets = inputs.to(device), targets.to(device)
133
+ optimizer.zero_grad()
134
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
135
+ outputs = model(inputs)
136
+ loss = criterion(outputs, targets)
137
+ loss.backward()
138
+ optimizer.step()
139
+ if scheduler is not None:
140
+ scheduler.step()
141
+
142
+ # test
143
+ @torch.no_grad()
144
+ def test(model=model):
145
+ model.eval()
146
+ all_targets = []
147
+ all_predicts = []
148
+ test_loss = 0
149
+ correct = 0
150
+ total = 0
151
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
152
+ inputs, targets = inputs.to(device), targets.to(device)
153
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
154
+ outputs = model(inputs)
155
+ loss = criterion(outputs, targets)
156
+ # to logging losses
157
+ all_targets.extend(targets.flatten().tolist())
158
+ test_loss += loss.item()
159
+ _, predicts = outputs.max(1)
160
+ all_predicts.extend(predicts.flatten().tolist())
161
+ total += targets.size(0)
162
+ correct += predicts.eq(targets).sum().item()
163
+ loss = test_loss / (batch_idx + 1)
164
+ acc = correct / total
165
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
166
+ model.train()
167
+ return loss, acc, all_targets, all_predicts
168
+
169
+ # save train
170
+ def save_train(model=model, optimizer=optimizer):
171
+ data_loader = DataLoader(
172
+ dataset=dataset,
173
+ batch_size=min(len(dataset) // config["total_save_number"], config["batch_size"]),
174
+ num_workers=config["num_workers"],
175
+ shuffle=True,
176
+ drop_last=True,
177
+ )
178
+ model.train()
179
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
180
+ inputs, targets = inputs.to(device), targets.to(device)
181
+ optimizer.zero_grad()
182
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
183
+ outputs = model(inputs)
184
+ loss = criterion(outputs, targets)
185
+ loss.backward()
186
+ optimizer.step()
187
+ # Save checkpoint
188
+ _, acc, _, _ = test(model=model)
189
+ if not os.path.isdir('checkpoint'):
190
+ os.mkdir('checkpoint')
191
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
192
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
193
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
194
+ # exit loop
195
+ if batch_idx+1 == config["total_save_number"]:
196
+ break
197
+
198
+
199
+
200
+
201
+ # main
202
+ if __name__ == '__main__':
203
+ for epoch in range(config["pre_epochs"]):
204
+ train(model=model, optimizer=head_optimizer, scheduler=None)
205
+ # test(model=model)
206
+ for epoch in range(config["epochs"]):
207
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
208
+ # test(model=model)
209
+ save_train(model=model, optimizer=optimizer)
dataset/condition_classinput_vittiny/dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from torch.utils.data import Dataset
4
+ from torchvision.datasets import CIFAR10
5
+ import torchvision.transforms as transforms
6
+
7
+
8
+ class BinaryClassifierDataset(Dataset):
9
+ def __init__(self, root, train, optimize_class: list):
10
+ self.optimize_class = optimize_class
11
+ self.dataset = CIFAR10(
12
+ root=root,
13
+ train=train,
14
+ download=True,
15
+ transform=transforms.Compose([
16
+ transforms.Resize(224),
17
+ transforms.RandomHorizontalFlip(),
18
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
21
+ ])
22
+ )
23
+
24
+ def __getitem__(self, index):
25
+ img, origin_target = self.dataset[index]
26
+ target = 1 if origin_target in self.optimize_class else 0
27
+ return img, target
28
+
29
+ def __len__(self):
30
+ return self.dataset.__len__()
31
+
32
+
33
+ def get_optimize_class():
34
+ try: # get string
35
+ string = sys.argv[1]
36
+ except IndexError:
37
+ RuntimeError("sys.argv[1] not found")
38
+ class_int_string = str(re.search(r'class(\d+)', string).group(1)).zfill(4)
39
+ one_hot_string = bin(int(class_int_string))[2:].zfill(10)
40
+ optimize_class = [index for index, i in enumerate(one_hot_string) if i == "1"]
41
+ return list(optimize_class), class_int_string
dataset/condition_classinput_vittiny/detail.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+ from torchvision.datasets import CIFAR10
8
+ from torchvision import transforms
9
+
10
+
11
+
12
+
13
+ try:
14
+ test_item = sys.argv[1]
15
+ except IndexError:
16
+ assert __name__ == "__main__"
17
+ test_item = "./generated"
18
+ test_items = []
19
+ if os.path.isdir(test_item):
20
+ for item in os.listdir(test_item):
21
+ item = os.path.join(test_item, item)
22
+ test_items.append(item)
23
+ elif os.path.isfile(test_item):
24
+ test_items.append(test_item)
25
+
26
+
27
+
28
+
29
+ original_dataset = CIFAR10(
30
+ root=config["dataset_root"],
31
+ train=False,
32
+ download=True,
33
+ transform=transforms.Compose([
34
+ transforms.Resize(224),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
37
+ ])
38
+ )
39
+ original_targets = [original_dataset[i][1] for i in range(len(original_dataset))]
40
+ original_targets = torch.tensor(original_targets, dtype=torch.long)
41
+
42
+
43
+
44
+
45
+ for item in test_items:
46
+ state = torch.load(item, map_location="cpu")
47
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
48
+ loss, acc, all_targets, all_predicts = test(model=model)
49
+ all_targets, all_predicts = torch.tensor(all_targets), torch.tensor(all_predicts)
50
+
51
+ for class_idx in range(10):
52
+ class_mask = torch.where(original_targets == class_idx, 1, 0)
53
+ total_number = torch.sum(class_mask)
54
+ correct = torch.where(all_targets == all_predicts, 1, 0)
55
+ class_correct = class_mask * correct
56
+ correct_number = torch.sum(class_correct)
57
+ class_acc = correct_number.item() / total_number.item()
58
+ print(f"class{class_idx}:", class_acc)
dataset/condition_classinput_vittiny/finetune.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import time
3
+ print("time stamp:", time.time())
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ seed = SEED = 20
8
+ torch.manual_seed(seed)
9
+ torch.cuda.manual_seed(seed)
10
+ torch.cuda.manual_seed_all(seed)
11
+ torch.backends.cudnn.deterministic = True
12
+ torch.backends.cudnn.benchmark = True
13
+ np.random.seed(seed)
14
+ random.seed(seed)
15
+
16
+
17
+ try: # relative import
18
+ from model import Model
19
+ from dataset import BinaryClassifierDataset as Dataset
20
+ from dataset import get_optimize_class
21
+ except ImportError:
22
+ from .model import Model
23
+ from .dataset import BinaryClassifierDataset as Dataset
24
+ from .dataset import get_optimize_class
25
+
26
+ # import
27
+ import torch.nn as nn
28
+ from torch import optim
29
+ from torch.optim import lr_scheduler
30
+ from torch.utils.data import DataLoader
31
+ from torch.nn import functional as F
32
+ import os
33
+ import sys
34
+ import warnings
35
+ warnings.filterwarnings("ignore", category=UserWarning)
36
+
37
+ # load additional config
38
+ import json
39
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
40
+ with open(config_file, "r") as f:
41
+ additional_config = json.load(f)
42
+
43
+
44
+
45
+
46
+ # config
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ config = {
49
+ "dataset_root": "from_additional_config",
50
+ "batch_size": 500 if __name__ == "__main__" else 50,
51
+ "num_workers": 16,
52
+ "pre_learning_rate": 0.01,
53
+ "learning_rate": 2e-5,
54
+ "pre_epochs": 0,
55
+ "epochs": 50,
56
+ "weight_decay": 0.1,
57
+ "save_learning_rate": 1e-6,
58
+ "total_save_number": 5,
59
+ "tag": os.path.basename(os.path.dirname(__file__)),
60
+ "optimize_class": get_optimize_class()[0],
61
+ "optimize_class_int": get_optimize_class()[1],
62
+ }
63
+ config.update(additional_config)
64
+ print("Training:", config["optimize_class"])
65
+
66
+
67
+
68
+
69
+ # Data
70
+ dataset = Dataset(
71
+ root=config["dataset_root"],
72
+ train=True,
73
+ optimize_class=config["optimize_class"],
74
+ )
75
+ train_loader = DataLoader(
76
+ dataset=dataset,
77
+ batch_size=config["batch_size"],
78
+ num_workers=config["num_workers"],
79
+ shuffle=True,
80
+ drop_last=True,
81
+ pin_memory=True,
82
+ persistent_workers=True,
83
+ )
84
+ test_loader = DataLoader(
85
+ dataset=Dataset(
86
+ root=config["dataset_root"],
87
+ train=False,
88
+ optimize_class=config["optimize_class"],
89
+ ),
90
+ batch_size=config["batch_size"],
91
+ num_workers=config["num_workers"],
92
+ shuffle=False,
93
+ )
94
+
95
+ # Model
96
+ model, head = Model()
97
+ model.load_state_dict(torch.load(sys.argv[1], map_location="cpu", weights_only=True))
98
+ model = model.to(device)
99
+ class FocalLoss(nn.Module):
100
+ def __init__(self, weight=None, gamma=2):
101
+ super(FocalLoss, self).__init__()
102
+ self.weight = weight
103
+ self.gamma = gamma
104
+ def forward(self, input, target):
105
+ ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)
106
+ pt = torch.exp(-ce_loss)
107
+ focal_loss = (1 - pt) ** self.gamma * ce_loss
108
+ return focal_loss.mean()
109
+ criterion = FocalLoss()
110
+
111
+ # Optimizer
112
+ head_optimizer = optim.AdamW(
113
+ head.parameters(),
114
+ lr=config["pre_learning_rate"],
115
+ weight_decay=config["weight_decay"],
116
+ )
117
+ optimizer = optim.AdamW(
118
+ model.parameters(),
119
+ lr=config["learning_rate"],
120
+ weight_decay=config["weight_decay"],
121
+ )
122
+ scheduler = lr_scheduler.CosineAnnealingLR(
123
+ optimizer,
124
+ T_max=config["epochs"],
125
+ eta_min=config["save_learning_rate"],
126
+ )
127
+
128
+
129
+
130
+
131
+ # Training
132
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
133
+ model.train()
134
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
135
+ inputs, targets = inputs.to(device), targets.to(device)
136
+ optimizer.zero_grad()
137
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
138
+ outputs = model(inputs)
139
+ loss = criterion(outputs, targets)
140
+ loss.backward()
141
+ optimizer.step()
142
+ if scheduler is not None:
143
+ scheduler.step()
144
+
145
+ # test
146
+ @torch.no_grad()
147
+ def test(model=model):
148
+ model.eval()
149
+ all_targets = []
150
+ all_predicts = []
151
+ test_loss = 0
152
+ correct = 0
153
+ total = 0
154
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
155
+ inputs, targets = inputs.to(device), targets.to(device)
156
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
157
+ outputs = model(inputs)
158
+ loss = criterion(outputs, targets)
159
+ # to logging losses
160
+ all_targets.extend(targets.flatten().tolist())
161
+ test_loss += loss.item()
162
+ _, predicts = outputs.max(1)
163
+ all_predicts.extend(predicts.flatten().tolist())
164
+ total += targets.size(0)
165
+ correct += predicts.eq(targets).sum().item()
166
+ loss = test_loss / (batch_idx + 1)
167
+ acc = correct / total
168
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
169
+ model.train()
170
+ return loss, acc, all_targets, all_predicts
171
+
172
+ # save train
173
+ def save_train(model=model, optimizer=optimizer):
174
+ data_loader = DataLoader(
175
+ dataset=dataset,
176
+ batch_size=min(len(dataset) // config["total_save_number"], config["batch_size"]),
177
+ num_workers=config["num_workers"],
178
+ shuffle=True,
179
+ drop_last=True,
180
+ )
181
+ model.train()
182
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
183
+ inputs, targets = inputs.to(device), targets.to(device)
184
+ optimizer.zero_grad()
185
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
186
+ outputs = model(inputs)
187
+ loss = criterion(outputs, targets)
188
+ loss.backward()
189
+ optimizer.step()
190
+ # Save checkpoint
191
+ # _, acc, _, _ = test(model=model)
192
+ acc = 1.0
193
+ if not os.path.isdir('checkpoint'):
194
+ os.mkdir('checkpoint')
195
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
196
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
197
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
198
+ # exit loop
199
+ if batch_idx+1 == config["total_save_number"]:
200
+ break
201
+
202
+
203
+
204
+
205
+ # main
206
+ if __name__ == '__main__':
207
+ test(model=model)
208
+ for epoch in range(config["pre_epochs"]):
209
+ train(model=model, optimizer=head_optimizer, scheduler=None)
210
+ test(model=model)
211
+ for epoch in range(config["epochs"]):
212
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
213
+ test(model=model)
214
+ # save_train(model=model, optimizer=optimizer)
215
+ print("time stamp:", time.time())
dataset/condition_classinput_vittiny/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+
6
+ def Model():
7
+ model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
8
+ model.head = nn.Sequential(
9
+ nn.Linear(192, 192, bias=True),
10
+ nn.SiLU(),
11
+ nn.Linear(192, 2, bias=False),
12
+ )
13
+ for param in model.head.parameters():
14
+ param = nn.Parameter(torch.ones_like(param) / 192)
15
+ param.requires_grad = True
16
+ return model, model.head
17
+
18
+
19
+ if __name__ == "__main__":
20
+ model, _ = Model()
21
+ print(model)
22
+ num_param = 0
23
+ for v in model.parameters():
24
+ num_param += v.numel()
25
+ print("num_param:", num_param)
dataset/condition_classinput_vittiny/split.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mkdir checkpoint_test
2
+ mkdir checkpoint_train
3
+ mkdir generated
4
+
5
+ mv ./checkpoint/*class0314* ./checkpoint_test
6
+ mv ./checkpoint/*class0482* ./checkpoint_test
7
+ mv ./checkpoint/*class0589* ./checkpoint_test
8
+ mv ./checkpoint/*class0197* ./checkpoint_test
9
+ mv ./checkpoint/*class0462* ./checkpoint_test
10
+ mv ./checkpoint/*class0111* ./checkpoint_test
11
+ mv ./checkpoint/*class0101* ./checkpoint_test
12
+ mv ./checkpoint/*class0278* ./checkpoint_test
13
+ mv ./checkpoint/*class0793* ./checkpoint_test
14
+ mv ./checkpoint/*class0279* ./checkpoint_test
15
+ mv ./checkpoint/*class0653* ./checkpoint_test
16
+ mv ./checkpoint/*class0238* ./checkpoint_test
17
+ mv ./checkpoint/*class1001* ./checkpoint_test
18
+ mv ./checkpoint/*class0141* ./checkpoint_test
19
+ mv ./checkpoint/*class0884* ./checkpoint_test
20
+ mv ./checkpoint/*class0592* ./checkpoint_test
21
+ mv ./checkpoint/*class0502* ./checkpoint_test
22
+ mv ./checkpoint/*class0643* ./checkpoint_test
23
+ mv ./checkpoint/*class0383* ./checkpoint_test
24
+ mv ./checkpoint/*class0128* ./checkpoint_test
25
+
26
+ mv ./checkpoint/* ./checkpoint_train
27
+
28
+ rm checkpoint -r
dataset/condition_classinput_vittiny/test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint_test"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+
26
+
27
+ for item in test_items:
28
+ state = torch.load(item, map_location="cpu")
29
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
30
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/condition_classinput_vittiny/train.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import time
3
+ print("time stamp:", time.time())
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ seed = SEED = 20
8
+ torch.manual_seed(seed)
9
+ torch.cuda.manual_seed(seed)
10
+ torch.cuda.manual_seed_all(seed)
11
+ torch.backends.cudnn.deterministic = True
12
+ torch.backends.cudnn.benchmark = True
13
+ np.random.seed(seed)
14
+ random.seed(seed)
15
+
16
+
17
+ try: # relative import
18
+ from model import Model
19
+ from dataset import BinaryClassifierDataset as Dataset
20
+ from dataset import get_optimize_class
21
+ except ImportError:
22
+ from .model import Model
23
+ from .dataset import BinaryClassifierDataset as Dataset
24
+ from .dataset import get_optimize_class
25
+
26
+ # import
27
+ import torch.nn as nn
28
+ from torch import optim
29
+ from torch.optim import lr_scheduler
30
+ from torch.utils.data import DataLoader
31
+ from torch.nn import functional as F
32
+ import os
33
+ import sys
34
+ import warnings
35
+ warnings.filterwarnings("ignore", category=UserWarning)
36
+
37
+ # load additional config
38
+ import json
39
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
40
+ with open(config_file, "r") as f:
41
+ additional_config = json.load(f)
42
+
43
+
44
+
45
+
46
+ # config
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ config = {
49
+ "dataset_root": "from_additional_config",
50
+ "batch_size": 500 if __name__ == "__main__" else 50,
51
+ "num_workers": 16,
52
+ "pre_learning_rate": 0.01,
53
+ "learning_rate": 1e-4,
54
+ "pre_epochs": 2,
55
+ "epochs": 13,
56
+ "weight_decay": 0.1,
57
+ "save_learning_rate": 2e-5,
58
+ "total_save_number": 5,
59
+ "tag": os.path.basename(os.path.dirname(__file__)),
60
+ "optimize_class": get_optimize_class()[0],
61
+ "optimize_class_int": get_optimize_class()[1],
62
+ }
63
+ config.update(additional_config)
64
+ print("Training:", config["optimize_class"])
65
+
66
+
67
+
68
+
69
+ # Data
70
+ dataset = Dataset(
71
+ root=config["dataset_root"],
72
+ train=True,
73
+ optimize_class=config["optimize_class"],
74
+ )
75
+ train_loader = DataLoader(
76
+ dataset=dataset,
77
+ batch_size=config["batch_size"],
78
+ num_workers=config["num_workers"],
79
+ shuffle=True,
80
+ drop_last=True,
81
+ pin_memory=True,
82
+ persistent_workers=True,
83
+ )
84
+ test_loader = DataLoader(
85
+ dataset=Dataset(
86
+ root=config["dataset_root"],
87
+ train=False,
88
+ optimize_class=config["optimize_class"],
89
+ ),
90
+ batch_size=config["batch_size"],
91
+ num_workers=config["num_workers"],
92
+ shuffle=False,
93
+ )
94
+
95
+ # Model
96
+ model, head = Model()
97
+ model = model.to(device)
98
+ class FocalLoss(nn.Module):
99
+ def __init__(self, weight=None, gamma=2):
100
+ super(FocalLoss, self).__init__()
101
+ self.weight = weight
102
+ self.gamma = gamma
103
+ def forward(self, input, target):
104
+ ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)
105
+ pt = torch.exp(-ce_loss)
106
+ focal_loss = (1 - pt) ** self.gamma * ce_loss
107
+ return focal_loss.mean()
108
+ criterion = FocalLoss()
109
+
110
+ # Optimizer
111
+ head_optimizer = optim.AdamW(
112
+ head.parameters(),
113
+ lr=config["pre_learning_rate"],
114
+ weight_decay=config["weight_decay"],
115
+ )
116
+ optimizer = optim.AdamW(
117
+ model.parameters(),
118
+ lr=config["learning_rate"],
119
+ weight_decay=config["weight_decay"],
120
+ )
121
+ scheduler = lr_scheduler.CosineAnnealingLR(
122
+ optimizer,
123
+ T_max=config["epochs"],
124
+ eta_min=config["save_learning_rate"],
125
+ )
126
+
127
+
128
+
129
+
130
+ # Training
131
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
132
+ model.train()
133
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
134
+ inputs, targets = inputs.to(device), targets.to(device)
135
+ optimizer.zero_grad()
136
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
137
+ outputs = model(inputs)
138
+ loss = criterion(outputs, targets)
139
+ loss.backward()
140
+ optimizer.step()
141
+ if scheduler is not None:
142
+ scheduler.step()
143
+
144
+ # test
145
+ @torch.no_grad()
146
+ def test(model=model):
147
+ model.eval()
148
+ all_targets = []
149
+ all_predicts = []
150
+ test_loss = 0
151
+ correct = 0
152
+ total = 0
153
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
154
+ inputs, targets = inputs.to(device), targets.to(device)
155
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
156
+ outputs = model(inputs)
157
+ loss = criterion(outputs, targets)
158
+ # to logging losses
159
+ all_targets.extend(targets.flatten().tolist())
160
+ test_loss += loss.item()
161
+ _, predicts = outputs.max(1)
162
+ all_predicts.extend(predicts.flatten().tolist())
163
+ total += targets.size(0)
164
+ correct += predicts.eq(targets).sum().item()
165
+ loss = test_loss / (batch_idx + 1)
166
+ acc = correct / total
167
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
168
+ model.train()
169
+ return loss, acc, all_targets, all_predicts
170
+
171
+ # save train
172
+ def save_train(model=model, optimizer=optimizer):
173
+ data_loader = DataLoader(
174
+ dataset=dataset,
175
+ batch_size=min(len(dataset) // config["total_save_number"], config["batch_size"]),
176
+ num_workers=config["num_workers"],
177
+ shuffle=True,
178
+ drop_last=True,
179
+ )
180
+ model.train()
181
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
182
+ inputs, targets = inputs.to(device), targets.to(device)
183
+ optimizer.zero_grad()
184
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
185
+ outputs = model(inputs)
186
+ loss = criterion(outputs, targets)
187
+ loss.backward()
188
+ optimizer.step()
189
+ # Save checkpoint
190
+ _, acc, _, _ = test(model=model)
191
+ if not os.path.isdir('checkpoint'):
192
+ os.mkdir('checkpoint')
193
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
194
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
195
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
196
+ # exit loop
197
+ if batch_idx+1 == config["total_save_number"]:
198
+ break
199
+
200
+
201
+
202
+
203
+ # main
204
+ if __name__ == '__main__':
205
+ for epoch in range(config["pre_epochs"]):
206
+ train(model=model, optimizer=head_optimizer, scheduler=None)
207
+ # test(model=model)
208
+ for epoch in range(config["epochs"]):
209
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
210
+ # test(model=model)
211
+ save_train(model=model, optimizer=optimizer)
212
+ print("time stamp:", time.time())
dataset/condition_classinput_vittiny/train.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ start=1
4
+ end=1022
5
+
6
+ for i in $(seq $start $end)
7
+ do
8
+ python train.py class$i
9
+ sleep 1
10
+ done
dataset/condition_imageinput_vittiny/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Code for condition_imageinput_vittiny is coming...
dataset/condition_imageinput_vittiny/dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from torch.utils.data import Dataset
4
+ from torchvision.datasets import CIFAR10
5
+ import torchvision.transforms as transforms
6
+
7
+
8
+
9
+
10
+ class BinaryClassifierDataset(Dataset):
11
+ def __init__(self, root, train, optimize_class):
12
+ optimize_class = [optimize_class,] if isinstance(optimize_class, int) else optimize_class
13
+ self.optimize_class = optimize_class
14
+ self.dataset = CIFAR10(
15
+ root=root,
16
+ train=train,
17
+ download=True,
18
+ transform=transforms.Compose([
19
+ transforms.Resize(224),
20
+ transforms.RandomHorizontalFlip(),
21
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
24
+ ])
25
+ )
26
+
27
+ def __getitem__(self, index):
28
+ img, origin_target = self.dataset[index]
29
+ target = 1 if origin_target in self.optimize_class else 0
30
+ return img, target
31
+
32
+ def __len__(self):
33
+ return self.dataset.__len__()
34
+
35
+
36
+
37
+
38
+ def get_optimize_class():
39
+ try: # get string
40
+ string = sys.argv[1]
41
+ except IndexError:
42
+ RuntimeError("sys.argv[1] not found")
43
+ class_int_string = str(re.search(r'class(\d+)', string).group(1)).zfill(4)
44
+ one_hot_string = bin(int(class_int_string))[2:].zfill(10)
45
+ optimize_class = [index for index, i in enumerate(one_hot_string) if i == "1"]
46
+ return list(optimize_class), class_int_string
dataset/condition_imageinput_vittiny/model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+
6
+ def Model():
7
+ model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
8
+ model.head = nn.Linear(192, 2)
9
+ return model, model.head
10
+
11
+
12
+ if __name__ == "__main__":
13
+ model, _ = Model()
14
+ print(model)
15
+ num_param = 0
16
+ for v in model.parameters():
17
+ num_param += v.numel()
18
+ print("num_param:", num_param)
dataset/condition_imageinput_vittiny/test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+
26
+
27
+ for item in test_items:
28
+ state = torch.load(item, map_location="cpu")
29
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
30
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/condition_imageinput_vittiny/train.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ seed = SEED = 20
6
+ torch.manual_seed(seed)
7
+ torch.cuda.manual_seed(seed)
8
+ torch.cuda.manual_seed_all(seed)
9
+ torch.backends.cudnn.deterministic = True
10
+ torch.backends.cudnn.benchmark = True
11
+ np.random.seed(seed)
12
+ random.seed(seed)
13
+
14
+ try: # relative import
15
+ from model import Model
16
+ from dataset import BinaryClassifierDataset as Dataset
17
+ from dataset import get_optimize_class
18
+ except ImportError:
19
+ from .model import Model
20
+ from .dataset import BinaryClassifierDataset as Dataset
21
+ from .dataset import get_optimize_class
22
+
23
+ # import
24
+ import torch.nn as nn
25
+ from torch import optim
26
+ from torch.optim import lr_scheduler
27
+ from torch.utils.data import DataLoader
28
+ from torch.nn import functional as F
29
+ import os
30
+ import sys
31
+ import warnings
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+
34
+ # load additional config
35
+ import json
36
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
37
+ with open(config_file, "r") as f:
38
+ additional_config = json.load(f)
39
+
40
+
41
+
42
+
43
+ # config
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ config = {
46
+ "dataset_root": "from_additional_config",
47
+ "batch_size": 250 if __name__ == "__main__" else 50,
48
+ "num_workers": 20,
49
+ "pre_learning_rate": 0.01,
50
+ "learning_rate": 3e-5,
51
+ "pre_epochs": 2,
52
+ "epochs": 13,
53
+ "weight_decay": 0.1,
54
+ "save_learning_rate": 1e-5,
55
+ "total_save_number": 10,
56
+ "tag": os.path.basename(os.path.dirname(__file__)),
57
+ "optimize_class": get_optimize_class()[0],
58
+ "optimize_class_int": get_optimize_class()[1],
59
+ }
60
+ config.update(additional_config)
61
+ print("Training:", config["optimize_class"])
62
+
63
+
64
+
65
+
66
+ # Data
67
+ dataset = Dataset(
68
+ root=config["dataset_root"],
69
+ train=True,
70
+ optimize_class=config["optimize_class"],
71
+ )
72
+ train_loader = DataLoader(
73
+ dataset=dataset,
74
+ batch_size=config["batch_size"],
75
+ num_workers=config["num_workers"],
76
+ shuffle=True,
77
+ drop_last=True,
78
+ pin_memory=True,
79
+ persistent_workers=True,
80
+ )
81
+ test_loader = DataLoader(
82
+ dataset=Dataset(
83
+ root=config["dataset_root"],
84
+ train=False,
85
+ optimize_class=config["optimize_class"],
86
+ ),
87
+ batch_size=config["batch_size"],
88
+ num_workers=config["num_workers"],
89
+ shuffle=False,
90
+ )
91
+
92
+ # Model
93
+ model, head = Model()
94
+ model = model.to(device)
95
+ class FocalLoss(nn.Module):
96
+ def __init__(self, weight=None, gamma=2):
97
+ super(FocalLoss, self).__init__()
98
+ self.weight = weight
99
+ self.gamma = gamma
100
+ def forward(self, input, target):
101
+ ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)
102
+ pt = torch.exp(-ce_loss)
103
+ focal_loss = (1 - pt) ** self.gamma * ce_loss
104
+ return focal_loss.mean()
105
+ criterion = FocalLoss()
106
+
107
+ # Optimizer
108
+ head_optimizer = optim.AdamW(
109
+ head.parameters(),
110
+ lr=config["pre_learning_rate"],
111
+ weight_decay=config["weight_decay"],
112
+ )
113
+ optimizer = optim.AdamW(
114
+ model.parameters(),
115
+ lr=config["learning_rate"],
116
+ weight_decay=config["weight_decay"],
117
+ )
118
+ scheduler = lr_scheduler.CosineAnnealingLR(
119
+ optimizer,
120
+ T_max=config["epochs"],
121
+ eta_min=config["save_learning_rate"],
122
+ )
123
+
124
+
125
+
126
+
127
+ # Training
128
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
129
+ model.train()
130
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
131
+ inputs, targets = inputs.to(device), targets.to(device)
132
+ optimizer.zero_grad()
133
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
134
+ outputs = model(inputs)
135
+ loss = criterion(outputs, targets)
136
+ loss.backward()
137
+ optimizer.step()
138
+ if scheduler is not None:
139
+ scheduler.step()
140
+
141
+ # test
142
+ @torch.no_grad()
143
+ def test(model=model):
144
+ model.eval()
145
+ all_targets = []
146
+ all_predicts = []
147
+ test_loss = 0
148
+ correct = 0
149
+ total = 0
150
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
151
+ inputs, targets = inputs.to(device), targets.to(device)
152
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
153
+ outputs = model(inputs)
154
+ loss = criterion(outputs, targets)
155
+ # to logging losses
156
+ all_targets.extend(targets.flatten().tolist())
157
+ test_loss += loss.item()
158
+ _, predicts = outputs.max(1)
159
+ all_predicts.extend(predicts.flatten().tolist())
160
+ total += targets.size(0)
161
+ correct += predicts.eq(targets).sum().item()
162
+ loss = test_loss / (batch_idx + 1)
163
+ acc = correct / total
164
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
165
+ model.train()
166
+ return loss, acc, all_targets, all_predicts
167
+
168
+ # save train
169
+ def save_train(model=model, optimizer=optimizer):
170
+ data_loader = DataLoader(
171
+ dataset=dataset,
172
+ batch_size=min(len(dataset) // config["total_save_number"], config["batch_size"]),
173
+ num_workers=config["num_workers"],
174
+ shuffle=True,
175
+ drop_last=True,
176
+ )
177
+ model.train()
178
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
179
+ inputs, targets = inputs.to(device), targets.to(device)
180
+ optimizer.zero_grad()
181
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
182
+ outputs = model(inputs)
183
+ loss = criterion(outputs, targets)
184
+ loss.backward()
185
+ optimizer.step()
186
+ # Save checkpoint
187
+ _, acc, _, _ = test(model=model)
188
+ if not os.path.isdir('checkpoint'):
189
+ os.mkdir('checkpoint')
190
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
191
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
192
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{config['optimize_class_int']}_{config['tag']}.pth")
193
+ # exit loop
194
+ if batch_idx+1 == config["total_save_number"]:
195
+ break
196
+
197
+
198
+
199
+
200
+ # main
201
+ if __name__ == '__main__':
202
+ for epoch in range(config["pre_epochs"]):
203
+ train(model=model, optimizer=head_optimizer, scheduler=None)
204
+ test(model=model)
205
+ for epoch in range(config["epochs"]):
206
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
207
+ test(model=model)
208
+ save_train(model=model, optimizer=optimizer)
dataset/condition_imageinput_vittiny/train.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ start=0
4
+ end=9
5
+
6
+ for i in $(seq $start $end)
7
+ do
8
+ power=$((2**i))
9
+ CUDA_VISIBLE_DEVICES=5 python train.py class$power
10
+ sleep 1
11
+ done
dataset/condition_permutation_vittiny/model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+
6
+ def Model():
7
+ model = timm.create_model("vit_tiny_patch16_224", pretrained=False)
8
+ model.head = nn.Linear(192, 10)
9
+ return model, model.head
10
+
11
+
12
+ if __name__ == "__main__":
13
+ model, _ = Model()
14
+ print(model)
15
+ num_param = 0
16
+ for v in model.parameters():
17
+ num_param += v.numel()
18
+ print("num_param:", num_param)
dataset/condition_permutation_vittiny/test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ if __name__ == "__main__":
4
+ from train import *
5
+ else: # relative import
6
+ from .train import *
7
+
8
+
9
+
10
+
11
+ try:
12
+ test_item = sys.argv[1]
13
+ except IndexError:
14
+ assert __name__ == "__main__"
15
+ test_item = "./checkpoint"
16
+ test_items = []
17
+ if os.path.isdir(test_item):
18
+ for item in os.listdir(test_item):
19
+ item = os.path.join(test_item, item)
20
+ test_items.append(item)
21
+ elif os.path.isfile(test_item):
22
+ test_items.append(test_item)
23
+
24
+
25
+
26
+
27
+ for item in test_items:
28
+ print(f"testing: {item}")
29
+ state = torch.load(item, map_location="cpu")
30
+ model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()})
31
+ loss, acc, all_targets, all_predicts = test(model=model)
dataset/condition_permutation_vittiny/train.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set global seed
2
+ import time
3
+ print("time stamp:", time.time())
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ import re
8
+ import sys
9
+ if __name__ == "__main__":
10
+ def get_permutation_state():
11
+ try: # get string
12
+ string = sys.argv[1]
13
+ except IndexError:
14
+ RuntimeError("sys.argv[1] not found")
15
+ class_int_string = str(re.search(r'class(\d+)', string).group(1)).zfill(4)
16
+ return int(class_int_string)
17
+ seed = SEED = get_permutation_state()
18
+ else: # when testing
19
+ seed = SEED = 0
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ torch.backends.cudnn.deterministic = True
24
+ torch.backends.cudnn.benchmark = True
25
+ np.random.seed(seed)
26
+ random.seed(seed)
27
+ print("Seed:", SEED)
28
+
29
+ try: # relative import
30
+ from model import Model
31
+ except ImportError:
32
+ from .model import Model
33
+
34
+ # import
35
+ import torch.nn as nn
36
+ from torch import optim
37
+ from torch.optim import lr_scheduler
38
+ from torch.utils.data import DataLoader
39
+ from torchvision.datasets import CIFAR10 as Dataset
40
+ from torchvision import transforms
41
+ from torch.nn import functional as F
42
+ import warnings
43
+ warnings.filterwarnings("ignore", category=UserWarning)
44
+
45
+ # load additional config
46
+ import os
47
+ import json
48
+ config_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")
49
+ with open(config_file, "r") as f:
50
+ additional_config = json.load(f)
51
+
52
+
53
+
54
+
55
+ # config
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ config = {
58
+ "dataset_root": "from_additional_config",
59
+ "batch_size": 250 if __name__ == "__main__" else 50,
60
+ "num_workers": 16,
61
+ "learning_rate": 5e-3,
62
+ "epochs": 200,
63
+ "weight_decay": 0.1,
64
+ "save_learning_rate": 2e-5,
65
+ "total_save_number": 5,
66
+ "tag": os.path.basename(os.path.dirname(__file__)),
67
+ }
68
+ config.update(additional_config)
69
+
70
+
71
+
72
+
73
+ # Data
74
+ dataset = Dataset(
75
+ root=config["dataset_root"],
76
+ train=True,
77
+ download=True,
78
+ transform=transforms.Compose([
79
+ transforms.Resize(224),
80
+ transforms.RandomCrop(224, padding=32),
81
+ transforms.RandomHorizontalFlip(),
82
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
85
+ ])
86
+ )
87
+ train_loader = DataLoader(
88
+ dataset=dataset,
89
+ batch_size=config["batch_size"],
90
+ num_workers=config["num_workers"],
91
+ shuffle=True,
92
+ drop_last=True,
93
+ pin_memory=True,
94
+ persistent_workers=True,
95
+ )
96
+ test_loader = DataLoader(
97
+ dataset=Dataset(
98
+ root=config["dataset_root"],
99
+ train=False,
100
+ download=True,
101
+ transform=transforms.Compose([
102
+ transforms.Resize(224),
103
+ transforms.ToTensor(),
104
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
105
+ ])),
106
+ batch_size=config["batch_size"],
107
+ num_workers=config["num_workers"],
108
+ shuffle=False,
109
+ )
110
+
111
+ # Model
112
+ model, head = Model()
113
+ model = model.to(device)
114
+ criterion = nn.CrossEntropyLoss()
115
+
116
+ # Optimizer
117
+ optimizer = optim.AdamW(
118
+ model.parameters(),
119
+ lr=config["learning_rate"],
120
+ weight_decay=config["weight_decay"],
121
+ )
122
+ scheduler = lr_scheduler.CosineAnnealingLR(
123
+ optimizer,
124
+ T_max=config["epochs"],
125
+ eta_min=config["save_learning_rate"],
126
+ )
127
+
128
+
129
+
130
+
131
+ # Training
132
+ def train(model=model, optimizer=optimizer, scheduler=scheduler):
133
+ model.train()
134
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
135
+ inputs, targets = inputs.to(device), targets.to(device)
136
+ optimizer.zero_grad()
137
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
138
+ outputs = model(inputs)
139
+ loss = criterion(outputs, targets)
140
+ loss.backward()
141
+ optimizer.step()
142
+ if scheduler is not None:
143
+ scheduler.step()
144
+
145
+ # test
146
+ @torch.no_grad()
147
+ def test(model=model):
148
+ model.eval()
149
+ all_targets = []
150
+ all_predicts = []
151
+ test_loss = 0
152
+ correct = 0
153
+ total = 0
154
+ for batch_idx, (inputs, targets) in enumerate(test_loader):
155
+ inputs, targets = inputs.to(device), targets.to(device)
156
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
157
+ outputs = model(inputs)
158
+ loss = criterion(outputs, targets)
159
+ # to logging losses
160
+ all_targets.extend(targets.flatten().tolist())
161
+ test_loss += loss.item()
162
+ _, predicts = outputs.max(1)
163
+ all_predicts.extend(predicts.flatten().tolist())
164
+ total += targets.size(0)
165
+ correct += predicts.eq(targets).sum().item()
166
+ loss = test_loss / (batch_idx + 1)
167
+ acc = correct / total
168
+ print(f"Loss: {loss:.4f} | Acc: {acc:.4f}\n")
169
+ model.train()
170
+ return loss, acc, all_targets, all_predicts
171
+
172
+ # save train
173
+ def save_train(model=model, optimizer=optimizer):
174
+ data_loader = DataLoader(
175
+ dataset=dataset,
176
+ batch_size=min(len(dataset) // config["total_save_number"], config["batch_size"]),
177
+ num_workers=config["num_workers"],
178
+ shuffle=True,
179
+ drop_last=True,
180
+ )
181
+ model.train()
182
+ for batch_idx, (inputs, targets) in enumerate(data_loader):
183
+ inputs, targets = inputs.to(device), targets.to(device)
184
+ optimizer.zero_grad()
185
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
186
+ outputs = model(inputs)
187
+ loss = criterion(outputs, targets)
188
+ loss.backward()
189
+ optimizer.step()
190
+ # Save checkpoint
191
+ _, acc, _, _ = test(model=model)
192
+ if not os.path.isdir('checkpoint'):
193
+ os.mkdir('checkpoint')
194
+ save_state = {key: value.cpu().to(torch.float32) for key, value in model.state_dict().items()}
195
+ torch.save(save_state, f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{SEED:04d}_{config['tag']}.pth")
196
+ print("save:", f"checkpoint/{str(batch_idx).zfill(4)}_acc{acc:.4f}_class{SEED:04d}_{config['tag']}.pth")
197
+ # exit loop
198
+ if batch_idx+1 == config["total_save_number"]:
199
+ break
200
+
201
+
202
+
203
+
204
+ # main
205
+ if __name__ == '__main__':
206
+ for epoch in range(config["epochs"]):
207
+ train(model=model, optimizer=optimizer, scheduler=scheduler)
208
+ test(model=model)
209
+ save_train(model=model, optimizer=optimizer)
210
+ print("time stamp:", time.time())
dataset/condition_permutation_vittiny/train.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ start=0
4
+ end=19
5
+
6
+ for i in $(seq $start $end)
7
+ do
8
+ python train.py class$i
9
+ sleep 1
10
+ done
dataset/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"dataset_root": "path_to_your_dataset", "imagenet_root": {"train": null, "test": null}, "dora_root": "/home/wangkai/arpgen/DoRA/commonsense_reasoning", "dora_env_name": "dora_llama"}
dataset/dataset.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ from torch.utils.data import Dataset
4
+ from torchvision.datasets import CIFAR10
5
+ from torchvision import transforms
6
+ import os
7
+ import math
8
+ import random
9
+ import json
10
+ from abc import ABC
11
+ import pickle
12
+
13
+
14
+
15
+
16
+ def pad_to_length(x, common_factor, **config):
17
+ if x.numel() % common_factor == 0:
18
+ return x.flatten()
19
+ # print(f"padding {x.shape} according to {common_factor}")
20
+ full_length = (x.numel() // common_factor + 1) * common_factor
21
+ padding_length = full_length - len(x.flatten())
22
+ padding = torch.full([padding_length, ], dtype=x.dtype, device=x.device, fill_value=config["fill_value"])
23
+ x = torch.cat((x.flatten(), padding), dim=0)
24
+ return x
25
+
26
+ def layer_to_token(x, common_factor, **config):
27
+ if config["granularity"] == 2: # split by output
28
+ if x.numel() <= common_factor:
29
+ return pad_to_length(x.flatten(), common_factor, **config)[None]
30
+ dim2 = x[0].numel()
31
+ dim1 = x.shape[0]
32
+ if dim2 <= common_factor:
33
+ i = int(dim1 / (common_factor / dim2))
34
+ while True:
35
+ if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
36
+ output = x.view(-1, dim2 * (dim1 // i))
37
+ output = [pad_to_length(item, common_factor, **config) for item in output]
38
+ return torch.stack(output, dim=0)
39
+ i += 1
40
+ else: # dim2 > common_factor
41
+ output = [layer_to_token(item, common_factor, **config) for item in x]
42
+ return torch.cat(output, dim=0)
43
+ elif config["granularity"] == 1: # split by layer
44
+ return pad_to_length(x.flatten(), common_factor, **config).view(-1, common_factor)
45
+ elif config["granularity"] == 0: # flatten directly
46
+ return x.flatten()
47
+ else: # NotImplementedError
48
+ raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
49
+
50
+
51
+ def token_to_layer(tokens, shape, **config):
52
+ common_factor = tokens.shape[-1]
53
+ if config["granularity"] == 2: # split by output
54
+ num_element = math.prod(shape)
55
+ if num_element <= common_factor:
56
+ param = tokens[0][:num_element].view(shape)
57
+ tokens = tokens[1:]
58
+ return param, tokens
59
+ dim2 = num_element // shape[0]
60
+ dim1 = shape[0]
61
+ if dim2 <= common_factor:
62
+ i = int(dim1 / (common_factor / dim2))
63
+ while True:
64
+ if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
65
+ item_per_token = dim2 * (dim1 // i)
66
+ length = num_element // item_per_token
67
+ output = [item[:item_per_token] for item in tokens[:length]]
68
+ param = torch.cat(output, dim=0).view(shape)
69
+ tokens = tokens[length:]
70
+ return param, tokens
71
+ i += 1
72
+ else: # dim2 > common_factor
73
+ output = []
74
+ for i in range(shape[0]):
75
+ param, tokens = token_to_layer(tokens, shape[1:], **config)
76
+ output.append(param.flatten())
77
+ param = torch.cat(output, dim=0).view(shape)
78
+ return param, tokens
79
+ elif config["granularity"] == 1: # split by layer
80
+ num_element = math.prod(shape)
81
+ token_num = num_element // common_factor if num_element % common_factor == 0 \
82
+ else num_element // common_factor + 1
83
+ param = tokens.flatten()[:num_element].view(shape)
84
+ tokens = tokens[token_num:]
85
+ return param, tokens
86
+ elif config["granularity"] == 0: # flatten directly
87
+ num_element = math.prod(shape)
88
+ param = tokens.flatten()[:num_element].view(shape)
89
+ tokens = pad_to_length(tokens.flatten()[num_element:],
90
+ common_factor, fill_value=torch.nan).view(-1, common_factor)
91
+ return param, tokens
92
+ else: # NotImplementedError
93
+ raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
94
+
95
+
96
+ def positional_embedding_2d(dim1, dim2, d_model):
97
+ assert d_model % 4 == 0, f"Cannot use sin/cos positional encoding with odd dimension {d_model}"
98
+ pe = torch.zeros(d_model, dim1, dim2)
99
+ d_model = int(d_model / 2) # Each dimension use half of d_model
100
+ div_term = torch.exp(torch.arange(0., d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model))
101
+ pos_w = torch.arange(0., dim2).unsqueeze(1)
102
+ pos_h = torch.arange(0., dim1).unsqueeze(1)
103
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
104
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
105
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
106
+ pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
107
+ return pe.permute(1, 2, 0)
108
+
109
+
110
+ def positional_embedding_1d(dim1, d_model):
111
+ pe = torch.zeros(dim1, d_model)
112
+ position = torch.arange(0, dim1, dtype=torch.float).unsqueeze(1)
113
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
114
+ pe[:, 0::2] = torch.sin(position * div_term)
115
+ pe[:, 1::2] = torch.cos(position * div_term)
116
+ return pe
117
+
118
+
119
+
120
+
121
+ class BaseDataset(Dataset, ABC):
122
+ data_path = None
123
+ generated_path = None
124
+ test_command = None
125
+ config = {
126
+ "fill_value": torch.nan,
127
+ "granularity": 1, # 0: flatten directly, 1: split by layer, 2: split by output
128
+ "pe_granularity": 2, # 0: no embedding, 1: 1d embedding, 2: 2d embedding
129
+ }
130
+
131
+ def __init__(self, checkpoint_path=None, dim_per_token=8192, **kwargs):
132
+ if not os.path.exists(self.data_path):
133
+ os.makedirs(self.data_path, exist_ok=False)
134
+ if self.generated_path is not None and not os.path.exists(os.path.dirname(self.generated_path)):
135
+ os.makedirs(os.path.dirname(self.generated_path))
136
+ self.config.update(kwargs)
137
+ checkpoint_path = self.data_path if checkpoint_path is None else checkpoint_path
138
+ assert os.path.exists(checkpoint_path)
139
+ self.dim_per_token = dim_per_token
140
+ self.structure = None # set in get_structure()
141
+ self.sequence_length = None # set in get_structure()
142
+ # load checkpoint_list
143
+ checkpoint_list = os.listdir(checkpoint_path)
144
+ self.checkpoint_list = list([os.path.join(checkpoint_path, item) for item in checkpoint_list])
145
+ self.length = self.real_length = len(self.checkpoint_list)
146
+ self.set_infinite_dataset()
147
+ # get structure
148
+ structure_cache_file = os.path.join(os.path.dirname(self.data_path), "structure.cache")
149
+ try: # try to load cache file
150
+ assert os.path.exists(structure_cache_file)
151
+ with open(structure_cache_file, "rb") as f:
152
+ print(f"Loading cache from {structure_cache_file}")
153
+ cache_file = pickle.load(f)
154
+ if len(self.checkpoint_list) != 0:
155
+ assert set(cache_file["checkpoint_list"]) == set(self.checkpoint_list)
156
+ self.structure = cache_file["structure"]
157
+ else: # empty checkpoint_list, only generate
158
+ print("Cannot find any trained checkpoint, loading cache file for generating!")
159
+ self.structure = cache_file["structure"]
160
+ fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
161
+ torch.save(fake_diction, os.path.join(checkpoint_path, "fake_checkpoint.pth"))
162
+ self.checkpoint_list.append(os.path.join(checkpoint_path, "fake_checkpoint.pth"))
163
+ self.length = self.real_length = len(self.checkpoint_list)
164
+ self.set_infinite_dataset()
165
+ os.system(f"rm {os.path.join(checkpoint_path, 'fake_checkpoint.pth')}")
166
+ except AssertionError: # recompute cache file
167
+ print("==> Organizing structure..")
168
+ self.structure = self.get_structure()
169
+ with open(structure_cache_file, "wb") as f:
170
+ pickle.dump({"structure": self.structure, "checkpoint_list": self.checkpoint_list}, f)
171
+ # get sequence_length
172
+ self.sequence_length = self.get_sequence_length()
173
+
174
+ def get_sequence_length(self):
175
+ fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
176
+ # get sequence_length
177
+ param = self.preprocess(fake_diction)
178
+ self.sequence_length = param.size(0)
179
+ return self.sequence_length
180
+
181
+ def get_structure(self):
182
+ # get structure
183
+ checkpoint_list = self.checkpoint_list
184
+ structures = [{} for _ in range(len(checkpoint_list))]
185
+ for i, checkpoint in enumerate(checkpoint_list):
186
+ diction = torch.load(checkpoint, map_location="cpu")
187
+ for key, value in diction.items():
188
+ if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
189
+ structures[i][key] = (value.shape, value, None)
190
+ elif "running_var" in key:
191
+ pre_mean = value.mean() * 0.95
192
+ value = torch.log(value / pre_mean + 0.05)
193
+ structures[i][key] = (value.shape, pre_mean, value.mean(), value.std())
194
+ else: # conv & linear
195
+ structures[i][key] = (value.shape, value.mean(), value.std())
196
+ final_structure = {}
197
+ structure_diction = torch.load(checkpoint_list[0], map_location="cpu")
198
+ for key, param in structure_diction.items():
199
+ if ("num_batches_tracked" in key) or (param.numel() == 1) or not torch.is_floating_point(param):
200
+ final_structure[key] = (param.shape, param, None)
201
+ elif "running_var" in key:
202
+ value = [param.shape, 0., 0., 0.]
203
+ for structure in structures:
204
+ for i in [1, 2, 3]:
205
+ value[i] += structure[key][i]
206
+ for i in [1, 2, 3]:
207
+ value[i] /= len(structures)
208
+ final_structure[key] = tuple(value)
209
+ else: # conv & linear
210
+ value = [param.shape, 0., 0.]
211
+ for structure in structures:
212
+ for i in [1, 2]:
213
+ value[i] += structure[key][i]
214
+ for i in [1, 2]:
215
+ value[i] /= len(structures)
216
+ final_structure[key] = tuple(value)
217
+ self.structure = final_structure
218
+ return self.structure
219
+
220
+ def set_infinite_dataset(self, max_num=None):
221
+ if max_num is None:
222
+ max_num = self.length * 1000000
223
+ self.length = max_num
224
+ return self
225
+
226
+ @property
227
+ def max_permutation_state(self):
228
+ return self.real_length
229
+
230
+ def get_position_embedding(self, positional_embedding_dim=None):
231
+ if positional_embedding_dim is None:
232
+ positional_embedding_dim = self.dim_per_token // 2
233
+ assert self.structure is not None, "run get_structure before get_position_embedding"
234
+ if self.config["pe_granularity"] == 2:
235
+ print("Use 2d positional embedding")
236
+ positional_embedding_index = []
237
+ for key, item in self.structure.items():
238
+ if ("num_batches_tracked" in key) or (item[-1] is None):
239
+ continue
240
+ else: # conv & linear
241
+ shape, *_ = item
242
+ fake_param = torch.ones(size=shape)
243
+ fake_param = layer_to_token(fake_param, self.dim_per_token, **self.config)
244
+ positional_embedding_index.append(list(range(fake_param.size(0))))
245
+ dim1 = len(positional_embedding_index)
246
+ dim2 = max([len(token_per_layer) for token_per_layer in positional_embedding_index])
247
+ full_pe = positional_embedding_2d(dim1, dim2, positional_embedding_dim)
248
+ positional_embedding = []
249
+ for layer_index, token_indexes in enumerate(positional_embedding_index):
250
+ for token_index in token_indexes:
251
+ this_pe = full_pe[layer_index, token_index]
252
+ positional_embedding.append(this_pe)
253
+ positional_embedding = torch.stack(positional_embedding)
254
+ return positional_embedding
255
+ elif self.config["pe_granularity"] == 1:
256
+ print("Use 1d positional embedding")
257
+ return positional_embedding_1d(self.sequence_length, positional_embedding_dim)
258
+ elif self.config["pe_granularity"] == 0:
259
+ print("Not use positional embedding")
260
+ return torch.zeros_like(self.__getitem__(0))
261
+ else: # NotImplementedError
262
+ raise NotImplementedError("pe_granularity: 0: no embedding, 1: 1d embedding, 2: 2d embedding")
263
+
264
+ def __len__(self):
265
+ return self.length
266
+
267
+ def __getitem__(self, index):
268
+ index = index % self.real_length
269
+ diction = torch.load(self.checkpoint_list[index], map_location="cpu")
270
+ param = self.preprocess(diction)
271
+ return param, index
272
+
273
+ def save_params(self, params, save_path):
274
+ diction = self.postprocess(params.cpu().to(torch.float32))
275
+ torch.save(diction, save_path)
276
+
277
+ def preprocess(self, diction: dict, **kwargs) -> torch.Tensor:
278
+ param_list = []
279
+ for key, value in diction.items():
280
+ if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
281
+ continue
282
+ elif "running_var" in key:
283
+ shape, pre_mean, mean, std = self.structure[key]
284
+ value = torch.log(value / pre_mean + 0.05)
285
+ else: # normal
286
+ shape, mean, std = self.structure[key]
287
+ value = (value - mean) / std
288
+ value = layer_to_token(value, self.dim_per_token, **self.config)
289
+ param_list.append(value)
290
+ param = torch.cat(param_list, dim=0)
291
+ if self.config["granularity"] == 0: # padding directly process tail
292
+ param = pad_to_length(param, self.dim_per_token, **self.config).view(-1, self.dim_per_token)
293
+ # print("Sequence length:", param.size(0))
294
+ return param.to(torch.float32)
295
+
296
+ def postprocess(self, params: torch.Tensor, **kwargs) -> dict:
297
+ diction = {}
298
+ params = params if len(params.shape) == 2 else params.squeeze(0)
299
+ for key, item in self.structure.items():
300
+ if ("num_batches_tracked" in key) or (item[-1] is None):
301
+ shape, mean, std = item
302
+ diction[key] = mean
303
+ continue
304
+ elif "running_var" in key:
305
+ shape, pre_mean, mean, std = item
306
+ else: # conv & linear
307
+ shape, mean, std = item
308
+ this_param, params = token_to_layer(params, shape, **self.config)
309
+ this_param = this_param * std + mean
310
+ if "running_var" in key:
311
+ this_param = torch.clip(torch.exp(this_param) - 0.05, min=0.001) * pre_mean
312
+ diction[key] = this_param
313
+ return diction
314
+
315
+
316
+ class ConditionalDataset(BaseDataset, ABC):
317
+ def _extract_condition(self, index: int):
318
+ name = self.checkpoint_list[index]
319
+ condition_list = os.path.basename(name).split("_")
320
+ return condition_list
321
+
322
+ def __getitem__(self, index):
323
+ index = index % self.real_length
324
+ diction = torch.load(self.checkpoint_list[index], map_location="cpu")
325
+ condition = self._extract_condition(index)
326
+ param = self.preprocess(diction)
327
+ return param, condition
dataset/downtask_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Code for segmentation is coming...
dataset/downtask_detection/test.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ source /path/to/miniconda3/bin/activate /path/to/miniconda3/envs/environment
4
+
5
+ CLUSTER=True \
6
+ DETECTRON2_DATASETS="/path/to/" \
7
+ PYTHONPATH="$(dirname $0)/Detection":$PYTHONPATH \
8
+ python $(dirname $0)/Detection/tools/lazyconfig_train_net.py --config-file $(dirname $0)/Detection/projects/ViTDet/configs/COCO/our_vit_b_100ep.py --finetune "VIT_BASE_IN21K" \
9
+ --num-gpus 1 \
10
+ --fulltune \
11
+ --eval-only "train.init_checkpoint='$1'"
dataset/downtask_dora_r16/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Wdecompose_target_modules": null,
3
+ "base_model_name_or_path": "yahma/llama-7b-hf",
4
+ "bias": "none",
5
+ "dora_simple": true,
6
+ "enable_lora": null,
7
+ "fan_in_fan_out": false,
8
+ "inference_mode": true,
9
+ "lora_alpha": 32,
10
+ "lora_dropout": 0.05,
11
+ "merge_weights": false,
12
+ "modules_to_save": null,
13
+ "peft_type": "DORA",
14
+ "r": 16,
15
+ "target_modules": [
16
+ "q_proj",
17
+ "k_proj",
18
+ "v_proj",
19
+ "up_proj",
20
+ "down_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }