Shiroi-max commited on
Commit
1688bb3
1 Parent(s): 0782bce

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +66 -0
pipeline.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ from diffusers import DiffusionPipeline, ImagePipelineOutput, randn_tensor
3
+
4
+ import torch
5
+
6
+
7
+ class DDPMConditionalPipeline(DiffusionPipeline):
8
+ model_cpu_offload_seq = "unet"
9
+
10
+ def __init__(self, unet, scheduler):
11
+ super().__init__()
12
+ self.register_modules(unet=unet, scheduler=scheduler)
13
+
14
+ @torch.no_grad()
15
+ def __call__(
16
+ self,
17
+ label,
18
+ batch_size: int = 1,
19
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
20
+ num_inference_steps: int = 1000,
21
+ output_type: Optional[str] = "pil",
22
+ return_dict: bool = True,
23
+ ) -> Union[ImagePipelineOutput, Tuple]:
24
+ # Sample gaussian noise to begin loop
25
+ if isinstance(self.unet.config.sample_size, int):
26
+ image_shape = (
27
+ batch_size,
28
+ self.unet.config.in_channels,
29
+ self.unet.config.sample_size,
30
+ self.unet.config.sample_size,
31
+ )
32
+ else:
33
+ image_shape = (
34
+ batch_size,
35
+ self.unet.config.in_channels,
36
+ *self.unet.config.sample_size,
37
+ )
38
+
39
+ if self.device.type == "mps":
40
+ # randn does not work reproducibly on mps
41
+ image = randn_tensor(image_shape, generator=generator)
42
+ image = image.to(self.device)
43
+ else:
44
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
45
+
46
+ # set step values
47
+ self.scheduler.set_timesteps(num_inference_steps)
48
+
49
+ for t in self.progress_bar(self.scheduler.timesteps):
50
+ # 1. predict noise model_output
51
+ model_output = self.unet(image, t, label).sample
52
+
53
+ # 2. compute previous image: x_t -> x_t-1
54
+ image = self.scheduler.step(
55
+ model_output, t, image, generator=generator
56
+ ).prev_sample
57
+
58
+ image = (image / 2 + 0.5).clamp(0, 1)
59
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
60
+ if output_type == "pil":
61
+ image = self.numpy_to_pil(image)
62
+
63
+ if not return_dict:
64
+ return (image,)
65
+
66
+ return ImagePipelineOutput(images=image)