jiuntian commited on
Commit
09e5125
1 Parent(s): 6fce1d6

update pipeline

Browse files
README.md CHANGED
@@ -1,3 +1,35 @@
1
  ---
2
  license: bsd
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: bsd
3
  ---
4
+
5
+ # InteractDiffusion Diffuser Implementation
6
+
7
+ ## How to Use
8
+
9
+ ```python
10
+ from diffusers import DiffusionPipeline
11
+ import torch
12
+
13
+ pipeline = DiffusionPipeline.from_pretrained(
14
+ "interactdiffusion/diffusers-v1-2",
15
+ trust_remote_code=True,
16
+ variant="fp16", torch_dtype=torch.float16
17
+ )
18
+ pipeline = pipeline.to("cuda")
19
+
20
+ images = pipeline(
21
+ prompt="a person is feeding a cat",
22
+ interactdiffusion_subject_phrases=["person"],
23
+ interactdiffusion_object_phrases=["cat"],
24
+ interactdiffusion_action_phrases=["feeding"],
25
+ interactdiffusion_subject_boxes=[[0.0332, 0.1660, 0.3359, 0.7305]],
26
+ interactdiffusion_object_boxes=[[0.2891, 0.4766, 0.6680, 0.7930]],
27
+ interactdiffusion_scheduled_sampling_beta=1,
28
+ output_type="pil",
29
+ num_inference_steps=50,
30
+ ).images
31
+
32
+ images[0].save('out.jpg')
33
+ ```
34
+
35
+ For more information, please check the project homepage:
pipeline_stable_diffusion_interactdiffusion.py CHANGED
@@ -26,6 +26,7 @@ from diffusers.image_processor import VaeImageProcessor
26
  from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
  from diffusers.models.attention import GatedSelfAttentionDense
 
29
  from diffusers.models.embeddings import get_fourier_embeds_from_boundingbox
30
  from diffusers.models.lora import adjust_lora_scale_text_encoder
31
  from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -38,7 +39,7 @@ from diffusers.utils import (
38
  unscale_lora_layers,
39
  )
40
  from diffusers.utils.torch_utils import randn_tensor
41
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
42
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
43
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
 
@@ -46,7 +47,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
46
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
 
48
 
49
- class StableDiffusionInteractDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
50
  r"""
51
  Pipeline for text-to-image generation using Stable Diffusion with Interaction-to-Image Generation (InteractDiffusion).
52
 
@@ -105,17 +106,6 @@ class StableDiffusionInteractDiffusionPipeline(DiffusionPipeline, StableDiffusio
105
  "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
106
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
107
  )
108
-
109
- # # load position_net
110
- # positive_len = 768
111
- # if isinstance(unet.config.cross_attention_dim, int):
112
- # positive_len = unet.config.cross_attention_dim
113
- # elif isinstance(unet.config.cross_attention_dim, tuple) or isinstance(unet.config.cross_attention_dim, list):
114
- # positive_len = unet.config.cross_attention_dim[0]
115
-
116
- # self.position_net = InteractDiffusionInteractionProjection(
117
- # in_dim=positive_len, out_dim=unet.config.cross_attention_dim
118
- # )
119
 
120
  self.register_modules(
121
  vae=vae,
@@ -130,6 +120,125 @@ class StableDiffusionInteractDiffusionPipeline(DiffusionPipeline, StableDiffusio
130
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
131
  self.register_to_config(requires_safety_checker=requires_safety_checker)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
134
  def _encode_prompt(
135
  self,
@@ -464,7 +573,6 @@ class StableDiffusionInteractDiffusionPipeline(DiffusionPipeline, StableDiffusio
464
  module.enabled = enabled
465
 
466
  @torch.no_grad()
467
- @replace_example_docstring(EXAMPLE_DOC_STRING)
468
  def __call__(
469
  self,
470
  prompt: Union[str, List[str]] = None,
 
26
  from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
  from diffusers.models.attention import GatedSelfAttentionDense
29
+ from diffusers.models.attention_processor import FusedAttnProcessor2_0
30
  from diffusers.models.embeddings import get_fourier_embeds_from_boundingbox
31
  from diffusers.models.lora import adjust_lora_scale_text_encoder
32
  from diffusers.schedulers import KarrasDiffusionSchedulers
 
39
  unscale_lora_layers,
40
  )
41
  from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
44
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
45
 
 
47
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
 
49
 
50
+ class StableDiffusionInteractDiffusionPipeline(DiffusionPipeline):
51
  r"""
52
  Pipeline for text-to-image generation using Stable Diffusion with Interaction-to-Image Generation (InteractDiffusion).
53
 
 
106
  "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
107
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
108
  )
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  self.register_modules(
111
  vae=vae,
 
120
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
121
  self.register_to_config(requires_safety_checker=requires_safety_checker)
122
 
123
+ ### Backward compability with pre diffusers-0.27.0, which this class cannot inherit StableDiffusionMixin class
124
+ def enable_vae_slicing(self):
125
+ r"""
126
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
127
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
128
+ """
129
+ self.vae.enable_slicing()
130
+
131
+ def disable_vae_slicing(self):
132
+ r"""
133
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
134
+ computing decoding in one step.
135
+ """
136
+ self.vae.disable_slicing()
137
+
138
+ def enable_vae_tiling(self):
139
+ r"""
140
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
141
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
142
+ processing larger images.
143
+ """
144
+ self.vae.enable_tiling()
145
+
146
+ def disable_vae_tiling(self):
147
+ r"""
148
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
149
+ computing decoding in one step.
150
+ """
151
+ self.vae.disable_tiling()
152
+
153
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
154
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
155
+
156
+ The suffixes after the scaling factors represent the stages where they are being applied.
157
+
158
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
159
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
160
+
161
+ Args:
162
+ s1 (`float`):
163
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
164
+ mitigate "oversmoothing effect" in the enhanced denoising process.
165
+ s2 (`float`):
166
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
167
+ mitigate "oversmoothing effect" in the enhanced denoising process.
168
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
169
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
170
+ """
171
+ if not hasattr(self, "unet"):
172
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
173
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
174
+
175
+ def disable_freeu(self):
176
+ """Disables the FreeU mechanism if enabled."""
177
+ self.unet.disable_freeu()
178
+
179
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
180
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
181
+ """
182
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
183
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
184
+
185
+ <Tip warning={true}>
186
+
187
+ This API is 🧪 experimental.
188
+
189
+ </Tip>
190
+
191
+ Args:
192
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
193
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
194
+ """
195
+ self.fusing_unet = False
196
+ self.fusing_vae = False
197
+
198
+ if unet:
199
+ self.fusing_unet = True
200
+ self.unet.fuse_qkv_projections()
201
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
202
+
203
+ if vae:
204
+ if not isinstance(self.vae, AutoencoderKL):
205
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
206
+
207
+ self.fusing_vae = True
208
+ self.vae.fuse_qkv_projections()
209
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
210
+
211
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
212
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
213
+ """Disable QKV projection fusion if enabled.
214
+
215
+ <Tip warning={true}>
216
+
217
+ This API is 🧪 experimental.
218
+
219
+ </Tip>
220
+
221
+ Args:
222
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
223
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
224
+
225
+ """
226
+ if unet:
227
+ if not self.fusing_unet:
228
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
229
+ else:
230
+ self.unet.unfuse_qkv_projections()
231
+ self.fusing_unet = False
232
+
233
+ if vae:
234
+ if not self.fusing_vae:
235
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
236
+ else:
237
+ self.vae.unfuse_qkv_projections()
238
+ self.fusing_vae = False
239
+
240
+ ### end of the section
241
+
242
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
243
  def _encode_prompt(
244
  self,
 
573
  module.enabled = enabled
574
 
575
  @torch.no_grad()
 
576
  def __call__(
577
  self,
578
  prompt: Union[str, List[str]] = None,