movie-diffusion / inference.py
Anton Forsman
put in everything
8eb7fcd
raw
history blame
1.15 kB
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
def inference1():
# new image from web page
image = requests.get("https://picsum.photos/120/80").content
return Image.open(io.BytesIO(image))
def inference():
model = Unet(
image_channels=3,
)
model.load_state_dict(torch.load("./model_final.pt"))
diffusion = GaussianDiffusion(
model=model,
noise_steps=1000,
beta_0=1e-4,
beta_T=0.02,
image_size=(120, 80),
)
imageAPI = DiffusionImageAPI(diffusion)
images, versions = diffusion.sample(1)
images = []
for image in versions:
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
#print(len(images))
#print(images[0])
## make gif out of pillow images
#images[0].save('./gif_output/versions.gif',
# save_all=True,
# append_images=images[1:],
# duration=100,
# loop=0)
return images[-1]
if __name__ == "__main__":
inference().show()