Spaces:
Paused
Paused
File size: 1,639 Bytes
8eb7fcd 098fc8a 8eb7fcd 94539f6 06c31c1 cdd9a51 8eb7fcd 585cc65 36338f6 8eb7fcd 475978d 8eb7fcd 36338f6 098fc8a 36338f6 60e1bdf 8eb7fcd 475978d 8eb7fcd cdd9a51 06c31c1 8eb7fcd 585cc65 cdd9a51 8eb7fcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io
from unet import Unet, ConditionalUnet
from diffusion import GaussianDiffusion, DiffusionImageAPI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def inference1():
# new image from web page
image = requests.get("https://picsum.photos/120/80").content
return Image.open(io.BytesIO(image))
def inference(cond, x0=None, gif=False, callback=None):
model = Unet(
image_channels=3,
dropout=0.1,
)
model = ConditionalUnet(
unet=model,
num_classes=13,
)
model.load_state_dict(torch.load("./model_final.pt", map_location=device))
diffusion = GaussianDiffusion(
model=model,
noise_steps=1000,
beta_0=1e-4,
beta_T=0.02,
image_size=(192, 128),
)
if x0 is not None:
x0 = diffusion.normalize_image(x0)
x0 = x0.permute(2, 0, 1)
x0 = x0.unsqueeze(0)
model.to(device)
diffusion.to(device)
imageAPI = DiffusionImageAPI(diffusion)
new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback)
if gif:
images = []
for image in versions:
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
print(len(images))
print(images[0])
# make gif out of pillow images
images[0].save('./gif_output/versions.gif',
save_all=True,
append_images=images[1:],
duration=100,
loop=0)
return imageAPI.tensor_to_image(new_images.squeeze(0))
if __name__ == "__main__":
inference().show()
|