movie-diffusion / inference.py
Anton Forsman
new weights
17b3ae7
raw
history blame
1.64 kB
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io
from unet import Unet, ConditionalUnet
from diffusion import GaussianDiffusion, DiffusionImageAPI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def inference1():
# new image from web page
image = requests.get("https://picsum.photos/120/80").content
return Image.open(io.BytesIO(image))
def inference(cond, x0=None, gif=False, callback=None):
model = Unet(
image_channels=3,
dropout=0.1,
)
model = ConditionalUnet(
unet=model,
num_classes=13,
)
model.load_state_dict(torch.load("./model_final2.pt", map_location=device))
diffusion = GaussianDiffusion(
model=model,
noise_steps=1000,
beta_0=1e-4,
beta_T=0.02,
image_size=(192, 128),
)
if x0 is not None:
x0 = diffusion.normalize_image(x0)
x0 = x0.permute(2, 0, 1)
x0 = x0.unsqueeze(0)
model.to(device)
diffusion.to(device)
imageAPI = DiffusionImageAPI(diffusion)
new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback)
if gif:
images = []
for image in versions:
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
print(len(images))
print(images[0])
# make gif out of pillow images
images[0].save('./gif_output/versions.gif',
save_all=True,
append_images=images[1:],
duration=100,
loop=0)
return imageAPI.tensor_to_image(new_images.squeeze(0))
if __name__ == "__main__":
inference().show()