jmemon commited on
Commit
92697e6
·
1 Parent(s): 43f4b92

Files: Epoch -1

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
__pycache__/config.cpython-310.pyc ADDED
Binary file (866 Bytes). View file
 
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class TrainingConfig:
6
+ image_size = 128 # the generated image resolution
7
+ train_batch_size = 4
8
+ eval_batch_size = 4 # how many images to sample during evaluation
9
+ num_epochs = 50
10
+ gradient_accumulation_steps = 1
11
+ learning_rate = 1e-4
12
+ lr_warmup_steps = 500
13
+ save_image_epochs = 1
14
+ save_model_epochs = 3
15
+ mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision
16
+ output_dir = 'ddpm-paintings-128-finetuned-cifar10' # the model name locally and on the HF Hub
17
+
18
+ push_to_hub = True # whether to upload the saved model to the HF Hub
19
+ hub_model_id = 'jmemon/ddpm-paintings-128-finetuned-cifar10' # the name of the repository to create on the HF Hub
20
+ hub_private_repo = False
21
+ overwrite_output_dir = True # overwrite the old model when re-running the notebook
22
+ seed = 0
ddpm-paintings-128-finetuned-cifar10/logs/ddpm-paintings-128-finetuned-cifar10/events.out.tfevents.1701696166.coffee.14798.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2fbf486914eb9ed63fdbcf637c2874ca608a32f1ec948a4567e37a8e2e412f3
3
+ size 427942
ddpm-paintings-128-finetuned-cifar10/logs/ddpm-paintings-128-finetuned-cifar10/events.out.tfevents.1701704512.coffee.17529.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c99374f9a97092f9da24ebc79289a21d0e48e40598c677c8206b1f453c2b050
3
+ size 88
ddpm-paintings-128-finetuned-cifar10/samples/0000.png ADDED
ddpm-paintings-128-finetuned-cifar10/samples/0001.png ADDED
ddpm-paintings-128-finetuned-cifar10/samples/0002.png ADDED
main.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import PIL
3
+ from tqdm import tqdm
4
+
5
+ from accelerate import Accelerator
6
+ from datasets import load_dataset
7
+ from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler
8
+ from diffusers.optimization import get_cosine_schedule_with_warmup
9
+ from diffusers.utils import make_image_grid
10
+ from huggingface_hub import create_repo, upload_folder
11
+ from peft import LoraConfig, get_peft_model
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torchvision import transforms
15
+
16
+ from config import TrainingConfig
17
+
18
+
19
+ """
20
+ Or diffusion for simple images (cifar10 or fashion-mnist or mnist) and explore subtly different
21
+ x_T's and what the output is.
22
+
23
+ Denoise each x_T multiple times to get a better picture of the distribution.
24
+ Maybe use a set sequence of seeds for every denoising run (torch.Generator(seed=__)).
25
+
26
+ Inter-concept space. Conciousness.
27
+ """
28
+
29
+
30
+ def evaluate(config, epoch, pipeline):
31
+ # Sample some images from random noise (this is the backward diffusion process).
32
+ # The default pipeline output type is `List[PIL.Image]`
33
+ images = pipeline(
34
+ batch_size=config.eval_batch_size,
35
+ generator=torch.manual_seed(config.seed),
36
+ num_inference_steps=50
37
+ ).images
38
+
39
+ # Make a grid out of the images
40
+ image_grid = make_image_grid(images, rows=2, cols=2)
41
+
42
+ # Save the images
43
+ test_dir = Path(config.output_dir) / 'samples'
44
+ test_dir.mkdir(exist_ok=True)
45
+ image_grid.save(test_dir / f'{epoch:04d}.png')
46
+
47
+
48
+ def print_trainable_parameters(model):
49
+ trainable_params = 0
50
+ all_param = 0
51
+ for _, param in model.named_parameters():
52
+ all_param += param.numel()
53
+ if param.requires_grad:
54
+ trainable_params += param.numel()
55
+ print(
56
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
57
+ )
58
+
59
+
60
+ if __name__ == '__main__':
61
+
62
+ config = TrainingConfig()
63
+ config.dataset_name = 'keremberke/painting-style-classification'
64
+
65
+ ds_dict = load_dataset(config.dataset_name, name='full')
66
+
67
+ preprocess = transforms.Compose([
68
+ transforms.Resize((config.image_size, config.image_size)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
71
+ ])
72
+
73
+ def transform(examples):
74
+ return {
75
+ 'images': [preprocess(img.convert('RGB')) for img in examples['image']]
76
+ }
77
+
78
+ ds_dict.set_transform(transform) # automatically applies preprocessing to samples as we load them
79
+
80
+ train_dataloader = torch.utils.data.DataLoader(ds_dict['train'], batch_size=config.train_batch_size, shuffle=True)
81
+ valid_dataloader = torch.utils.data.DataLoader(ds_dict['validation'], batch_size=config.eval_batch_size, shuffle=False)
82
+ test_dataloader = torch.utils.data.DataLoader(ds_dict['test'], batch_size=config.eval_batch_size, shuffle=False)
83
+
84
+ """pipe = DDPMPipeline.from_pretrained(
85
+ 'google/ddpm-celebahq-256',
86
+ #use_safetensors=True
87
+ ).to('mps')
88
+ pipe.enable_attention_slicing()"""
89
+
90
+ unet = UNet2DModel.from_pretrained(
91
+ 'google/ddpm-celebahq-256',
92
+ safetensors=True
93
+ ).to('mps')
94
+
95
+ scheduler = DDPMScheduler.from_pretrained(
96
+ 'google/ddpm-celebahq-256'
97
+ )
98
+
99
+ lora_config = LoraConfig(r=8, lora_alpha=8, target_modules=['to_k','to_v'], lora_dropout=0.1, bias='none')
100
+ lora_unet = get_peft_model(unet, lora_config)
101
+
102
+ print_trainable_parameters(lora_unet)
103
+
104
+ optimizer = torch.optim.AdamW(lora_unet.parameters(), lr=config.learning_rate)
105
+ lr_scheduler = get_cosine_schedule_with_warmup(
106
+ optimizer=optimizer,
107
+ num_warmup_steps=config.lr_warmup_steps,
108
+ num_training_steps=(len(train_dataloader) * config.num_epochs)
109
+ )
110
+
111
+ accelerator = Accelerator(
112
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
113
+ mixed_precision=config.mixed_precision,
114
+ log_with='tensorboard',
115
+ project_dir=Path(config.output_dir) / 'logs'
116
+ )
117
+
118
+ if accelerator.is_main_process:
119
+ if config.push_to_hub:
120
+ repo_id = create_repo(repo_id=config.hub_model_id, exist_ok=True).repo_id
121
+
122
+ accelerator.init_trackers('ddpm-paintings-128-finetuned-cifar10')
123
+
124
+
125
+ epoch = -1
126
+ pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet), scheduler=scheduler)
127
+
128
+ upload_folder(
129
+ repo_id=repo_id,
130
+ folder_path=Path(config.output_dir).parent,
131
+ commit_message=f'Files: Epoch {epoch}',
132
+ ignore_patterns=['step_*', 'epoch_*'],
133
+ token='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'
134
+ )
135
+
136
+ pipeline.push_to_hub(
137
+ repo_id=config.hub_model_id,
138
+ commit_message=f'Model: Epoch {epoch}',
139
+ token='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'
140
+ )
141
+
142
+ exit()
143
+ global_step = 0
144
+
145
+ for epoch in range(config.num_epochs):
146
+ pbar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
147
+ pbar.set_description(f'Epoch {epoch}')
148
+
149
+ for idx, batch in enumerate(train_dataloader):
150
+ clean_images = batch['images'].to('mps')
151
+
152
+ noise = torch.randn(clean_images.shape, device=clean_images.device)
153
+ bs = clean_images.shape[0]
154
+
155
+ ts = torch.randint(0, scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, dtype=torch.int64)
156
+
157
+ noisy_images = scheduler.add_noise(clean_images, noise, ts)
158
+
159
+ with accelerator.accumulate(unet):
160
+ noise_pred = lora_unet(noisy_images, ts, return_dict=False)[0]
161
+ loss = F.mse_loss(noise_pred, noise)
162
+ accelerator.backward(loss)
163
+
164
+ accelerator.clip_grad_norm_(lora_unet.parameters(), 1.0)
165
+ optimizer.step()
166
+ lr_scheduler.step()
167
+ optimizer.zero_grad()
168
+
169
+ logs = {'loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0], 'step': global_step}
170
+ pbar.update(1)
171
+ pbar.set_postfix(loss=logs['loss'], step=idx + 1)
172
+ accelerator.log(logs, step=global_step)
173
+ global_step += 1
174
+
175
+ pbar.close()
176
+
177
+ if accelerator.is_main_process:
178
+ pipeline = DDPMPipeline(unet=accelerator.unwrap_model(lora_unet), scheduler=scheduler)
179
+
180
+ if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
181
+ # Save some images for model trained at end of epoch
182
+ evaluate(config, epoch, pipeline)
183
+
184
+ if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
185
+ if config.push_to_hub:
186
+ upload_folder(
187
+ repo_id=repo_id,
188
+ folder_path=Path(config.output_dir).parent,
189
+ commit_message=f'Files: Epoch {epoch}',
190
+ ignore_patterns=['step_*', 'epoch_*'],
191
+ token='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'
192
+ )
193
+
194
+ pipeline.push_to_hub(
195
+ repo_id=config.hub_model_id,
196
+ commit_message=f'Model: Epoch {epoch}',
197
+ token='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'
198
+ )
199
+ else:
200
+ pipeline.save_pretrained(config.output_dir)
unet.txt ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ UNet2DModel(
2
+ (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3
+ (time_proj): Timesteps()
4
+ (time_embedding): TimestepEmbedding(
5
+ (linear_1): LoRACompatibleLinear(in_features=128, out_features=512, bias=True)
6
+ (act): SiLU()
7
+ (linear_2): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
8
+ )
9
+ (down_blocks): ModuleList(
10
+ (0-1): 2 x DownBlock2D(
11
+ (resnets): ModuleList(
12
+ (0-1): 2 x ResnetBlock2D(
13
+ (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
14
+ (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
15
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=128, bias=True)
16
+ (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
17
+ (dropout): Dropout(p=0.0, inplace=False)
18
+ (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
19
+ (nonlinearity): SiLU()
20
+ )
21
+ )
22
+ (downsamplers): ModuleList(
23
+ (0): Downsample2D(
24
+ (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
25
+ )
26
+ )
27
+ )
28
+ (2): DownBlock2D(
29
+ (resnets): ModuleList(
30
+ (0): ResnetBlock2D(
31
+ (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
32
+ (conv1): LoRACompatibleConv(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
33
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
34
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
35
+ (dropout): Dropout(p=0.0, inplace=False)
36
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
37
+ (nonlinearity): SiLU()
38
+ (conv_shortcut): LoRACompatibleConv(128, 256, kernel_size=(1, 1), stride=(1, 1))
39
+ )
40
+ (1): ResnetBlock2D(
41
+ (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
42
+ (conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
43
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
44
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
45
+ (dropout): Dropout(p=0.0, inplace=False)
46
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
47
+ (nonlinearity): SiLU()
48
+ )
49
+ )
50
+ (downsamplers): ModuleList(
51
+ (0): Downsample2D(
52
+ (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(2, 2))
53
+ )
54
+ )
55
+ )
56
+ (3): DownBlock2D(
57
+ (resnets): ModuleList(
58
+ (0-1): 2 x ResnetBlock2D(
59
+ (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
60
+ (conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
61
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
62
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
63
+ (dropout): Dropout(p=0.0, inplace=False)
64
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
65
+ (nonlinearity): SiLU()
66
+ )
67
+ )
68
+ (downsamplers): ModuleList(
69
+ (0): Downsample2D(
70
+ (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(2, 2))
71
+ )
72
+ )
73
+ )
74
+ (4): AttnDownBlock2D(
75
+ (attentions): ModuleList(
76
+ (0-1): 2 x Attention(
77
+ (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
78
+ (to_q): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
79
+ (to_k): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
80
+ (to_v): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
81
+ (to_out): ModuleList(
82
+ (0): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
83
+ (1): Dropout(p=0.0, inplace=False)
84
+ )
85
+ )
86
+ )
87
+ (resnets): ModuleList(
88
+ (0): ResnetBlock2D(
89
+ (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
90
+ (conv1): LoRACompatibleConv(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
91
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
92
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
93
+ (dropout): Dropout(p=0.0, inplace=False)
94
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
95
+ (nonlinearity): SiLU()
96
+ (conv_shortcut): LoRACompatibleConv(256, 512, kernel_size=(1, 1), stride=(1, 1))
97
+ )
98
+ (1): ResnetBlock2D(
99
+ (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
100
+ (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
101
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
102
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
103
+ (dropout): Dropout(p=0.0, inplace=False)
104
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
105
+ (nonlinearity): SiLU()
106
+ )
107
+ )
108
+ (downsamplers): ModuleList(
109
+ (0): Downsample2D(
110
+ (conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(2, 2))
111
+ )
112
+ )
113
+ )
114
+ (5): DownBlock2D(
115
+ (resnets): ModuleList(
116
+ (0-1): 2 x ResnetBlock2D(
117
+ (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
118
+ (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
119
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
120
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
121
+ (dropout): Dropout(p=0.0, inplace=False)
122
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
123
+ (nonlinearity): SiLU()
124
+ )
125
+ )
126
+ )
127
+ )
128
+ (up_blocks): ModuleList(
129
+ (0): UpBlock2D(
130
+ (resnets): ModuleList(
131
+ (0-2): 3 x ResnetBlock2D(
132
+ (norm1): GroupNorm(32, 1024, eps=1e-06, affine=True)
133
+ (conv1): LoRACompatibleConv(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
134
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
135
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
136
+ (dropout): Dropout(p=0.0, inplace=False)
137
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
138
+ (nonlinearity): SiLU()
139
+ (conv_shortcut): LoRACompatibleConv(1024, 512, kernel_size=(1, 1), stride=(1, 1))
140
+ )
141
+ )
142
+ (upsamplers): ModuleList(
143
+ (0): Upsample2D(
144
+ (conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
145
+ )
146
+ )
147
+ )
148
+ (1): AttnUpBlock2D(
149
+ (attentions): ModuleList(
150
+ (0-2): 3 x Attention(
151
+ (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
152
+ (to_q): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
153
+ (to_k): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
154
+ (to_v): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
155
+ (to_out): ModuleList(
156
+ (0): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
157
+ (1): Dropout(p=0.0, inplace=False)
158
+ )
159
+ )
160
+ )
161
+ (resnets): ModuleList(
162
+ (0-1): 2 x ResnetBlock2D(
163
+ (norm1): GroupNorm(32, 1024, eps=1e-06, affine=True)
164
+ (conv1): LoRACompatibleConv(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
165
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
166
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
167
+ (dropout): Dropout(p=0.0, inplace=False)
168
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
169
+ (nonlinearity): SiLU()
170
+ (conv_shortcut): LoRACompatibleConv(1024, 512, kernel_size=(1, 1), stride=(1, 1))
171
+ )
172
+ (2): ResnetBlock2D(
173
+ (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)
174
+ (conv1): LoRACompatibleConv(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
175
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
176
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
177
+ (dropout): Dropout(p=0.0, inplace=False)
178
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
179
+ (nonlinearity): SiLU()
180
+ (conv_shortcut): LoRACompatibleConv(768, 512, kernel_size=(1, 1), stride=(1, 1))
181
+ )
182
+ )
183
+ (upsamplers): ModuleList(
184
+ (0): Upsample2D(
185
+ (conv): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
186
+ )
187
+ )
188
+ )
189
+ (2): UpBlock2D(
190
+ (resnets): ModuleList(
191
+ (0): ResnetBlock2D(
192
+ (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)
193
+ (conv1): LoRACompatibleConv(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
194
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
195
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
196
+ (dropout): Dropout(p=0.0, inplace=False)
197
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
198
+ (nonlinearity): SiLU()
199
+ (conv_shortcut): LoRACompatibleConv(768, 256, kernel_size=(1, 1), stride=(1, 1))
200
+ )
201
+ (1-2): 2 x ResnetBlock2D(
202
+ (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
203
+ (conv1): LoRACompatibleConv(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
204
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
205
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
206
+ (dropout): Dropout(p=0.0, inplace=False)
207
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
208
+ (nonlinearity): SiLU()
209
+ (conv_shortcut): LoRACompatibleConv(512, 256, kernel_size=(1, 1), stride=(1, 1))
210
+ )
211
+ )
212
+ (upsamplers): ModuleList(
213
+ (0): Upsample2D(
214
+ (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
215
+ )
216
+ )
217
+ )
218
+ (3): UpBlock2D(
219
+ (resnets): ModuleList(
220
+ (0-1): 2 x ResnetBlock2D(
221
+ (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
222
+ (conv1): LoRACompatibleConv(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
223
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
224
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
225
+ (dropout): Dropout(p=0.0, inplace=False)
226
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
227
+ (nonlinearity): SiLU()
228
+ (conv_shortcut): LoRACompatibleConv(512, 256, kernel_size=(1, 1), stride=(1, 1))
229
+ )
230
+ (2): ResnetBlock2D(
231
+ (norm1): GroupNorm(32, 384, eps=1e-06, affine=True)
232
+ (conv1): LoRACompatibleConv(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
233
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=256, bias=True)
234
+ (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
235
+ (dropout): Dropout(p=0.0, inplace=False)
236
+ (conv2): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
237
+ (nonlinearity): SiLU()
238
+ (conv_shortcut): LoRACompatibleConv(384, 256, kernel_size=(1, 1), stride=(1, 1))
239
+ )
240
+ )
241
+ (upsamplers): ModuleList(
242
+ (0): Upsample2D(
243
+ (conv): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
244
+ )
245
+ )
246
+ )
247
+ (4): UpBlock2D(
248
+ (resnets): ModuleList(
249
+ (0): ResnetBlock2D(
250
+ (norm1): GroupNorm(32, 384, eps=1e-06, affine=True)
251
+ (conv1): LoRACompatibleConv(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
252
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=128, bias=True)
253
+ (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
254
+ (dropout): Dropout(p=0.0, inplace=False)
255
+ (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
256
+ (nonlinearity): SiLU()
257
+ (conv_shortcut): LoRACompatibleConv(384, 128, kernel_size=(1, 1), stride=(1, 1))
258
+ )
259
+ (1-2): 2 x ResnetBlock2D(
260
+ (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
261
+ (conv1): LoRACompatibleConv(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
262
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=128, bias=True)
263
+ (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
264
+ (dropout): Dropout(p=0.0, inplace=False)
265
+ (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
266
+ (nonlinearity): SiLU()
267
+ (conv_shortcut): LoRACompatibleConv(256, 128, kernel_size=(1, 1), stride=(1, 1))
268
+ )
269
+ )
270
+ (upsamplers): ModuleList(
271
+ (0): Upsample2D(
272
+ (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
273
+ )
274
+ )
275
+ )
276
+ (5): UpBlock2D(
277
+ (resnets): ModuleList(
278
+ (0-2): 3 x ResnetBlock2D(
279
+ (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
280
+ (conv1): LoRACompatibleConv(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
281
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=128, bias=True)
282
+ (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
283
+ (dropout): Dropout(p=0.0, inplace=False)
284
+ (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
285
+ (nonlinearity): SiLU()
286
+ (conv_shortcut): LoRACompatibleConv(256, 128, kernel_size=(1, 1), stride=(1, 1))
287
+ )
288
+ )
289
+ )
290
+ )
291
+ (mid_block): UNetMidBlock2D(
292
+ (attentions): ModuleList(
293
+ (0): Attention(
294
+ (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
295
+ (to_q): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
296
+ (to_k): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
297
+ (to_v): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
298
+ (to_out): ModuleList(
299
+ (0): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
300
+ (1): Dropout(p=0.0, inplace=False)
301
+ )
302
+ )
303
+ )
304
+ (resnets): ModuleList(
305
+ (0-1): 2 x ResnetBlock2D(
306
+ (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
307
+ (conv1): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
308
+ (time_emb_proj): LoRACompatibleLinear(in_features=512, out_features=512, bias=True)
309
+ (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
310
+ (dropout): Dropout(p=0.0, inplace=False)
311
+ (conv2): LoRACompatibleConv(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
312
+ (nonlinearity): SiLU()
313
+ )
314
+ )
315
+ )
316
+ (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
317
+ (conv_act): SiLU()
318
+ (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
319
+ )