import torch import torch.optim as optim import torch.nn as nn from tqdm import tqdm import numpy as np from PIL import Image import requests import io from unet import Unet, ConditionalUnet from diffusion import GaussianDiffusion, DiffusionImageAPI device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def inference1(): # new image from web page image = requests.get("https://picsum.photos/120/80").content return Image.open(io.BytesIO(image)) def inference(cond, x0=None, gif=False, callback=None): model = Unet( image_channels=3, dropout=0.1, ) model = ConditionalUnet( unet=model, num_classes=13, ) model.load_state_dict(torch.load("./model_final.pt", map_location=device)) diffusion = GaussianDiffusion( model=model, noise_steps=1000, beta_0=1e-4, beta_T=0.02, image_size=(192, 128), ) if x0 is not None: x0 = diffusion.normalize_image(x0) x0 = x0.permute(2, 0, 1) x0 = x0.unsqueeze(0) model.to(device) diffusion.to(device) imageAPI = DiffusionImageAPI(diffusion) new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback) if gif: images = [] for image in versions: images.append(imageAPI.tensor_to_image(image.squeeze(0))) print(len(images)) print(images[0]) # make gif out of pillow images images[0].save('./gif_output/versions.gif', save_all=True, append_images=images[1:], duration=100, loop=0) return imageAPI.tensor_to_image(new_images.squeeze(0)) if __name__ == "__main__": inference().show()