Spaces:
Paused
Paused
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from tqdm import tqdm | |
import numpy as np | |
from PIL import Image | |
import requests | |
import io | |
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI | |
def inference1(): | |
# new image from web page | |
image = requests.get("https://picsum.photos/120/80").content | |
return Image.open(io.BytesIO(image)) | |
def inference(): | |
model = Unet( | |
image_channels=3, | |
) | |
model.load_state_dict(torch.load("./model_final.pt")) | |
diffusion = GaussianDiffusion( | |
model=model, | |
noise_steps=1000, | |
beta_0=1e-4, | |
beta_T=0.02, | |
image_size=(120, 80), | |
) | |
imageAPI = DiffusionImageAPI(diffusion) | |
images, versions = diffusion.sample(1) | |
images = [] | |
for image in versions: | |
images.append(imageAPI.tensor_to_image(image.squeeze(0))) | |
#print(len(images)) | |
#print(images[0]) | |
## make gif out of pillow images | |
#images[0].save('./gif_output/versions.gif', | |
# save_all=True, | |
# append_images=images[1:], | |
# duration=100, | |
# loop=0) | |
return images[-1] | |
if __name__ == "__main__": | |
inference().show() | |