File size: 1,639 Bytes
8eb7fcd
 
 
 
 
 
 
 
 
098fc8a
 
 
8eb7fcd
94539f6
06c31c1
cdd9a51
8eb7fcd
 
 
 
 
585cc65
36338f6
8eb7fcd
475978d
8eb7fcd
36338f6
 
098fc8a
36338f6
60e1bdf
8eb7fcd
 
 
 
 
 
475978d
8eb7fcd
 
cdd9a51
 
 
 
 
06c31c1
 
 
8eb7fcd
 
585cc65
cdd9a51
 
 
 
 
 
 
 
 
 
 
 
 
 
8eb7fcd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io

from unet import Unet, ConditionalUnet

from diffusion import GaussianDiffusion, DiffusionImageAPI

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def inference1():
  # new image from web page
  image = requests.get("https://picsum.photos/120/80").content
  return Image.open(io.BytesIO(image))

def inference(cond, x0=None, gif=False, callback=None):
  model = Unet(
    image_channels=3,
    dropout=0.1,
  )
  model = ConditionalUnet(
    unet=model,
    num_classes=13,
  )
  model.load_state_dict(torch.load("./model_final.pt", map_location=device))

  diffusion = GaussianDiffusion(
    model=model,
    noise_steps=1000,
    beta_0=1e-4,
    beta_T=0.02,
    image_size=(192, 128),
  )

  if x0 is not None:
    x0 = diffusion.normalize_image(x0)
    x0 = x0.permute(2, 0, 1)
    x0 = x0.unsqueeze(0)

  model.to(device)
  diffusion.to(device)

  imageAPI = DiffusionImageAPI(diffusion)

  new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback)
  if gif:
    images = []
    for image in versions:
      images.append(imageAPI.tensor_to_image(image.squeeze(0))) 

    print(len(images))
    print(images[0])
    # make gif out of pillow images
    images[0].save('./gif_output/versions.gif',
                   save_all=True,
                   append_images=images[1:],
                   duration=100,
                   loop=0)
  return imageAPI.tensor_to_image(new_images.squeeze(0))

if __name__ == "__main__":
  inference().show()