Blackroot commited on
Commit
1717ad7
·
verified ·
1 Parent(s): b7a0936

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ epoch_159.png filter=lfs diff=lfs merge=lfs -text
37
+ epoch_39.png filter=lfs diff=lfs merge=lfs -text
38
+ epoch_459.png filter=lfs diff=lfs merge=lfs -text
39
+ epoch_799.png filter=lfs diff=lfs merge=lfs -text
epoch_159.png ADDED

Git LFS Details

  • SHA256: a9dfd4c8bc0712fa7920c21ffacdc8bbc88fd8ea952f7f934d5e5c60d1f662f8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
epoch_39.png ADDED

Git LFS Details

  • SHA256: b2d498c4ab4bb9407244dd4cb39c45e5347752a545249127e0452e7dc9615357
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
epoch_459.png ADDED

Git LFS Details

  • SHA256: f7dfb4c4a845199bd4694b290dc6d8972280da5c33e4a5fc7e83aa46c6f3bf85
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
epoch_799.png ADDED

Git LFS Details

  • SHA256: ec639eeb8d56909e1b62df0f857bb8e160b0bd325d10b9e00ed1bc392f95c3a9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .uvit import AsymmetricResidualUDiT, xATGLU
2
+
3
+ __all__ = ['AsymmetricResidualUDiT', xATGLU]
models/uvit.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Changelog since original version:
6
+ # xATGLU instead of top linear in transformer block
7
+ # Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding.
8
+
9
+ # This architecture was my attempt at the following Simple Diffusion paper with some modifications:
10
+ # https://arxiv.org/pdf/2410.19324v1
11
+
12
+ # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
13
+ class xATGLU(nn.Module):
14
+ def __init__(self, input_dim, output_dim, bias=True):
15
+ super().__init__()
16
+ # GATE path | VALUE path
17
+ self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
18
+ nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
19
+
20
+ self.alpha = nn.Parameter(torch.zeros(1))
21
+ self.half_pi = torch.pi / 2
22
+ self.inv_pi = 1 / torch.pi
23
+
24
+ def forward(self, x):
25
+ projected = self.proj(x)
26
+ gate_path, value_path = projected.chunk(2, dim=-1)
27
+
28
+ # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
29
+ gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
30
+ expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
31
+
32
+ return expanded_gate * value_path # g(x) × y
33
+
34
+ class ResBlock(nn.Module):
35
+ def __init__(self, channels):
36
+ super().__init__()
37
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
38
+ self.norm1 = nn.GroupNorm(32, channels)
39
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
40
+ self.norm2 = nn.GroupNorm(32, channels)
41
+
42
+ self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
43
+
44
+ def forward(self, x):
45
+ h = self.conv1(F.silu(self.norm1(x)))
46
+ h = self.conv2(F.silu(self.norm2(h)))
47
+ return x + h * self.learned_residual_scale
48
+
49
+ class TransformerBlock(nn.Module):
50
+ def __init__(self, channels, num_heads=8):
51
+ super().__init__()
52
+ self.norm1 = nn.LayerNorm(channels)
53
+ self.norm2 = nn.LayerNorm(channels)
54
+
55
+ # Params recommended by TPA paper, seem to work fine.
56
+ self.attn = nn.MultiheadAttention(channels, num_heads)
57
+
58
+ self.mlp = nn.Sequential(
59
+ xATGLU(channels, 2 * channels, bias=False),
60
+ nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias
61
+ )
62
+
63
+ self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1)
64
+ self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1)
65
+
66
+ def forward(self, x):
67
+ # Input shape B C H W
68
+ b, c, h, w = x.shape
69
+
70
+ x = x.reshape(b, h * w, c) # [B, H*W, C]
71
+
72
+ # Pre-norm architecture, this was really helpful for network stability when using bf16
73
+ identity = x
74
+ x = self.norm1(x)
75
+ h_attn, _ = self.attn(x, x, x)
76
+ x = identity + h_attn * self.learned_residual_scale_attn
77
+
78
+ identity = x
79
+ x = self.norm2(x)
80
+ h_mlp = self.mlp(x)
81
+ x = identity + h_mlp * self.learned_residual_scale_mlp
82
+
83
+ # Reshape back to B C H W
84
+ x = x.permute(1, 2, 0).reshape(b, c, h, w)
85
+ return x
86
+
87
+ class LevelBlock(nn.Module):
88
+ def __init__(self, channels, num_blocks, block_type='res'):
89
+ super().__init__()
90
+ self.blocks = nn.ModuleList()
91
+ for _ in range(num_blocks):
92
+ if block_type == 'transformer':
93
+ self.blocks.append(TransformerBlock(channels))
94
+ else:
95
+ self.blocks.append(ResBlock(channels))
96
+
97
+ def forward(self, x):
98
+ for block in self.blocks:
99
+ x = block(x)
100
+ return x
101
+
102
+ class AsymmetricResidualUDiT(nn.Module):
103
+ def __init__(self,
104
+ in_channels=3, # Input color channels
105
+ base_channels=128, # Initial feature size, dramatically increases parameter size of network.
106
+ patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute.
107
+ num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase.
108
+ encoder_blocks=3, # Can be different number of blocks VS decoder_blocks
109
+ decoder_blocks=7, # Can be different number of blocks VS encoder_blocks
110
+ encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=)
111
+ decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=)
112
+ mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck.
113
+ ):
114
+ super().__init__()
115
+ self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
116
+ # Initial projection from image space
117
+ self.patch_embed = nn.Conv2d(in_channels, base_channels,
118
+ kernel_size=patch_size, stride=patch_size)
119
+
120
+ self.encoders = nn.ModuleList()
121
+ curr_channels = base_channels
122
+
123
+ for level in range(num_levels):
124
+ use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels
125
+
126
+ # Encoder blocks -- N = encoder_blocks
127
+ self.encoders.append(
128
+ LevelBlock(curr_channels, encoder_blocks, use_transformer)
129
+ )
130
+
131
+ # Each successive decoder halves the size of the feature space for each step, except for the last level.
132
+ if level < num_levels - 1:
133
+ self.encoders.append(
134
+ nn.Conv2d(curr_channels, curr_channels * 2, 1)
135
+ )
136
+ curr_channels *= 2
137
+
138
+ # Middle transformer blocks -- N = mid_blocks
139
+ self.middle = nn.ModuleList([
140
+ TransformerBlock(curr_channels) for _ in range(mid_blocks)
141
+ ])
142
+
143
+ # Create decoder levels
144
+ self.decoders = nn.ModuleList()
145
+
146
+ for level in range(num_levels):
147
+ use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder)
148
+
149
+ # Decoder blocks -- N = decoder_blocks
150
+ self.decoders.append(
151
+ LevelBlock(curr_channels, decoder_blocks, use_transformer)
152
+ )
153
+
154
+ # Each successive decoder halves the size of the feature space for each step, except for the last level.
155
+ if level < num_levels - 1:
156
+ self.decoders.append(
157
+ nn.Conv2d(curr_channels, curr_channels // 2, 1)
158
+ )
159
+ curr_channels //= 2
160
+
161
+ # Final projection back to image space
162
+ self.final_proj = nn.ConvTranspose2d(base_channels, in_channels,
163
+ kernel_size=patch_size, stride=patch_size)
164
+
165
+ def downsample(self, x):
166
+ return F.avg_pool2d(x, kernel_size=2)
167
+
168
+ def upsample(self, x):
169
+ return F.interpolate(x, scale_factor=2, mode='nearest')
170
+
171
+ def forward(self, x, t=None):
172
+ # x shape B C H W
173
+ # This patchifies our input, for example given an input shape like:
174
+ # From 2, 3, 256, 256
175
+ x = self.patch_embed(x)
176
+ # Our shape is now more channels and with smaller W and H
177
+ # To 2, 128, 64, 64
178
+
179
+
180
+ # *Per resolution e.g. per num_level resolution block more or less
181
+ # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x)
182
+ #
183
+ # Where
184
+ # 1. h = fd(x) : Encoder path processes input
185
+ # 2. D(h) : Downsample the encoded features
186
+ # 3. fm(D(h)) : Middle transformer blocks process downsampled features
187
+ # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection)
188
+ # 5. U(...) : Upsample the processed features
189
+ # 6. ... + h : Add back original encoder features (skip connection)
190
+ # 7. fu(...) : Decoder path processes the combined features
191
+
192
+ residuals = []
193
+ curr_res = x
194
+
195
+ # Encoder path (computing h = fd(x))
196
+ h = x
197
+ for i, blocks in enumerate(self.encoders):
198
+ if isinstance(blocks, LevelBlock):
199
+ h = blocks(h)
200
+ else:
201
+ # Save residual before downsampling
202
+ residuals.append(curr_res)
203
+ # Downsample and update current residual
204
+ h = self.downsample(blocks(h))
205
+ curr_res = h
206
+
207
+ # Middle blocks (fm)
208
+ x = h
209
+ for block in self.middle:
210
+ x = block(x)
211
+
212
+ # Subtract the residual at this level (D(h))
213
+ x = x - curr_res * self.learned_middle_residual_scale
214
+
215
+ # Decoder path (fu)
216
+ for i, blocks in enumerate(self.decoders):
217
+ if isinstance(blocks, LevelBlock):
218
+ x = blocks(x)
219
+ else:
220
+ # Channel reduction
221
+ x = blocks(x)
222
+ # Upsample
223
+ x = self.upsample(x)
224
+ # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape.
225
+ curr_res = residuals.pop()
226
+ x = x + curr_res * self.learned_middle_residual_scale
227
+
228
+ # Final projection
229
+ x = self.final_proj(x)
230
+
231
+ return x
step_799.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:142069210ea246387de5d7e00264185a9fba49e86984b6b008be3949098eb7ae
3
+ size 417394856
train.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ import torchvision.utils as vutils
7
+ from datasets import load_dataset, load_from_disk
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from safetensors.torch import save_file, load_file
11
+ import os, time
12
+ from models import AsymmetricResidualUDiT, xATGLU
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+ from torch.distributions import Normal
17
+ from schedulefree import AdamWScheduleFree
18
+ from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
19
+
20
+ # Changes
21
+ # MAE replace MSE
22
+ # Larger shampoo preconditioner step for stability
23
+ # Larger shampoo preconditioner dim 1024 -> 2048
24
+ # Commented out norm.
25
+
26
+ def preload_dataset(image_size=256, device="cuda", max_images=50000):
27
+ """Preload and cache the entire dataset in GPU memory"""
28
+ print("Loading and preprocessing dataset...")
29
+ dataset = load_dataset("jiovine/pixel-art-nouns-2k", split="train")
30
+ #dataset = load_dataset("reach-vb/pokemon-blip-captions", split="train")
31
+ #dataset = load_from_disk("./new_dataset")
32
+
33
+ transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ #transforms.Pad((35, 0), fill=0), # Add 35 pixels on each side horizontally (70 total to get from 186 to 256)
36
+ transforms.Resize((256, 256), antialias=True),
37
+ transforms.Lambda(lambda x: (x * 2) - 1) # Scale to [-1, 1]
38
+ ])
39
+
40
+ all_images = []
41
+
42
+ for i, example in enumerate(dataset):
43
+ if max_images and i >= max_images:
44
+ break
45
+
46
+ img_tensor = transform(example['image'])
47
+
48
+ all_images.extend([
49
+ img_tensor,
50
+ ])
51
+
52
+ # Stack entire dataset onto gpu
53
+ images_tensor = torch.stack(all_images).to(device)
54
+ print(f"Dataset loaded: {images_tensor.shape} ({images_tensor.element_size() * images_tensor.nelement() / 1024/1024:.2f} MB)")
55
+
56
+ return TensorDataset(images_tensor)
57
+
58
+ def count_parameters(model):
59
+ total_params = sum(p.numel() for p in model.parameters())
60
+ print(f'Total parameters: {total_params:,} ({total_params/1e6:.2f}M)')
61
+
62
+ def save_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
63
+ model_state = model.state_dict()
64
+ save_file(model_state, filename)
65
+
66
+ def load_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
67
+ model_state = load_file(filename)
68
+ model.load_state_dict(model_state)
69
+
70
+ # https://arxiv.org/abs/2210.02747
71
+ class OptimalTransportLinearFlowGenerator():
72
+ def __init__(self, sigma_min=0.001):
73
+ self.sigma_min = sigma_min
74
+
75
+ def loss(self, model, x1, device):
76
+ batch_size = x1.shape[0]
77
+ # Uniform Dist 0..1 -- t ~ U[0, 1]
78
+ t = torch.rand(batch_size, 1, 1, 1, device=device)
79
+
80
+ # Sample noise -- x0 ~ N[0, I]
81
+ x0 = torch.randn_like(x1)
82
+
83
+ # Compute OT conditional flow matching path interpolation
84
+
85
+ # My understanding of this process -- We start at some random time t (Per sample)
86
+ # We have a pure noise value at x0, which is a totally destroyed signal.
87
+ # We have the actual image as x1 which is a perfect signal.
88
+ # We are going to destroy an amount of the image equal to t% of the signal. So if t is 0.3 we're destroying about 30% of the signal(image)
89
+ # The final x_t represents our combined noisy singal, you can imagine 30% random noise overlayed onto the normal image.
90
+ # We calculate the shortest path between x0 and x1, a straight line segment (lets call it a displacement vector) in their respective space, conditioned on the timestep.
91
+ # We then try to predict the displacement vector where we provide our partially noisy signal and our conditioning timestep
92
+ # We check the prediction against the real displacement vector we calculated to see how good the prediction was. Then we back propogate, baby.
93
+
94
+ sigma_t = 1 - (1 - self.sigma_min) * t # As t increases this value decreases. This is almost 1 - t
95
+ mu_t = t * x1 # As t increases this increases.
96
+ x_t = sigma_t * x0 + mu_t # This is essentially a mixture of noise and signal ((1-t) * x0) + ((t) * x1)
97
+
98
+ # Compute target
99
+ target = x1 - (1 - self.sigma_min) * x0 # This is the target displacement vector (direction and magnitude) that we need to travel from x0 to x1.
100
+ v_t = model(x_t, t) # v_t is our displacement vector prediction
101
+
102
+ # Magnitude-corrected MSE
103
+ # The 69 factor helps with very small gradients, as this loss tends to be b/w [0..1], this rescales to something more like [0..69]
104
+ # Other values like 420 might lead to numerical instability if the loss is too large.
105
+ loss = F.mse_loss(v_t, target)*69 # Compare the displacement vector the network predicted to the actual displacement we calculated as mean absolute error.
106
+
107
+ return loss
108
+
109
+ def write_logs(writer, model, loss, batch_idx, epoch, epoch_time, batch_size, lr, log_gradients=True):
110
+ """
111
+ TensorBoard logging
112
+
113
+ Args:
114
+ writer: torch.utils.tensorboard.SummaryWriter instance
115
+ model: torch.nn.Module - the model being trained
116
+ loss: float or torch.Tensor - the loss value to log
117
+ batch_idx: int - current batch index
118
+ epoch: int - current epoch
119
+ epoch_time: float - time taken for epoch
120
+ batch_size: int - current batch size
121
+ lr: float - current learning rate
122
+ samples: Optional[torch.Tensor] - generated samples to log (only passed every 50 epochs)
123
+ log_gradients: bool - whether to log gradient norms
124
+ """
125
+ total_steps = epoch * batch_idx
126
+
127
+ writer.add_scalar('Loss/batch', loss, total_steps)
128
+ writer.add_scalar('Time/epoch', epoch_time, epoch)
129
+ writer.add_scalar('Training/batch_size', batch_size, epoch)
130
+ writer.add_scalar('Training/learning_rate', lr, epoch)
131
+
132
+ # Gradient logging
133
+ if log_gradients:
134
+ total_norm = 0.0
135
+ for p in model.parameters():
136
+ if p.grad is not None:
137
+ param_norm = p.grad.detach().data.norm(2)
138
+ total_norm += param_norm.item() ** 2
139
+ total_norm = total_norm ** 0.5
140
+ writer.add_scalar('Gradients/total_norm', total_norm, total_steps)
141
+
142
+ def train_udit_flow(num_epochs=1000, initial_batch_sizes=[8, 16, 32, 64, 128], epoch_batch_drop_at=40, device="cuda", dtype=torch.float32):
143
+ dataset = preload_dataset(device=device)
144
+ temp_loader = DataLoader(dataset, batch_size=initial_batch_sizes[0], shuffle=True)
145
+ first_batch = next(iter(temp_loader))
146
+ image_shape = first_batch[0].shape[1:]
147
+
148
+ writer = SummaryWriter('logs/current_run')
149
+
150
+ model = AsymmetricResidualUDiT(
151
+ in_channels=3,
152
+ base_channels=128,
153
+ num_levels=3,
154
+ patch_size=4,
155
+ encoder_blocks=3,
156
+ decoder_blocks=7,
157
+ encoder_transformer_thresh=2,
158
+ decoder_transformer_thresh=4,
159
+ mid_blocks=16
160
+ ).to(device).to(torch.float32)
161
+ model.train()
162
+ count_parameters(model)
163
+
164
+ # optimizer = AdamWScheduleFree(
165
+ # model.parameters(),
166
+ # lr=4e-5,
167
+ # warmup_steps=100
168
+ # )
169
+ # optimizer.train()
170
+
171
+ optimizer = DistributedShampoo(
172
+ model.parameters(),
173
+ lr=0.001,
174
+ betas=(0.9, 0.999),
175
+ epsilon=1e-10,
176
+ weight_decay=1e-05,
177
+ max_preconditioner_dim=2048,
178
+ precondition_frequency=100,
179
+ start_preconditioning_step=250,
180
+ use_decoupled_weight_decay=False,
181
+ grafting_config=AdamGraftingConfig(
182
+ beta2=0.999,
183
+ epsilon=1e-10,
184
+ ),
185
+ )
186
+
187
+ scaler = torch.amp.GradScaler("cuda")
188
+
189
+ scheduler = CosineAnnealingLR(
190
+ optimizer,
191
+ T_max=num_epochs,
192
+ eta_min=1e-5
193
+ )
194
+
195
+ current_batch_sizes = initial_batch_sizes.copy()
196
+ next_drop_epoch = epoch_batch_drop_at
197
+ interval_multiplier = 2
198
+
199
+ torch.set_float32_matmul_precision('high')
200
+ # torch.backends.cudnn.benchmark = True
201
+ # torch.backends.cuda.matmul.allow_fp16_accumulation = True
202
+
203
+ model = torch.compile(
204
+ model,
205
+ backend='inductor',
206
+ dynamic=False,
207
+ fullgraph=True,
208
+ options={
209
+ "epilogue_fusion": True,
210
+ "max_autotune": True,
211
+ "cuda.use_fast_math": True,
212
+ }
213
+ )
214
+
215
+ flow_transport = OptimalTransportLinearFlowGenerator(sigma_min=0.001)
216
+
217
+ current_batch_size = current_batch_sizes[-1]
218
+ dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True)
219
+
220
+ for epoch in range(num_epochs):
221
+ epoch_start_time = time.time()
222
+ total_loss = 0
223
+
224
+ # Batch size decay logic
225
+ # Geomtric growth, every X*N+(X-1*N+...) use the number batch size in the list.
226
+ if False:
227
+ if epoch > 0 and epoch == next_drop_epoch and len(current_batch_sizes) > 1:
228
+ current_batch_sizes.pop()
229
+ next_interval = epoch_batch_drop_at * interval_multiplier
230
+ next_drop_epoch += next_interval
231
+ interval_multiplier += 1
232
+ print(f"\nEpoch {epoch}: Reducing batch size to {current_batch_sizes[-1]}")
233
+ print(f"Next drop will occur at epoch {next_drop_epoch} (interval: {next_interval})")
234
+
235
+ curr_lr = optimizer.param_groups[0]['lr']
236
+
237
+ for batch_idx, batch in enumerate(dataloader):
238
+ optimizer.zero_grad()
239
+ with torch.autocast(device_type='cuda', dtype=dtype):
240
+ x1 = batch[0]
241
+ batch_size = x1.shape[0]
242
+
243
+ # x1 shape: B, C, H, W
244
+ loss = flow_transport.loss(model, x1, device)
245
+
246
+ scaler.scale(loss).backward()
247
+ scaler.unscale_(optimizer)
248
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
249
+ scaler.step(optimizer)
250
+ scaler.update()
251
+ total_loss += loss.item()
252
+
253
+ avg_loss = total_loss / len(dataloader)
254
+
255
+ epoch_time = time.time() - epoch_start_time
256
+ print(f"Epoch {epoch}, Took: {epoch_time:.2f}s, Batch Size: {current_batch_size}, "
257
+ f"Average Loss: {avg_loss:.4f}, Learning Rate: {curr_lr:.2e}")
258
+
259
+ write_logs(writer, model, avg_loss, batch_idx, epoch, epoch_time, current_batch_size, curr_lr)
260
+ if (epoch + 1) % 10 == 0:
261
+ with torch.amp.autocast('cuda', dtype=dtype):
262
+ sampling_start_time = time.time()
263
+ samples = sample(model, device=device, dtype=dtype)
264
+ os.makedirs("samples", exist_ok=True)
265
+ vutils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=4, padding=2)
266
+
267
+ sample_time = time.time() - sampling_start_time
268
+ print(f"Sampling took: {sample_time:.2f}s")
269
+
270
+ if (epoch + 1) % 50 == 0:
271
+ save_checkpoint(model, optimizer, f"step_{epoch}.safetensors")
272
+
273
+ scheduler.step()
274
+
275
+ return model
276
+
277
+ def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32):
278
+ with torch.amp.autocast('cuda', dtype=dtype):
279
+
280
+ x = torch.randn(n_samples, 3, image_size, image_size, device=device)
281
+ ts = torch.linspace(0, 1, n_steps, device=device)
282
+ dt = 1/n_steps
283
+
284
+ # Forward Euler Integration step 0..1
285
+ with torch.no_grad():
286
+ for i in range(len(ts)):
287
+ t = ts[i]
288
+ t_input = t.repeat(n_samples, 1, 1, 1)
289
+
290
+ v_t = model(x, t_input)
291
+
292
+ x = x + v_t * dt
293
+
294
+ return x.float()
295
+
296
+ if __name__ == "__main__":
297
+ device = "cuda" if torch.cuda.is_available() else "cpu"
298
+ print(f"Using device: {device}")
299
+
300
+ model = train_udit_flow(
301
+ device=device,
302
+ initial_batch_sizes=[16,32,64],
303
+ epoch_batch_drop_at=100,
304
+ dtype=torch.bfloat16
305
+ )
306
+
307
+ print("Training complete! Samples saved in 'samples' directory")