File size: 2,989 Bytes
4a9ad28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torchvision.utils as vutils
from models import AsymmetricResidualUDiT
from safetensors.torch import load_file
import os
import argparse
from typing import Optional

def load_checkpoint(model: nn.Module, checkpoint_path: str) -> None:
    state_dict = load_file(checkpoint_path)
    
    # The training was done via torch compile which prefixes the model with this for whatever reason.
    # Handle compiled model state dict by removing '_orig_mod.' prefix
    if all(k.startswith('_orig_mod.') for k in state_dict.keys()):
        state_dict = {k[10:]: v for k, v in state_dict.items()}
        
    model.load_state_dict(state_dict)
    
def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32):
    with torch.amp.autocast('cuda', dtype=dtype):
        x = torch.randn(n_samples, 3, image_size, image_size, device=device)
        ts = torch.linspace(0, 1, n_steps, device=device)
        dt = 1/n_steps
        
        # Forward Euler Integration step 0..1
        with torch.no_grad():
            for i in range(len(ts)):
                t = ts[i]
                t_input = t.repeat(n_samples, 1, 1, 1)
                
                v_t = model(x, t_input)
                x = x + v_t * dt
    
    return x.float()

def main():
    parser = argparse.ArgumentParser(description="Generate samples from a trained UDiT model")
    parser.add_argument("checkpoint", type=str, help="Path to the model checkpoint (.safetensors)")
    parser.add_argument("--samples", type=int, default=16, help="Number of samples to generate")
    parser.add_argument("--steps", type=int, default=50, help="Number of sampling steps")
    parser.add_argument("--output", type=str, default="output.png", help="Output filename")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                      help="Device to run inference on (cuda/cpu)")
    args = parser.parse_args()
    device = args.device
    
    model = AsymmetricResidualUDiT(
        in_channels=3,
        base_channels=128,
        num_levels=3,
        patch_size=4,
        encoder_blocks=3,
        decoder_blocks=7,
        encoder_transformer_thresh=2,
        decoder_transformer_thresh=4,
        mid_blocks=8
    ).to(device)
    
    # Load state dict into model
    load_checkpoint(model, args.checkpoint)
    model.eval()
    
    # Generate samples
    print(f"Generating {args.samples} samples with {args.steps} steps...")
    with torch.no_grad():
        samples = sample(
            model,
            n_samples=args.samples,
            n_steps=args.steps,
            device=args.device,
            dtype=torch.float32
        )
    
    # Save samples
    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
    vutils.save_image(samples, args.output, nrow=4, padding=2)
    print(f"Samples saved to {args.output}")

if __name__ == "__main__":
    main()