Spaces:
Build error
Build error
File size: 2,441 Bytes
8cdf719 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from diffusers import DiffusionPipeline
import torch
from torchvision import transforms
from tqdm import tqdm
class DDPMSketch2ImgPipeline(DiffusionPipeline):
# TODO: Move transforms to another class
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, sketch, num_inference_step=1000, tqdm_leave=True):
# sketch : PIL
# returl : PIL
sketch = transforms.functional.pil_to_tensor(sketch).float()
sketch = self.normalize(sketch).to(self.device)
sketch = sketch.unsqueeze(0)
image = self.sample(sketch, num_inference_step, tqdm_leave)
image = image.squeeze(0)
image = self.denormalize(image)
image = self.denormalized_tensor_to_pil(image)
return image
def sample(self, transformed_sketch, num_inference_step, tqdm_leave=True):
assert (
len(transformed_sketch.shape) == 4
), f"(bs, c, h, w) but {transformed_sketch.shape}"
# Is this the right place to set timesteps?
self.scheduler.set_timesteps(num_inference_step, device=self.device)
s = transformed_sketch.shape
# Assume image's channels == out_channels
image = torch.randn((s[0], self.unet.config["out_channels"], s[2], s[3])).to(
self.device
)
for t in tqdm(self.scheduler.timesteps, leave=tqdm_leave):
model_input = torch.concat([image, transformed_sketch], dim=1).to(
self.device
)
with torch.no_grad():
model_output = self.unet(model_input, t).sample
image = self.scheduler.step(model_output, t, image).prev_sample
return image
def denormalized_tensor_to_pil(self, tensor):
assert len(tensor.shape) == 3, f"(c, h, w) but {tensor.shape}"
tensor = tensor.cpu().clip(0, 255).to(torch.uint8)
pil = transforms.functional.to_pil_image(tensor)
return pil
def normalize(self, x):
assert x.dtype == torch.float
# map x to -1 < x < 1
# I'm doing normalization with zero understanding :o
x = x / 255.0
x = transforms.Normalize([0.5], [0.5])(x)
return x
def denormalize(self, x):
assert x.dtype == torch.float
x = x * 0.5 + 0.5 # map from (-1, 1) back to (0, 1)
x = x * 255.0
return x
|