File size: 2,441 Bytes
8cdf719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from diffusers import DiffusionPipeline
import torch
from torchvision import transforms
from tqdm import tqdm


class DDPMSketch2ImgPipeline(DiffusionPipeline):
    # TODO: Move transforms to another class

    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    def __call__(self, sketch, num_inference_step=1000, tqdm_leave=True):
        # sketch : PIL
        # returl : PIL

        sketch = transforms.functional.pil_to_tensor(sketch).float()
        sketch = self.normalize(sketch).to(self.device)
        sketch = sketch.unsqueeze(0)

        image = self.sample(sketch, num_inference_step, tqdm_leave)

        image = image.squeeze(0)
        image = self.denormalize(image)
        image = self.denormalized_tensor_to_pil(image)

        return image

    def sample(self, transformed_sketch, num_inference_step, tqdm_leave=True):
        assert (
            len(transformed_sketch.shape) == 4
        ), f"(bs, c, h, w) but {transformed_sketch.shape}"

        # Is this the right place to set timesteps?
        self.scheduler.set_timesteps(num_inference_step, device=self.device)

        s = transformed_sketch.shape
        # Assume image's channels == out_channels
        image = torch.randn((s[0], self.unet.config["out_channels"], s[2], s[3])).to(
            self.device
        )

        for t in tqdm(self.scheduler.timesteps, leave=tqdm_leave):
            model_input = torch.concat([image, transformed_sketch], dim=1).to(
                self.device
            )
            with torch.no_grad():
                model_output = self.unet(model_input, t).sample
            image = self.scheduler.step(model_output, t, image).prev_sample

        return image

    def denormalized_tensor_to_pil(self, tensor):
        assert len(tensor.shape) == 3, f"(c, h, w) but {tensor.shape}"

        tensor = tensor.cpu().clip(0, 255).to(torch.uint8)
        pil = transforms.functional.to_pil_image(tensor)
        return pil

    def normalize(self, x):
        assert x.dtype == torch.float
        # map x to -1 < x < 1
        # I'm doing normalization with zero understanding :o
        x = x / 255.0
        x = transforms.Normalize([0.5], [0.5])(x)
        return x

    def denormalize(self, x):
        assert x.dtype == torch.float
        x = x * 0.5 + 0.5  # map from (-1, 1) back to (0, 1)
        x = x * 255.0
        return x