Upload pipeline_ncsn.py with huggingface_hub
Browse files- pipeline_ncsn.py +175 -0
pipeline_ncsn.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, List, Optional, Self, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
5 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from .scheduling_ncsn import (
|
9 |
+
AnnealedLangevinDynamicOutput,
|
10 |
+
AnnealedLangevinDynamicScheduler,
|
11 |
+
)
|
12 |
+
from .unet_2d_ncsn import UNet2DModelForNCSN
|
13 |
+
|
14 |
+
|
15 |
+
def normalize_images(image: torch.Tensor) -> torch.Tensor:
|
16 |
+
"""Normalize the image to be between 0 and 1 using min-max normalization manner.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
image (torch.Tensor): The batch of images to normalize.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: The normalized image.
|
23 |
+
"""
|
24 |
+
assert image.ndim == 4, image.ndim
|
25 |
+
batch_size = image.shape[0]
|
26 |
+
|
27 |
+
def _normalize(img: torch.Tensor) -> torch.Tensor:
|
28 |
+
return (img - img.min()) / (img.max() - img.min())
|
29 |
+
|
30 |
+
for i in range(batch_size):
|
31 |
+
image[i] = _normalize(image[i])
|
32 |
+
return image
|
33 |
+
|
34 |
+
|
35 |
+
class NCSNPipeline(DiffusionPipeline):
|
36 |
+
r"""
|
37 |
+
Pipeline for unconditional image generation using Noise Conditional Score Network (NCSN).
|
38 |
+
|
39 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
40 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
unet ([`UNet2DModelForNCSN`]):
|
44 |
+
A `UNet2DModelForNCSN` to estimate the score of the image.
|
45 |
+
scheduler ([`AnnealedLangevinDynamicScheduler`]):
|
46 |
+
A `AnnealedLangevinDynamicScheduler` to be used in combination with `unet` to estimate the score of the image.
|
47 |
+
"""
|
48 |
+
|
49 |
+
unet: UNet2DModelForNCSN
|
50 |
+
scheduler: AnnealedLangevinDynamicScheduler
|
51 |
+
|
52 |
+
_callback_tensor_inputs: List[str] = ["samples"]
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicScheduler
|
56 |
+
) -> None:
|
57 |
+
super().__init__()
|
58 |
+
self.register_modules(unet=unet, scheduler=scheduler)
|
59 |
+
|
60 |
+
def decode_samples(self, samples: torch.Tensor) -> torch.Tensor:
|
61 |
+
# Normalize the generated image
|
62 |
+
samples = normalize_images(samples)
|
63 |
+
# Rearrange the generated image to the correct format
|
64 |
+
samples = rearrange(samples, "b c w h -> b w h c")
|
65 |
+
return samples
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def __call__(
|
69 |
+
self,
|
70 |
+
batch_size: int = 1,
|
71 |
+
num_inference_steps: int = 10,
|
72 |
+
generator: Optional[torch.Generator] = None,
|
73 |
+
output_type: str = "pil",
|
74 |
+
return_dict: bool = True,
|
75 |
+
callback_on_step_end: Optional[
|
76 |
+
Union[
|
77 |
+
Callable[[Self, int, int, Dict], Dict],
|
78 |
+
PipelineCallback,
|
79 |
+
MultiPipelineCallbacks,
|
80 |
+
]
|
81 |
+
] = None,
|
82 |
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
83 |
+
**kwargs,
|
84 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
85 |
+
r"""
|
86 |
+
The call function to the pipeline for generation.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
batch_size (`int`, *optional*, defaults to 1):
|
90 |
+
The number of images to generate.
|
91 |
+
num_inference_steps (`int`, *optional*, defaults to 10):
|
92 |
+
The number of inference steps.
|
93 |
+
generator (`torch.Generator`, `optional`):
|
94 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
95 |
+
generation deterministic.
|
96 |
+
output_type (`str`, `optional`, defaults to `"pil"`):
|
97 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
98 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
99 |
+
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
|
100 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
101 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
102 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
103 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
104 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
105 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
106 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
107 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
108 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
112 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
113 |
+
returned where the first element is a list with the generated images.
|
114 |
+
"""
|
115 |
+
callback_on_step_end_tensor_inputs = (
|
116 |
+
callback_on_step_end_tensor_inputs or self._callback_tensor_inputs
|
117 |
+
)
|
118 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
119 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
120 |
+
|
121 |
+
samples_shape = (
|
122 |
+
batch_size,
|
123 |
+
self.unet.config.in_channels, # type: ignore
|
124 |
+
self.unet.config.sample_size, # type: ignore
|
125 |
+
self.unet.config.sample_size, # type: ignore
|
126 |
+
)
|
127 |
+
|
128 |
+
# Generate a random sample
|
129 |
+
# NOTE: The behavior of random number generation is different between CPU and GPU,
|
130 |
+
# so first generate random numbers on CPU and then move them to GPU (if available).
|
131 |
+
samples = torch.rand(samples_shape, generator=generator)
|
132 |
+
samples = samples.to(self.device)
|
133 |
+
|
134 |
+
# Set the number of inference steps for the scheduler
|
135 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
136 |
+
|
137 |
+
# Perform the reverse diffusion process
|
138 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
139 |
+
# Perform `num_annnealed_steps` annealing steps
|
140 |
+
for i in range(self.scheduler.num_annealed_steps):
|
141 |
+
# Predict the score using the model
|
142 |
+
model_output = self.unet(samples, t).sample # type: ignore
|
143 |
+
|
144 |
+
# Perform the annealed langevin dynamics
|
145 |
+
output = self.scheduler.step(
|
146 |
+
model_output=model_output,
|
147 |
+
timestep=t,
|
148 |
+
samples=samples,
|
149 |
+
generator=generator,
|
150 |
+
return_dict=return_dict,
|
151 |
+
)
|
152 |
+
samples = (
|
153 |
+
output.prev_sample
|
154 |
+
if isinstance(output, AnnealedLangevinDynamicOutput)
|
155 |
+
else output[0]
|
156 |
+
)
|
157 |
+
|
158 |
+
# Perform the callback on step end if provided
|
159 |
+
if callback_on_step_end is not None:
|
160 |
+
callback_kwargs = {}
|
161 |
+
for k in callback_on_step_end_tensor_inputs:
|
162 |
+
callback_kwargs[k] = locals()[k]
|
163 |
+
|
164 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
165 |
+
samples = callback_outputs.pop("samples", samples)
|
166 |
+
|
167 |
+
samples = self.decode_samples(samples)
|
168 |
+
|
169 |
+
if output_type == "pil":
|
170 |
+
samples = self.numpy_to_pil(samples.cpu().numpy())
|
171 |
+
|
172 |
+
if return_dict:
|
173 |
+
return ImagePipelineOutput(images=samples) # type: ignore
|
174 |
+
else:
|
175 |
+
return (samples,)
|