IzumiSatoshi commited on
Commit
8cdf719
·
1 Parent(s): 15bfd1a

Rename Sketch2ImgPipeline.py to pipeline_ddpm_sketch2img.py

Browse files
Sketch2ImgPipeline.py DELETED
@@ -1,52 +0,0 @@
1
- import torchvision
2
- import torch
3
- from tqdm import tqdm
4
- from diffusers.pipeline_utils import DiffusionPipeline
5
-
6
-
7
- class Sketch2ImgPipeline(DiffusionPipeline):
8
- def __init__(self, unet, scheduler):
9
- super().__init__()
10
- self.register_modules(
11
- unet=unet,
12
- scheduler=scheduler,
13
- )
14
-
15
- @torch.no_grad()
16
- def __call__(
17
- self,
18
- sketches,
19
- num_inference_step=10,
20
- ):
21
- # input sketch : numpy array(batch_size, 1, 28, 28), 0~255
22
- # return : numpy array(batch_size, 1, 28, 28), -1 ~ 1
23
- # TODO: map output to 0 ~ 255
24
-
25
- sketches = torch.from_numpy(sketches).float()
26
- sketches = self.normalize(sketches)
27
-
28
- self.scheduler.set_timesteps(num_inference_step, device=self.device)
29
-
30
- sketches = sketches.to(self.device)
31
- samples = torch.randn_like(sketches).to(self.device)
32
-
33
- for t in tqdm(self.scheduler.timesteps):
34
- x = torch.concat([samples, sketches], dim=1).to(self.device)
35
- residuals = self.unet(x, t).sample
36
- samples = self.scheduler.step(residuals, t, samples).prev_sample
37
-
38
- # samples = self.denormalize(samples).cpu().int().numpy()
39
- samples = samples.cpu().numpy()
40
- return samples
41
-
42
- def normalize(self, x):
43
- # map x to -1 < x < 1
44
- # I'm doing normalization with zero understanding :o
45
- x /= 255
46
- x = torchvision.transforms.Normalize(0.5, 0.5)(x)
47
- return x
48
-
49
- def denormalize(self, x):
50
- x = x * 0.5 + 0.5 # map from (-1, 1) back to (0, 1)
51
- x = x * 255
52
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_ddpm_sketch2img.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ from torchvision import transforms
4
+ from tqdm import tqdm
5
+
6
+
7
+ class DDPMSketch2ImgPipeline(DiffusionPipeline):
8
+ # TODO: Move transforms to another class
9
+
10
+ def __init__(self, unet, scheduler):
11
+ super().__init__()
12
+ self.register_modules(unet=unet, scheduler=scheduler)
13
+
14
+ def __call__(self, sketch, num_inference_step=1000, tqdm_leave=True):
15
+ # sketch : PIL
16
+ # returl : PIL
17
+
18
+ sketch = transforms.functional.pil_to_tensor(sketch).float()
19
+ sketch = self.normalize(sketch).to(self.device)
20
+ sketch = sketch.unsqueeze(0)
21
+
22
+ image = self.sample(sketch, num_inference_step, tqdm_leave)
23
+
24
+ image = image.squeeze(0)
25
+ image = self.denormalize(image)
26
+ image = self.denormalized_tensor_to_pil(image)
27
+
28
+ return image
29
+
30
+ def sample(self, transformed_sketch, num_inference_step, tqdm_leave=True):
31
+ assert (
32
+ len(transformed_sketch.shape) == 4
33
+ ), f"(bs, c, h, w) but {transformed_sketch.shape}"
34
+
35
+ # Is this the right place to set timesteps?
36
+ self.scheduler.set_timesteps(num_inference_step, device=self.device)
37
+
38
+ s = transformed_sketch.shape
39
+ # Assume image's channels == out_channels
40
+ image = torch.randn((s[0], self.unet.config["out_channels"], s[2], s[3])).to(
41
+ self.device
42
+ )
43
+
44
+ for t in tqdm(self.scheduler.timesteps, leave=tqdm_leave):
45
+ model_input = torch.concat([image, transformed_sketch], dim=1).to(
46
+ self.device
47
+ )
48
+ with torch.no_grad():
49
+ model_output = self.unet(model_input, t).sample
50
+ image = self.scheduler.step(model_output, t, image).prev_sample
51
+
52
+ return image
53
+
54
+ def denormalized_tensor_to_pil(self, tensor):
55
+ assert len(tensor.shape) == 3, f"(c, h, w) but {tensor.shape}"
56
+
57
+ tensor = tensor.cpu().clip(0, 255).to(torch.uint8)
58
+ pil = transforms.functional.to_pil_image(tensor)
59
+ return pil
60
+
61
+ def normalize(self, x):
62
+ assert x.dtype == torch.float
63
+ # map x to -1 < x < 1
64
+ # I'm doing normalization with zero understanding :o
65
+ x = x / 255.0
66
+ x = transforms.Normalize([0.5], [0.5])(x)
67
+ return x
68
+
69
+ def denormalize(self, x):
70
+ assert x.dtype == torch.float
71
+ x = x * 0.5 + 0.5 # map from (-1, 1) back to (0, 1)
72
+ x = x * 255.0
73
+ return x