import torch import torch.nn as nn import torchvision.utils as vutils from models import AsymmetricResidualUDiT from safetensors.torch import load_file import os import argparse from typing import Optional def load_checkpoint(model: nn.Module, checkpoint_path: str) -> None: state_dict = load_file(checkpoint_path) # The training was done via torch compile which prefixes the model with this for whatever reason. # Handle compiled model state dict by removing '_orig_mod.' prefix if all(k.startswith('_orig_mod.') for k in state_dict.keys()): state_dict = {k[10:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32): with torch.amp.autocast('cuda', dtype=dtype): x = torch.randn(n_samples, 3, image_size, image_size, device=device) ts = torch.linspace(0, 1, n_steps, device=device) dt = 1/n_steps # Forward Euler Integration step 0..1 with torch.no_grad(): for i in range(len(ts)): t = ts[i] t_input = t.repeat(n_samples, 1, 1, 1) v_t = model(x, t_input) x = x + v_t * dt return x.float() def main(): parser = argparse.ArgumentParser(description="Generate samples from a trained UDiT model") parser.add_argument("checkpoint", type=str, help="Path to the model checkpoint (.safetensors)") parser.add_argument("--samples", type=int, default=16, help="Number of samples to generate") parser.add_argument("--steps", type=int, default=50, help="Number of sampling steps") parser.add_argument("--output", type=str, default="output.png", help="Output filename") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run inference on (cuda/cpu)") args = parser.parse_args() device = args.device model = AsymmetricResidualUDiT( in_channels=3, base_channels=128, num_levels=3, patch_size=4, encoder_blocks=3, decoder_blocks=7, encoder_transformer_thresh=2, decoder_transformer_thresh=4, mid_blocks=8 ).to(device) # Load state dict into model load_checkpoint(model, args.checkpoint) model.eval() # Generate samples print(f"Generating {args.samples} samples with {args.steps} steps...") with torch.no_grad(): samples = sample( model, n_samples=args.samples, n_steps=args.steps, device=args.device, dtype=torch.float32 ) # Save samples os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) vutils.save_image(samples, args.output, nrow=4, padding=2) print(f"Samples saved to {args.output}") if __name__ == "__main__": main()