shisheng7
commited on
Commit
•
bd6c4af
1
Parent(s):
f7e8357
update home
Browse files- configs/inference/inference.yaml +118 -0
- configs/unet/unet.yaml +44 -0
- data/inference.json +12 -0
- joyhallo/__init__.py +0 -0
- joyhallo/animate/__init__.py +0 -0
- joyhallo/animate/face_animate.py +441 -0
- joyhallo/animate/face_animate_static.py +480 -0
- joyhallo/datasets/__init__.py +0 -0
- joyhallo/datasets/audio_processor.py +176 -0
- joyhallo/datasets/image_processor.py +345 -0
- joyhallo/datasets/mask_image.py +153 -0
- joyhallo/datasets/talk_video.py +321 -0
- joyhallo/models/__init__.py +0 -0
- joyhallo/models/attention.py +893 -0
- joyhallo/models/audio_proj.py +124 -0
- joyhallo/models/face_locator.py +113 -0
- joyhallo/models/image_proj.py +76 -0
- joyhallo/models/motion_module.py +605 -0
- joyhallo/models/mutual_self_attention.py +495 -0
- joyhallo/models/resnet.py +429 -0
- joyhallo/models/transformer_2d.py +428 -0
- joyhallo/models/transformer_3d.py +256 -0
- joyhallo/models/unet_2d_blocks.py +1340 -0
- joyhallo/models/unet_2d_condition.py +1428 -0
- joyhallo/models/unet_3d.py +840 -0
- joyhallo/models/unet_3d_blocks.py +1398 -0
- joyhallo/models/wav2vec.py +206 -0
- joyhallo/utils/__init__.py +0 -0
- joyhallo/utils/config.py +25 -0
- joyhallo/utils/util.py +976 -0
- scripts/inference.py +690 -0
configs/inference/inference.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
train_bs: 4
|
3 |
+
val_bs: 1
|
4 |
+
train_width: 512
|
5 |
+
train_height: 512
|
6 |
+
fps: 25
|
7 |
+
sample_rate: 16000
|
8 |
+
n_motion_frames: 2
|
9 |
+
n_sample_frames: 16
|
10 |
+
audio_margin: 2
|
11 |
+
train_meta_paths:
|
12 |
+
- "./data/inference.json"
|
13 |
+
|
14 |
+
wav2vec_config:
|
15 |
+
audio_type: "vocals" # audio vocals
|
16 |
+
model_scale: "base" # base large
|
17 |
+
features: "all" # last avg all
|
18 |
+
model_path: ./pretrained_models/chinese-wav2vec2-base
|
19 |
+
audio_separator:
|
20 |
+
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
|
21 |
+
face_expand_ratio: 1.2
|
22 |
+
|
23 |
+
solver:
|
24 |
+
gradient_accumulation_steps: 1
|
25 |
+
mixed_precision: "no"
|
26 |
+
enable_xformers_memory_efficient_attention: True
|
27 |
+
gradient_checkpointing: True
|
28 |
+
max_train_steps: 30000
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
# lr
|
31 |
+
learning_rate: 1e-5
|
32 |
+
scale_lr: False
|
33 |
+
lr_warmup_steps: 1
|
34 |
+
lr_scheduler: "constant"
|
35 |
+
|
36 |
+
# optimizer
|
37 |
+
use_8bit_adam: True
|
38 |
+
adam_beta1: 0.9
|
39 |
+
adam_beta2: 0.999
|
40 |
+
adam_weight_decay: 1.0e-2
|
41 |
+
adam_epsilon: 1.0e-8
|
42 |
+
|
43 |
+
val:
|
44 |
+
validation_steps: 1000
|
45 |
+
|
46 |
+
noise_scheduler_kwargs:
|
47 |
+
num_train_timesteps: 1000
|
48 |
+
beta_start: 0.00085
|
49 |
+
beta_end: 0.012
|
50 |
+
beta_schedule: "linear"
|
51 |
+
steps_offset: 1
|
52 |
+
clip_sample: false
|
53 |
+
|
54 |
+
unet_additional_kwargs:
|
55 |
+
use_inflated_groupnorm: true
|
56 |
+
unet_use_cross_frame_attention: false
|
57 |
+
unet_use_temporal_attention: false
|
58 |
+
use_motion_module: true
|
59 |
+
use_audio_module: true
|
60 |
+
motion_module_resolutions:
|
61 |
+
- 1
|
62 |
+
- 2
|
63 |
+
- 4
|
64 |
+
- 8
|
65 |
+
motion_module_mid_block: true
|
66 |
+
motion_module_decoder_only: false
|
67 |
+
motion_module_type: Vanilla
|
68 |
+
motion_module_kwargs:
|
69 |
+
num_attention_heads: 8
|
70 |
+
num_transformer_block: 1
|
71 |
+
attention_block_types:
|
72 |
+
- Temporal_Self
|
73 |
+
- Temporal_Self
|
74 |
+
temporal_position_encoding: true
|
75 |
+
temporal_position_encoding_max_len: 32
|
76 |
+
temporal_attention_dim_div: 1
|
77 |
+
audio_attention_dim: 768
|
78 |
+
stack_enable_blocks_name:
|
79 |
+
- "up"
|
80 |
+
- "down"
|
81 |
+
- "mid"
|
82 |
+
stack_enable_blocks_depth: [0,1,2,3]
|
83 |
+
|
84 |
+
trainable_para:
|
85 |
+
- audio_modules
|
86 |
+
- motion_modules
|
87 |
+
|
88 |
+
base_model_path: "./pretrained_models/stable-diffusion-v1-5"
|
89 |
+
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
|
90 |
+
face_analysis_model_path: "./pretrained_models/face_analysis"
|
91 |
+
mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
|
92 |
+
|
93 |
+
weight_dtype: "fp16" # [fp16, fp32]
|
94 |
+
uncond_img_ratio: 0.05
|
95 |
+
uncond_audio_ratio: 0.05
|
96 |
+
uncond_ia_ratio: 0.05
|
97 |
+
start_ratio: 0.05
|
98 |
+
noise_offset: 0.05
|
99 |
+
snr_gamma: 5.0
|
100 |
+
enable_zero_snr: True
|
101 |
+
stage1_ckpt_dir: "./exp_output/stage1/"
|
102 |
+
|
103 |
+
single_inference_times: 10
|
104 |
+
inference_steps: 40
|
105 |
+
cfg_scale: 3.5
|
106 |
+
|
107 |
+
seed: 42
|
108 |
+
resume_from_checkpoint: "latest"
|
109 |
+
checkpointing_steps: 500
|
110 |
+
|
111 |
+
exp_name: "joyhallo"
|
112 |
+
output_dir: "./opts"
|
113 |
+
|
114 |
+
audio_ckpt_dir: "./pretrained_models/joyhallo/net.pth"
|
115 |
+
|
116 |
+
ref_img_path: None
|
117 |
+
|
118 |
+
audio_path: None
|
configs/unet/unet.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
unet_use_cross_frame_attention: false
|
4 |
+
unet_use_temporal_attention: false
|
5 |
+
use_motion_module: true
|
6 |
+
use_audio_module: true
|
7 |
+
motion_module_resolutions:
|
8 |
+
- 1
|
9 |
+
- 2
|
10 |
+
- 4
|
11 |
+
- 8
|
12 |
+
motion_module_mid_block: true
|
13 |
+
motion_module_decoder_only: false
|
14 |
+
motion_module_type: Vanilla
|
15 |
+
motion_module_kwargs:
|
16 |
+
num_attention_heads: 8
|
17 |
+
num_transformer_block: 1
|
18 |
+
attention_block_types:
|
19 |
+
- Temporal_Self
|
20 |
+
- Temporal_Self
|
21 |
+
temporal_position_encoding: true
|
22 |
+
temporal_position_encoding_max_len: 32
|
23 |
+
temporal_attention_dim_div: 1
|
24 |
+
audio_attention_dim: 768
|
25 |
+
stack_enable_blocks_name:
|
26 |
+
- "up"
|
27 |
+
- "down"
|
28 |
+
- "mid"
|
29 |
+
stack_enable_blocks_depth: [0,1,2,3]
|
30 |
+
|
31 |
+
enable_zero_snr: true
|
32 |
+
|
33 |
+
noise_scheduler_kwargs:
|
34 |
+
beta_start: 0.00085
|
35 |
+
beta_end: 0.012
|
36 |
+
beta_schedule: "linear"
|
37 |
+
clip_sample: false
|
38 |
+
steps_offset: 1
|
39 |
+
### Zero-SNR params
|
40 |
+
prediction_type: "v_prediction"
|
41 |
+
rescale_betas_zero_snr: True
|
42 |
+
timestep_spacing: "trailing"
|
43 |
+
|
44 |
+
sampler: DDIM
|
data/inference.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"video_path": "",
|
4 |
+
"mask_path": "",
|
5 |
+
"sep_mask_border": "",
|
6 |
+
"sep_mask_face": "",
|
7 |
+
"sep_mask_lip": "",
|
8 |
+
"face_emb_path": "",
|
9 |
+
"audio_path": "",
|
10 |
+
"vocals_emb_base_all": ""
|
11 |
+
}
|
12 |
+
]
|
joyhallo/__init__.py
ADDED
File without changes
|
joyhallo/animate/__init__.py
ADDED
File without changes
|
joyhallo/animate/face_animate.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module is responsible for animating faces in videos using a combination of deep learning techniques.
|
3 |
+
It provides a pipeline for generating face animations by processing video frames and extracting face features.
|
4 |
+
The module utilizes various schedulers and utilities for efficient face animation and supports different types
|
5 |
+
of latents for more control over the animation process.
|
6 |
+
|
7 |
+
Functions and Classes:
|
8 |
+
- FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks.
|
9 |
+
- __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.).
|
10 |
+
- prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements.
|
11 |
+
- prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers.
|
12 |
+
- decode_latents: Decodes the latents into video frames, ready for animation.
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
- Import the necessary packages and classes.
|
16 |
+
- Create a FaceAnimatePipeline instance with the required components.
|
17 |
+
- Prepare the latents for the animation process.
|
18 |
+
- Use the pipeline to generate the animated video.
|
19 |
+
|
20 |
+
Note:
|
21 |
+
- This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning.
|
22 |
+
- The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import inspect
|
26 |
+
from dataclasses import dataclass
|
27 |
+
from typing import Callable, List, Optional, Union
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from diffusers import (DDIMScheduler, DiffusionPipeline,
|
32 |
+
DPMSolverMultistepScheduler,
|
33 |
+
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
34 |
+
LMSDiscreteScheduler, PNDMScheduler)
|
35 |
+
from diffusers.image_processor import VaeImageProcessor
|
36 |
+
from diffusers.utils import BaseOutput
|
37 |
+
from diffusers.utils.torch_utils import randn_tensor
|
38 |
+
from einops import rearrange, repeat
|
39 |
+
from tqdm import tqdm
|
40 |
+
|
41 |
+
from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class FaceAnimatePipelineOutput(BaseOutput):
|
46 |
+
"""
|
47 |
+
FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline.
|
48 |
+
|
49 |
+
Attributes:
|
50 |
+
videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames.
|
51 |
+
|
52 |
+
Methods:
|
53 |
+
__init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames.
|
54 |
+
"""
|
55 |
+
videos: Union[torch.Tensor, np.ndarray]
|
56 |
+
|
57 |
+
class FaceAnimatePipeline(DiffusionPipeline):
|
58 |
+
"""
|
59 |
+
FaceAnimatePipeline is a custom DiffusionPipeline for animating faces.
|
60 |
+
|
61 |
+
It inherits from the DiffusionPipeline class and is used to animate faces by
|
62 |
+
utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet,
|
63 |
+
a face locator, and an image processor. The pipeline is responsible for generating
|
64 |
+
and animating face latents, and decoding the latents to produce the final video output.
|
65 |
+
|
66 |
+
Attributes:
|
67 |
+
vae (VaeImageProcessor): Variational autoencoder for processing images.
|
68 |
+
reference_unet (nn.Module): Reference UNet for mutual self-attention.
|
69 |
+
denoising_unet (nn.Module): Denoising UNet for image denoising.
|
70 |
+
face_locator (nn.Module): Face locator for detecting and cropping faces.
|
71 |
+
image_proj (nn.Module): Image projector for processing images.
|
72 |
+
scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
|
73 |
+
EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
|
74 |
+
DPMSolverMultistepScheduler]): Diffusion scheduler for
|
75 |
+
controlling the noise level.
|
76 |
+
|
77 |
+
Methods:
|
78 |
+
__init__(self, vae, reference_unet, denoising_unet, face_locator,
|
79 |
+
image_proj, scheduler): Initializes the FaceAnimatePipeline
|
80 |
+
with the given components and scheduler.
|
81 |
+
prepare_latents(self, batch_size, num_channels_latents, width, height,
|
82 |
+
video_length, dtype, device, generator=None, latents=None):
|
83 |
+
Prepares the initial latents for video generation.
|
84 |
+
prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword
|
85 |
+
arguments for the scheduler step.
|
86 |
+
decode_latents(self, latents): Decodes the latents to produce the final
|
87 |
+
video output.
|
88 |
+
"""
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
vae,
|
92 |
+
reference_unet,
|
93 |
+
denoising_unet,
|
94 |
+
face_locator,
|
95 |
+
image_proj,
|
96 |
+
scheduler: Union[
|
97 |
+
DDIMScheduler,
|
98 |
+
PNDMScheduler,
|
99 |
+
LMSDiscreteScheduler,
|
100 |
+
EulerDiscreteScheduler,
|
101 |
+
EulerAncestralDiscreteScheduler,
|
102 |
+
DPMSolverMultistepScheduler,
|
103 |
+
],
|
104 |
+
) -> None:
|
105 |
+
super().__init__()
|
106 |
+
|
107 |
+
self.register_modules(
|
108 |
+
vae=vae,
|
109 |
+
reference_unet=reference_unet,
|
110 |
+
denoising_unet=denoising_unet,
|
111 |
+
face_locator=face_locator,
|
112 |
+
scheduler=scheduler,
|
113 |
+
image_proj=image_proj,
|
114 |
+
)
|
115 |
+
|
116 |
+
self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
117 |
+
|
118 |
+
self.ref_image_processor = VaeImageProcessor(
|
119 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True,
|
120 |
+
)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def _execution_device(self):
|
124 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
125 |
+
return self.device
|
126 |
+
for module in self.unet.modules():
|
127 |
+
if (
|
128 |
+
hasattr(module, "_hf_hook")
|
129 |
+
and hasattr(module._hf_hook, "execution_device")
|
130 |
+
and module._hf_hook.execution_device is not None
|
131 |
+
):
|
132 |
+
return torch.device(module._hf_hook.execution_device)
|
133 |
+
return self.device
|
134 |
+
|
135 |
+
def prepare_latents(
|
136 |
+
self,
|
137 |
+
batch_size: int, # Number of videos to generate in parallel
|
138 |
+
num_channels_latents: int, # Number of channels in the latents
|
139 |
+
width: int, # Width of the video frame
|
140 |
+
height: int, # Height of the video frame
|
141 |
+
video_length: int, # Length of the video in frames
|
142 |
+
dtype: torch.dtype, # Data type of the latents
|
143 |
+
device: torch.device, # Device to store the latents on
|
144 |
+
generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
|
145 |
+
latents: Optional[torch.Tensor] = None # Pre-generated latents (optional)
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
Prepares the initial latents for video generation.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
batch_size (int): Number of videos to generate in parallel.
|
152 |
+
num_channels_latents (int): Number of channels in the latents.
|
153 |
+
width (int): Width of the video frame.
|
154 |
+
height (int): Height of the video frame.
|
155 |
+
video_length (int): Length of the video in frames.
|
156 |
+
dtype (torch.dtype): Data type of the latents.
|
157 |
+
device (torch.device): Device to store the latents on.
|
158 |
+
generator (Optional[torch.Generator]): Random number generator for reproducibility.
|
159 |
+
latents (Optional[torch.Tensor]): Pre-generated latents (optional).
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height)
|
163 |
+
containing the initial latents for video generation.
|
164 |
+
"""
|
165 |
+
shape = (
|
166 |
+
batch_size,
|
167 |
+
num_channels_latents,
|
168 |
+
video_length,
|
169 |
+
height // self.vae_scale_factor,
|
170 |
+
width // self.vae_scale_factor,
|
171 |
+
)
|
172 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
173 |
+
raise ValueError(
|
174 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
175 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
176 |
+
)
|
177 |
+
|
178 |
+
if latents is None:
|
179 |
+
latents = randn_tensor(
|
180 |
+
shape, generator=generator, device=device, dtype=dtype
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
latents = latents.to(device)
|
184 |
+
|
185 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
186 |
+
latents = latents * self.scheduler.init_noise_sigma
|
187 |
+
return latents
|
188 |
+
|
189 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
190 |
+
"""
|
191 |
+
Prepares extra keyword arguments for the scheduler step.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
generator (Optional[torch.Generator]): Random number generator for reproducibility.
|
195 |
+
eta (float): The eta (η) parameter used with the DDIMScheduler.
|
196 |
+
It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1].
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
dict: A dictionary containing the extra keyword arguments for the scheduler step.
|
200 |
+
"""
|
201 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
202 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
203 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
204 |
+
# and should be between [0, 1]
|
205 |
+
|
206 |
+
accepts_eta = "eta" in set(
|
207 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
208 |
+
)
|
209 |
+
extra_step_kwargs = {}
|
210 |
+
if accepts_eta:
|
211 |
+
extra_step_kwargs["eta"] = eta
|
212 |
+
|
213 |
+
# check if the scheduler accepts generator
|
214 |
+
accepts_generator = "generator" in set(
|
215 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
216 |
+
)
|
217 |
+
if accepts_generator:
|
218 |
+
extra_step_kwargs["generator"] = generator
|
219 |
+
return extra_step_kwargs
|
220 |
+
|
221 |
+
def decode_latents(self, latents):
|
222 |
+
"""
|
223 |
+
Decode the latents to produce a video.
|
224 |
+
|
225 |
+
Parameters:
|
226 |
+
latents (torch.Tensor): The latents to be decoded.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
video (torch.Tensor): The decoded video.
|
230 |
+
video_length (int): The length of the video in frames.
|
231 |
+
"""
|
232 |
+
video_length = latents.shape[2]
|
233 |
+
latents = 1 / 0.18215 * latents
|
234 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
235 |
+
# video = self.vae.decode(latents).sample
|
236 |
+
video = []
|
237 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
238 |
+
video.append(self.vae.decode(
|
239 |
+
latents[frame_idx: frame_idx + 1]).sample)
|
240 |
+
video = torch.cat(video)
|
241 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
242 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
243 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
244 |
+
video = video.cpu().float().numpy()
|
245 |
+
return video
|
246 |
+
|
247 |
+
|
248 |
+
@torch.no_grad()
|
249 |
+
def __call__(
|
250 |
+
self,
|
251 |
+
ref_image,
|
252 |
+
face_emb,
|
253 |
+
audio_tensor,
|
254 |
+
face_mask,
|
255 |
+
pixel_values_full_mask,
|
256 |
+
pixel_values_face_mask,
|
257 |
+
pixel_values_lip_mask,
|
258 |
+
width,
|
259 |
+
height,
|
260 |
+
video_length,
|
261 |
+
num_inference_steps,
|
262 |
+
guidance_scale,
|
263 |
+
num_images_per_prompt=1,
|
264 |
+
eta: float = 0.0,
|
265 |
+
motion_scale: Optional[List[torch.Tensor]] = None,
|
266 |
+
generator: Optional[Union[torch.Generator,
|
267 |
+
List[torch.Generator]]] = None,
|
268 |
+
output_type: Optional[str] = "tensor",
|
269 |
+
return_dict: bool = True,
|
270 |
+
callback: Optional[Callable[[
|
271 |
+
int, int, torch.FloatTensor], None]] = None,
|
272 |
+
callback_steps: Optional[int] = 1,
|
273 |
+
**kwargs,
|
274 |
+
):
|
275 |
+
# Default height and width to unet
|
276 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
277 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
278 |
+
|
279 |
+
device = self._execution_device
|
280 |
+
|
281 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
282 |
+
|
283 |
+
# Prepare timesteps
|
284 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
285 |
+
timesteps = self.scheduler.timesteps
|
286 |
+
|
287 |
+
batch_size = 1
|
288 |
+
|
289 |
+
# prepare clip image embeddings
|
290 |
+
clip_image_embeds = face_emb
|
291 |
+
clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
|
292 |
+
|
293 |
+
encoder_hidden_states = self.image_proj(clip_image_embeds)
|
294 |
+
uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
|
295 |
+
|
296 |
+
if do_classifier_free_guidance:
|
297 |
+
encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
|
298 |
+
|
299 |
+
reference_control_writer = ReferenceAttentionControl(
|
300 |
+
self.reference_unet,
|
301 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
302 |
+
mode="write",
|
303 |
+
batch_size=batch_size,
|
304 |
+
fusion_blocks="full",
|
305 |
+
)
|
306 |
+
reference_control_reader = ReferenceAttentionControl(
|
307 |
+
self.denoising_unet,
|
308 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
309 |
+
mode="read",
|
310 |
+
batch_size=batch_size,
|
311 |
+
fusion_blocks="full",
|
312 |
+
)
|
313 |
+
|
314 |
+
num_channels_latents = self.denoising_unet.in_channels
|
315 |
+
|
316 |
+
latents = self.prepare_latents(
|
317 |
+
batch_size * num_images_per_prompt,
|
318 |
+
num_channels_latents,
|
319 |
+
width,
|
320 |
+
height,
|
321 |
+
video_length,
|
322 |
+
clip_image_embeds.dtype,
|
323 |
+
device,
|
324 |
+
generator,
|
325 |
+
)
|
326 |
+
|
327 |
+
# Prepare extra step kwargs.
|
328 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
329 |
+
|
330 |
+
# Prepare ref image latents
|
331 |
+
ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
|
332 |
+
ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
|
333 |
+
ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
|
334 |
+
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
|
335 |
+
ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
|
336 |
+
|
337 |
+
|
338 |
+
face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
|
339 |
+
face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
|
340 |
+
face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W)
|
341 |
+
face_mask = self.face_locator(face_mask)
|
342 |
+
face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask
|
343 |
+
|
344 |
+
pixel_values_full_mask = (
|
345 |
+
[torch.cat([mask] * 2) for mask in pixel_values_full_mask]
|
346 |
+
if do_classifier_free_guidance
|
347 |
+
else pixel_values_full_mask
|
348 |
+
)
|
349 |
+
pixel_values_face_mask = (
|
350 |
+
[torch.cat([mask] * 2) for mask in pixel_values_face_mask]
|
351 |
+
if do_classifier_free_guidance
|
352 |
+
else pixel_values_face_mask
|
353 |
+
)
|
354 |
+
pixel_values_lip_mask = (
|
355 |
+
[torch.cat([mask] * 2) for mask in pixel_values_lip_mask]
|
356 |
+
if do_classifier_free_guidance
|
357 |
+
else pixel_values_lip_mask
|
358 |
+
)
|
359 |
+
pixel_values_face_mask_ = []
|
360 |
+
for mask in pixel_values_face_mask:
|
361 |
+
pixel_values_face_mask_.append(
|
362 |
+
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
|
363 |
+
pixel_values_face_mask = pixel_values_face_mask_
|
364 |
+
pixel_values_lip_mask_ = []
|
365 |
+
for mask in pixel_values_lip_mask:
|
366 |
+
pixel_values_lip_mask_.append(
|
367 |
+
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
|
368 |
+
pixel_values_lip_mask = pixel_values_lip_mask_
|
369 |
+
pixel_values_full_mask_ = []
|
370 |
+
for mask in pixel_values_full_mask:
|
371 |
+
pixel_values_full_mask_.append(
|
372 |
+
mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
|
373 |
+
pixel_values_full_mask = pixel_values_full_mask_
|
374 |
+
|
375 |
+
|
376 |
+
uncond_audio_tensor = torch.zeros_like(audio_tensor)
|
377 |
+
audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
|
378 |
+
audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)
|
379 |
+
|
380 |
+
# denoising loop
|
381 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
382 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
383 |
+
for i, t in enumerate(timesteps):
|
384 |
+
# Forward reference image
|
385 |
+
if i == 0:
|
386 |
+
self.reference_unet(
|
387 |
+
ref_image_latents.repeat(
|
388 |
+
(2 if do_classifier_free_guidance else 1), 1, 1, 1
|
389 |
+
),
|
390 |
+
torch.zeros_like(t),
|
391 |
+
encoder_hidden_states=encoder_hidden_states,
|
392 |
+
return_dict=False,
|
393 |
+
)
|
394 |
+
reference_control_reader.update(reference_control_writer)
|
395 |
+
|
396 |
+
# expand the latents if we are doing classifier free guidance
|
397 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
398 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
399 |
+
|
400 |
+
noise_pred = self.denoising_unet(
|
401 |
+
latent_model_input,
|
402 |
+
t,
|
403 |
+
encoder_hidden_states=encoder_hidden_states,
|
404 |
+
mask_cond_fea=face_mask,
|
405 |
+
full_mask=pixel_values_full_mask,
|
406 |
+
face_mask=pixel_values_face_mask,
|
407 |
+
lip_mask=pixel_values_lip_mask,
|
408 |
+
audio_embedding=audio_tensor,
|
409 |
+
motion_scale=motion_scale,
|
410 |
+
return_dict=False,
|
411 |
+
)[0]
|
412 |
+
|
413 |
+
# perform guidance
|
414 |
+
if do_classifier_free_guidance:
|
415 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
416 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
417 |
+
|
418 |
+
# compute the previous noisy sample x_t -> x_t-1
|
419 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
420 |
+
|
421 |
+
# call the callback, if provided
|
422 |
+
if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
423 |
+
progress_bar.update()
|
424 |
+
if callback is not None and i % callback_steps == 0:
|
425 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
426 |
+
callback(step_idx, t, latents)
|
427 |
+
|
428 |
+
reference_control_reader.clear()
|
429 |
+
reference_control_writer.clear()
|
430 |
+
|
431 |
+
# Post-processing
|
432 |
+
images = self.decode_latents(latents) # (b, c, f, h, w)
|
433 |
+
|
434 |
+
# Convert to tensor
|
435 |
+
if output_type == "tensor":
|
436 |
+
images = torch.from_numpy(images)
|
437 |
+
|
438 |
+
if not return_dict:
|
439 |
+
return images
|
440 |
+
|
441 |
+
return FaceAnimatePipelineOutput(videos=images)
|
joyhallo/animate/face_animate_static.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques.
|
3 |
+
It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments.
|
4 |
+
The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance.
|
5 |
+
|
6 |
+
Functions and Classes:
|
7 |
+
- StaticPipelineOutput: A class that represents the output of the animation pipeline, c
|
8 |
+
ontaining properties and methods related to the generated images.
|
9 |
+
- prepare_latents: A function that prepares the initial noise for the animation process,
|
10 |
+
scaling it according to the scheduler's requirements.
|
11 |
+
- prepare_condition: A function that processes the user-provided conditions
|
12 |
+
(e.g., facial expressions) and prepares them for use in the animation pipeline.
|
13 |
+
- decode_latents: A function that decodes the latent representations of the face animations into
|
14 |
+
their corresponding image formats.
|
15 |
+
- prepare_extra_step_kwargs: A function that prepares additional parameters for each step of
|
16 |
+
the animation process, such as the generator and eta values.
|
17 |
+
|
18 |
+
Dependencies:
|
19 |
+
- numpy: A library for numerical computing.
|
20 |
+
- torch: A machine learning library based on PyTorch.
|
21 |
+
- diffusers: A library for image-to-image diffusion models.
|
22 |
+
- transformers: A library for pre-trained transformer models.
|
23 |
+
|
24 |
+
Usage:
|
25 |
+
- To create an instance of the animation pipeline, provide the necessary components such as
|
26 |
+
the VAE, reference UNET, denoising UNET, face locator, and image processor.
|
27 |
+
- Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as
|
28 |
+
required for the animation process.
|
29 |
+
- Generate the face animations by decoding the latents and processing the conditions.
|
30 |
+
|
31 |
+
Note:
|
32 |
+
- The module is designed to work with the diffusers library, which is based on
|
33 |
+
the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765).
|
34 |
+
- The face animations generated by this module should be used for entertainment purposes
|
35 |
+
only and should respect the rights and privacy of the individuals involved.
|
36 |
+
"""
|
37 |
+
import inspect
|
38 |
+
from dataclasses import dataclass
|
39 |
+
from typing import Callable, List, Optional, Union
|
40 |
+
|
41 |
+
import numpy as np
|
42 |
+
import torch
|
43 |
+
from diffusers import DiffusionPipeline
|
44 |
+
from diffusers.image_processor import VaeImageProcessor
|
45 |
+
from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
|
46 |
+
EulerAncestralDiscreteScheduler,
|
47 |
+
EulerDiscreteScheduler, LMSDiscreteScheduler,
|
48 |
+
PNDMScheduler)
|
49 |
+
from diffusers.utils import BaseOutput, is_accelerate_available
|
50 |
+
from diffusers.utils.torch_utils import randn_tensor
|
51 |
+
from einops import rearrange
|
52 |
+
from tqdm import tqdm
|
53 |
+
from transformers import CLIPImageProcessor
|
54 |
+
|
55 |
+
from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
|
56 |
+
|
57 |
+
if is_accelerate_available():
|
58 |
+
from accelerate import cpu_offload
|
59 |
+
else:
|
60 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class StaticPipelineOutput(BaseOutput):
|
65 |
+
"""
|
66 |
+
StaticPipelineOutput is a class that represents the output of the static pipeline.
|
67 |
+
It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
|
68 |
+
|
69 |
+
Attributes:
|
70 |
+
images (Union[torch.Tensor, np.ndarray]): The generated images.
|
71 |
+
"""
|
72 |
+
images: Union[torch.Tensor, np.ndarray]
|
73 |
+
|
74 |
+
|
75 |
+
class StaticPipeline(DiffusionPipeline):
|
76 |
+
"""
|
77 |
+
StaticPipelineOutput is a class that represents the output of the static pipeline.
|
78 |
+
It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
|
79 |
+
|
80 |
+
Attributes:
|
81 |
+
images (Union[torch.Tensor, np.ndarray]): The generated images.
|
82 |
+
"""
|
83 |
+
_optional_components = []
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
vae,
|
88 |
+
reference_unet,
|
89 |
+
denoising_unet,
|
90 |
+
face_locator,
|
91 |
+
imageproj,
|
92 |
+
scheduler: Union[
|
93 |
+
DDIMScheduler,
|
94 |
+
PNDMScheduler,
|
95 |
+
LMSDiscreteScheduler,
|
96 |
+
EulerDiscreteScheduler,
|
97 |
+
EulerAncestralDiscreteScheduler,
|
98 |
+
DPMSolverMultistepScheduler,
|
99 |
+
],
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.register_modules(
|
104 |
+
vae=vae,
|
105 |
+
reference_unet=reference_unet,
|
106 |
+
denoising_unet=denoising_unet,
|
107 |
+
face_locator=face_locator,
|
108 |
+
scheduler=scheduler,
|
109 |
+
imageproj=imageproj,
|
110 |
+
)
|
111 |
+
self.vae_scale_factor = 2 ** (
|
112 |
+
len(self.vae.config.block_out_channels) - 1)
|
113 |
+
self.clip_image_processor = CLIPImageProcessor()
|
114 |
+
self.ref_image_processor = VaeImageProcessor(
|
115 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
|
116 |
+
)
|
117 |
+
self.cond_image_processor = VaeImageProcessor(
|
118 |
+
vae_scale_factor=self.vae_scale_factor,
|
119 |
+
do_convert_rgb=True,
|
120 |
+
do_normalize=False,
|
121 |
+
)
|
122 |
+
|
123 |
+
def enable_vae_slicing(self):
|
124 |
+
"""
|
125 |
+
Enable VAE slicing.
|
126 |
+
|
127 |
+
This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images.
|
128 |
+
"""
|
129 |
+
self.vae.enable_slicing()
|
130 |
+
|
131 |
+
def disable_vae_slicing(self):
|
132 |
+
"""
|
133 |
+
Disable vae slicing.
|
134 |
+
|
135 |
+
This function disables the vae slicing for the StaticPipeline object.
|
136 |
+
It calls the `disable_slicing()` method of the vae model.
|
137 |
+
This is useful when you want to use the entire vae model for decoding latents
|
138 |
+
instead of slicing it for better performance.
|
139 |
+
"""
|
140 |
+
self.vae.disable_slicing()
|
141 |
+
|
142 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
143 |
+
"""
|
144 |
+
Offloads selected models to the GPU for increased performance.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0.
|
148 |
+
"""
|
149 |
+
device = torch.device(f"cuda:{gpu_id}")
|
150 |
+
|
151 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
152 |
+
if cpu_offloaded_model is not None:
|
153 |
+
cpu_offload(cpu_offloaded_model, device)
|
154 |
+
|
155 |
+
@property
|
156 |
+
def _execution_device(self):
|
157 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
158 |
+
return self.device
|
159 |
+
for module in self.unet.modules():
|
160 |
+
if (
|
161 |
+
hasattr(module, "_hf_hook")
|
162 |
+
and hasattr(module._hf_hook, "execution_device")
|
163 |
+
and module._hf_hook.execution_device is not None
|
164 |
+
):
|
165 |
+
return torch.device(module._hf_hook.execution_device)
|
166 |
+
return self.device
|
167 |
+
|
168 |
+
def decode_latents(self, latents):
|
169 |
+
"""
|
170 |
+
Decode the given latents to video frames.
|
171 |
+
|
172 |
+
Parameters:
|
173 |
+
latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width).
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width).
|
177 |
+
"""
|
178 |
+
video_length = latents.shape[2]
|
179 |
+
latents = 1 / 0.18215 * latents
|
180 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
181 |
+
# video = self.vae.decode(latents).sample
|
182 |
+
video = []
|
183 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
184 |
+
video.append(self.vae.decode(
|
185 |
+
latents[frame_idx: frame_idx + 1]).sample)
|
186 |
+
video = torch.cat(video)
|
187 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
188 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
189 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
190 |
+
video = video.cpu().float().numpy()
|
191 |
+
return video
|
192 |
+
|
193 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
194 |
+
"""
|
195 |
+
Prepare extra keyword arguments for the scheduler step.
|
196 |
+
|
197 |
+
Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
generator (Optional[torch.Generator]): A random number generator for reproducibility.
|
201 |
+
eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
dict: A dictionary containing the extra keyword arguments for the scheduler step.
|
205 |
+
"""
|
206 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
207 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
208 |
+
# and should be between [0, 1]
|
209 |
+
|
210 |
+
accepts_eta = "eta" in set(
|
211 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
212 |
+
)
|
213 |
+
extra_step_kwargs = {}
|
214 |
+
if accepts_eta:
|
215 |
+
extra_step_kwargs["eta"] = eta
|
216 |
+
|
217 |
+
# check if the scheduler accepts generator
|
218 |
+
accepts_generator = "generator" in set(
|
219 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
220 |
+
)
|
221 |
+
if accepts_generator:
|
222 |
+
extra_step_kwargs["generator"] = generator
|
223 |
+
return extra_step_kwargs
|
224 |
+
|
225 |
+
def prepare_latents(
|
226 |
+
self,
|
227 |
+
batch_size,
|
228 |
+
num_channels_latents,
|
229 |
+
width,
|
230 |
+
height,
|
231 |
+
dtype,
|
232 |
+
device,
|
233 |
+
generator,
|
234 |
+
latents=None,
|
235 |
+
):
|
236 |
+
"""
|
237 |
+
Prepares the initial latents for the diffusion pipeline.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
batch_size (int): The number of images to generate in one forward pass.
|
241 |
+
num_channels_latents (int): The number of channels in the latents tensor.
|
242 |
+
width (int): The width of the latents tensor.
|
243 |
+
height (int): The height of the latents tensor.
|
244 |
+
dtype (torch.dtype): The data type of the latents tensor.
|
245 |
+
device (torch.device): The device to place the latents tensor on.
|
246 |
+
generator (Optional[torch.Generator], optional): A random number generator
|
247 |
+
for reproducibility. Defaults to None.
|
248 |
+
latents (Optional[torch.Tensor], optional): Pre-computed latents to use as
|
249 |
+
initial conditions for the diffusion process. Defaults to None.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
torch.Tensor: The prepared latents tensor.
|
253 |
+
"""
|
254 |
+
shape = (
|
255 |
+
batch_size,
|
256 |
+
num_channels_latents,
|
257 |
+
height // self.vae_scale_factor,
|
258 |
+
width // self.vae_scale_factor,
|
259 |
+
)
|
260 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
261 |
+
raise ValueError(
|
262 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
263 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
264 |
+
)
|
265 |
+
|
266 |
+
if latents is None:
|
267 |
+
latents = randn_tensor(
|
268 |
+
shape, generator=generator, device=device, dtype=dtype
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
latents = latents.to(device)
|
272 |
+
|
273 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
274 |
+
latents = latents * self.scheduler.init_noise_sigma
|
275 |
+
return latents
|
276 |
+
|
277 |
+
def prepare_condition(
|
278 |
+
self,
|
279 |
+
cond_image,
|
280 |
+
width,
|
281 |
+
height,
|
282 |
+
device,
|
283 |
+
dtype,
|
284 |
+
do_classififer_free_guidance=False,
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
Prepares the condition for the face animation pipeline.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
cond_image (torch.Tensor): The conditional image tensor.
|
291 |
+
width (int): The width of the output image.
|
292 |
+
height (int): The height of the output image.
|
293 |
+
device (torch.device): The device to run the pipeline on.
|
294 |
+
dtype (torch.dtype): The data type of the tensor.
|
295 |
+
do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors.
|
299 |
+
"""
|
300 |
+
image = self.cond_image_processor.preprocess(
|
301 |
+
cond_image, height=height, width=width
|
302 |
+
).to(dtype=torch.float32)
|
303 |
+
|
304 |
+
image = image.to(device=device, dtype=dtype)
|
305 |
+
|
306 |
+
if do_classififer_free_guidance:
|
307 |
+
image = torch.cat([image] * 2)
|
308 |
+
|
309 |
+
return image
|
310 |
+
|
311 |
+
@torch.no_grad()
|
312 |
+
def __call__(
|
313 |
+
self,
|
314 |
+
ref_image,
|
315 |
+
face_mask,
|
316 |
+
width,
|
317 |
+
height,
|
318 |
+
num_inference_steps,
|
319 |
+
guidance_scale,
|
320 |
+
face_embedding,
|
321 |
+
num_images_per_prompt=1,
|
322 |
+
eta: float = 0.0,
|
323 |
+
generator: Optional[Union[torch.Generator,
|
324 |
+
List[torch.Generator]]] = None,
|
325 |
+
output_type: Optional[str] = "tensor",
|
326 |
+
return_dict: bool = True,
|
327 |
+
callback: Optional[Callable[[
|
328 |
+
int, int, torch.FloatTensor], None]] = None,
|
329 |
+
callback_steps: Optional[int] = 1,
|
330 |
+
**kwargs,
|
331 |
+
):
|
332 |
+
# Default height and width to unet
|
333 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
334 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
335 |
+
|
336 |
+
device = self._execution_device
|
337 |
+
|
338 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
339 |
+
|
340 |
+
# Prepare timesteps
|
341 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
342 |
+
timesteps = self.scheduler.timesteps
|
343 |
+
|
344 |
+
batch_size = 1
|
345 |
+
|
346 |
+
image_prompt_embeds = self.imageproj(face_embedding)
|
347 |
+
uncond_image_prompt_embeds = self.imageproj(
|
348 |
+
torch.zeros_like(face_embedding))
|
349 |
+
|
350 |
+
if do_classifier_free_guidance:
|
351 |
+
image_prompt_embeds = torch.cat(
|
352 |
+
[uncond_image_prompt_embeds, image_prompt_embeds], dim=0
|
353 |
+
)
|
354 |
+
|
355 |
+
reference_control_writer = ReferenceAttentionControl(
|
356 |
+
self.reference_unet,
|
357 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
358 |
+
mode="write",
|
359 |
+
batch_size=batch_size,
|
360 |
+
fusion_blocks="full",
|
361 |
+
)
|
362 |
+
reference_control_reader = ReferenceAttentionControl(
|
363 |
+
self.denoising_unet,
|
364 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
365 |
+
mode="read",
|
366 |
+
batch_size=batch_size,
|
367 |
+
fusion_blocks="full",
|
368 |
+
)
|
369 |
+
|
370 |
+
num_channels_latents = self.denoising_unet.in_channels
|
371 |
+
latents = self.prepare_latents(
|
372 |
+
batch_size * num_images_per_prompt,
|
373 |
+
num_channels_latents,
|
374 |
+
width,
|
375 |
+
height,
|
376 |
+
face_embedding.dtype,
|
377 |
+
device,
|
378 |
+
generator,
|
379 |
+
)
|
380 |
+
latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
|
381 |
+
# latents_dtype = latents.dtype
|
382 |
+
|
383 |
+
# Prepare extra step kwargs.
|
384 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
385 |
+
|
386 |
+
# Prepare ref image latents
|
387 |
+
ref_image_tensor = self.ref_image_processor.preprocess(
|
388 |
+
ref_image, height=height, width=width
|
389 |
+
) # (bs, c, width, height)
|
390 |
+
ref_image_tensor = ref_image_tensor.to(
|
391 |
+
dtype=self.vae.dtype, device=self.vae.device
|
392 |
+
)
|
393 |
+
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
|
394 |
+
ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
|
395 |
+
|
396 |
+
# Prepare face mask image
|
397 |
+
face_mask_tensor = self.cond_image_processor.preprocess(
|
398 |
+
face_mask, height=height, width=width
|
399 |
+
)
|
400 |
+
face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w)
|
401 |
+
face_mask_tensor = face_mask_tensor.to(
|
402 |
+
device=device, dtype=self.face_locator.dtype
|
403 |
+
)
|
404 |
+
mask_fea = self.face_locator(face_mask_tensor)
|
405 |
+
mask_fea = (
|
406 |
+
torch.cat(
|
407 |
+
[mask_fea] * 2) if do_classifier_free_guidance else mask_fea
|
408 |
+
)
|
409 |
+
|
410 |
+
# denoising loop
|
411 |
+
num_warmup_steps = len(timesteps) - \
|
412 |
+
num_inference_steps * self.scheduler.order
|
413 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
414 |
+
for i, t in enumerate(timesteps):
|
415 |
+
# 1. Forward reference image
|
416 |
+
if i == 0:
|
417 |
+
self.reference_unet(
|
418 |
+
ref_image_latents.repeat(
|
419 |
+
(2 if do_classifier_free_guidance else 1), 1, 1, 1
|
420 |
+
),
|
421 |
+
torch.zeros_like(t),
|
422 |
+
encoder_hidden_states=image_prompt_embeds,
|
423 |
+
return_dict=False,
|
424 |
+
)
|
425 |
+
|
426 |
+
# 2. Update reference unet feature into denosing net
|
427 |
+
reference_control_reader.update(reference_control_writer)
|
428 |
+
|
429 |
+
# 3.1 expand the latents if we are doing classifier free guidance
|
430 |
+
latent_model_input = (
|
431 |
+
torch.cat(
|
432 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
433 |
+
)
|
434 |
+
latent_model_input = self.scheduler.scale_model_input(
|
435 |
+
latent_model_input, t
|
436 |
+
)
|
437 |
+
|
438 |
+
noise_pred = self.denoising_unet(
|
439 |
+
latent_model_input,
|
440 |
+
t,
|
441 |
+
encoder_hidden_states=image_prompt_embeds,
|
442 |
+
mask_cond_fea=mask_fea,
|
443 |
+
return_dict=False,
|
444 |
+
)[0]
|
445 |
+
|
446 |
+
# perform guidance
|
447 |
+
if do_classifier_free_guidance:
|
448 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
449 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
450 |
+
noise_pred_text - noise_pred_uncond
|
451 |
+
)
|
452 |
+
|
453 |
+
# compute the previous noisy sample x_t -> x_t-1
|
454 |
+
latents = self.scheduler.step(
|
455 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
456 |
+
)[0]
|
457 |
+
|
458 |
+
# call the callback, if provided
|
459 |
+
if i == len(timesteps) - 1 or (
|
460 |
+
(i + 1) > num_warmup_steps and (i +
|
461 |
+
1) % self.scheduler.order == 0
|
462 |
+
):
|
463 |
+
progress_bar.update()
|
464 |
+
if callback is not None and i % callback_steps == 0:
|
465 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
466 |
+
callback(step_idx, t, latents)
|
467 |
+
reference_control_reader.clear()
|
468 |
+
reference_control_writer.clear()
|
469 |
+
|
470 |
+
# Post-processing
|
471 |
+
image = self.decode_latents(latents) # (b, c, 1, h, w)
|
472 |
+
|
473 |
+
# Convert to tensor
|
474 |
+
if output_type == "tensor":
|
475 |
+
image = torch.from_numpy(image)
|
476 |
+
|
477 |
+
if not return_dict:
|
478 |
+
return image
|
479 |
+
|
480 |
+
return StaticPipelineOutput(images=image)
|
joyhallo/datasets/__init__.py
ADDED
File without changes
|
joyhallo/datasets/audio_processor.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This module contains the AudioProcessor class and related functions for processing audio data.
|
3 |
+
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
|
4 |
+
and audio separation. The class is initialized with configuration parameters and can process
|
5 |
+
audio files using the provided models.
|
6 |
+
'''
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
|
10 |
+
import librosa
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from audio_separator.separator import Separator
|
14 |
+
from einops import rearrange
|
15 |
+
from transformers import Wav2Vec2FeatureExtractor
|
16 |
+
|
17 |
+
from joyhallo.models.wav2vec import Wav2VecModel
|
18 |
+
from joyhallo.utils.util import resample_audio
|
19 |
+
|
20 |
+
|
21 |
+
class AudioProcessor:
|
22 |
+
"""
|
23 |
+
AudioProcessor is a class that handles the processing of audio files.
|
24 |
+
It takes care of preprocessing the audio files, extracting features
|
25 |
+
using wav2vec models, and separating audio signals if needed.
|
26 |
+
|
27 |
+
:param sample_rate: Sampling rate of the audio file
|
28 |
+
:param fps: Frames per second for the extracted features
|
29 |
+
:param wav2vec_model_path: Path to the wav2vec model
|
30 |
+
:param only_last_features: Whether to only use the last features
|
31 |
+
:param audio_separator_model_path: Path to the audio separator model
|
32 |
+
:param audio_separator_model_name: Name of the audio separator model
|
33 |
+
:param cache_dir: Directory to cache the intermediate results
|
34 |
+
:param device: Device to run the processing on
|
35 |
+
"""
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
sample_rate,
|
39 |
+
fps,
|
40 |
+
wav2vec_model_path,
|
41 |
+
only_last_features,
|
42 |
+
audio_separator_model_path:str=None,
|
43 |
+
audio_separator_model_name:str=None,
|
44 |
+
cache_dir:str='',
|
45 |
+
device="cuda:0",
|
46 |
+
) -> None:
|
47 |
+
self.sample_rate = sample_rate
|
48 |
+
self.fps = fps
|
49 |
+
self.device = device
|
50 |
+
|
51 |
+
self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
|
52 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
53 |
+
self.only_last_features = only_last_features
|
54 |
+
|
55 |
+
if audio_separator_model_name is not None:
|
56 |
+
try:
|
57 |
+
os.makedirs(cache_dir, exist_ok=True)
|
58 |
+
except OSError as _:
|
59 |
+
print("Fail to create the output cache dir.")
|
60 |
+
self.audio_separator = Separator(
|
61 |
+
output_dir=cache_dir,
|
62 |
+
output_single_stem="vocals",
|
63 |
+
model_file_dir=audio_separator_model_path,
|
64 |
+
)
|
65 |
+
self.audio_separator.load_model(audio_separator_model_name)
|
66 |
+
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
|
67 |
+
else:
|
68 |
+
self.audio_separator=None
|
69 |
+
print("Use audio directly without vocals seperator.")
|
70 |
+
|
71 |
+
|
72 |
+
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
|
73 |
+
|
74 |
+
|
75 |
+
def preprocess(self, wav_file: str, clip_length: int=-1):
|
76 |
+
"""
|
77 |
+
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
|
78 |
+
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
|
82 |
+
|
83 |
+
Raises:
|
84 |
+
RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
|
85 |
+
such as file not found, unsupported file format, or errors during the audio processing steps.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
torch.tensor: Returns an audio embedding as a torch.tensor
|
89 |
+
"""
|
90 |
+
if self.audio_separator is not None:
|
91 |
+
# 1. separate vocals
|
92 |
+
# TODO: process in memory
|
93 |
+
outputs = self.audio_separator.separate(wav_file)
|
94 |
+
if len(outputs) <= 0:
|
95 |
+
raise RuntimeError("Audio separate failed.")
|
96 |
+
|
97 |
+
vocal_audio_file = outputs[0]
|
98 |
+
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
|
99 |
+
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
|
100 |
+
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
|
101 |
+
else:
|
102 |
+
vocal_audio_file=wav_file
|
103 |
+
|
104 |
+
# 2. extract wav2vec features
|
105 |
+
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
|
106 |
+
audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
|
107 |
+
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
|
108 |
+
audio_length = seq_len
|
109 |
+
|
110 |
+
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
|
111 |
+
|
112 |
+
if clip_length>0 and seq_len % clip_length != 0:
|
113 |
+
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
|
114 |
+
seq_len += clip_length - seq_len % clip_length
|
115 |
+
audio_feature = audio_feature.unsqueeze(0)
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
|
119 |
+
assert len(embeddings) > 0, "Fail to extract audio embedding"
|
120 |
+
if self.only_last_features:
|
121 |
+
audio_emb = embeddings.last_hidden_state.squeeze()
|
122 |
+
else:
|
123 |
+
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
124 |
+
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
125 |
+
|
126 |
+
audio_emb = audio_emb.cpu().detach()
|
127 |
+
|
128 |
+
return audio_emb, audio_length
|
129 |
+
|
130 |
+
def get_embedding(self, wav_file: str):
|
131 |
+
"""preprocess wav audio file convert to embeddings
|
132 |
+
|
133 |
+
Args:
|
134 |
+
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
torch.tensor: Returns an audio embedding as a torch.tensor
|
138 |
+
"""
|
139 |
+
speech_array, sampling_rate = librosa.load(
|
140 |
+
wav_file, sr=self.sample_rate)
|
141 |
+
assert sampling_rate == 16000, "The audio sample rate must be 16000"
|
142 |
+
audio_feature = np.squeeze(self.wav2vec_feature_extractor(
|
143 |
+
speech_array, sampling_rate=sampling_rate).input_values)
|
144 |
+
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
|
145 |
+
|
146 |
+
audio_feature = torch.from_numpy(
|
147 |
+
audio_feature).float().to(device=self.device)
|
148 |
+
audio_feature = audio_feature.unsqueeze(0)
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
embeddings = self.audio_encoder(
|
152 |
+
audio_feature, seq_len=seq_len, output_hidden_states=True)
|
153 |
+
assert len(embeddings) > 0, "Fail to extract audio embedding"
|
154 |
+
|
155 |
+
if self.only_last_features:
|
156 |
+
audio_emb = embeddings.last_hidden_state.squeeze()
|
157 |
+
else:
|
158 |
+
audio_emb = torch.stack(
|
159 |
+
embeddings.hidden_states[1:], dim=1).squeeze(0)
|
160 |
+
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
161 |
+
|
162 |
+
audio_emb = audio_emb.cpu().detach()
|
163 |
+
|
164 |
+
return audio_emb
|
165 |
+
|
166 |
+
def close(self):
|
167 |
+
"""
|
168 |
+
TODO: to be implemented
|
169 |
+
"""
|
170 |
+
return self
|
171 |
+
|
172 |
+
def __enter__(self):
|
173 |
+
return self
|
174 |
+
|
175 |
+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
176 |
+
self.close()
|
joyhallo/datasets/image_processor.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module is responsible for processing images, particularly for face-related tasks.
|
3 |
+
It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
|
4 |
+
face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates
|
5 |
+
the functionality for these operations.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import mediapipe as mp
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from insightface.app import FaceAnalysis
|
15 |
+
from PIL import Image
|
16 |
+
from torchvision import transforms
|
17 |
+
|
18 |
+
from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
|
19 |
+
get_union_face_mask, get_union_lip_mask)
|
20 |
+
|
21 |
+
MEAN = 0.5
|
22 |
+
STD = 0.5
|
23 |
+
|
24 |
+
class ImageProcessor:
|
25 |
+
"""
|
26 |
+
ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
|
27 |
+
It takes in an image and performs various operations such as augmentation, face detection,
|
28 |
+
face embedding extraction, and rendering a face mask. The processed images are then used for
|
29 |
+
further analysis or recognition purposes.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
img_size (int): The size of the image to be processed.
|
33 |
+
face_analysis_model_path (str): The path to the face analysis model.
|
34 |
+
|
35 |
+
Methods:
|
36 |
+
preprocess(source_image_path, cache_dir):
|
37 |
+
Preprocesses the input image by performing augmentation, face detection,
|
38 |
+
face embedding extraction, and rendering a face mask.
|
39 |
+
|
40 |
+
close():
|
41 |
+
Closes the ImageProcessor and releases any resources being used.
|
42 |
+
|
43 |
+
_augmentation(images, transform, state=None):
|
44 |
+
Applies image augmentation to the input images using the given transform and state.
|
45 |
+
|
46 |
+
__enter__():
|
47 |
+
Enters a runtime context and returns the ImageProcessor object.
|
48 |
+
|
49 |
+
__exit__(_exc_type, _exc_val, _exc_tb):
|
50 |
+
Exits a runtime context and handles any exceptions that occurred during the processing.
|
51 |
+
"""
|
52 |
+
def __init__(self, img_size, face_analysis_model_path) -> None:
|
53 |
+
self.img_size = img_size
|
54 |
+
|
55 |
+
self.pixel_transform = transforms.Compose(
|
56 |
+
[
|
57 |
+
transforms.Resize(self.img_size),
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Normalize([MEAN], [STD]),
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
self.cond_transform = transforms.Compose(
|
64 |
+
[
|
65 |
+
transforms.Resize(self.img_size),
|
66 |
+
transforms.ToTensor(),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
self.attn_transform_64 = transforms.Compose(
|
71 |
+
[
|
72 |
+
transforms.Resize(
|
73 |
+
(self.img_size[0] // 8, self.img_size[0] // 8)),
|
74 |
+
transforms.ToTensor(),
|
75 |
+
]
|
76 |
+
)
|
77 |
+
self.attn_transform_32 = transforms.Compose(
|
78 |
+
[
|
79 |
+
transforms.Resize(
|
80 |
+
(self.img_size[0] // 16, self.img_size[0] // 16)),
|
81 |
+
transforms.ToTensor(),
|
82 |
+
]
|
83 |
+
)
|
84 |
+
self.attn_transform_16 = transforms.Compose(
|
85 |
+
[
|
86 |
+
transforms.Resize(
|
87 |
+
(self.img_size[0] // 32, self.img_size[0] // 32)),
|
88 |
+
transforms.ToTensor(),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
self.attn_transform_8 = transforms.Compose(
|
92 |
+
[
|
93 |
+
transforms.Resize(
|
94 |
+
(self.img_size[0] // 64, self.img_size[0] // 64)),
|
95 |
+
transforms.ToTensor(),
|
96 |
+
]
|
97 |
+
)
|
98 |
+
|
99 |
+
self.face_analysis = FaceAnalysis(
|
100 |
+
name="",
|
101 |
+
root=face_analysis_model_path,
|
102 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
103 |
+
)
|
104 |
+
self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
|
105 |
+
|
106 |
+
def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float):
|
107 |
+
"""
|
108 |
+
Apply preprocessing to the source image to prepare for face analysis.
|
109 |
+
|
110 |
+
Parameters:
|
111 |
+
source_image_path (str): The path to the source image.
|
112 |
+
cache_dir (str): The directory to cache intermediate results.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
None
|
116 |
+
"""
|
117 |
+
source_image = Image.open(source_image_path)
|
118 |
+
ref_image_pil = source_image.convert("RGB")
|
119 |
+
# 1. image augmentation
|
120 |
+
pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform)
|
121 |
+
|
122 |
+
# 2.1 detect face
|
123 |
+
faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
|
124 |
+
if not faces:
|
125 |
+
print("No faces detected in the image. Using the entire image as the face region.")
|
126 |
+
# Use the entire image as the face region
|
127 |
+
face = {
|
128 |
+
"bbox": [0, 0, ref_image_pil.width, ref_image_pil.height],
|
129 |
+
"embedding": np.zeros(512)
|
130 |
+
}
|
131 |
+
else:
|
132 |
+
# Sort faces by size and select the largest one
|
133 |
+
faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True)
|
134 |
+
face = faces_sorted[0] # Select the largest face
|
135 |
+
|
136 |
+
# 2.2 face embedding
|
137 |
+
face_emb = face["embedding"]
|
138 |
+
|
139 |
+
# 2.3 render face mask
|
140 |
+
get_mask(source_image_path, cache_dir, face_region_ratio)
|
141 |
+
file_name = os.path.basename(source_image_path).split(".")[0]
|
142 |
+
face_mask_pil = Image.open(
|
143 |
+
os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB")
|
144 |
+
|
145 |
+
face_mask = self._augmentation(face_mask_pil, self.cond_transform)
|
146 |
+
|
147 |
+
# 2.4 detect and expand lip, face mask
|
148 |
+
sep_background_mask = Image.open(
|
149 |
+
os.path.join(cache_dir, f"{file_name}_sep_background.png"))
|
150 |
+
sep_face_mask = Image.open(
|
151 |
+
os.path.join(cache_dir, f"{file_name}_sep_face.png"))
|
152 |
+
sep_lip_mask = Image.open(
|
153 |
+
os.path.join(cache_dir, f"{file_name}_sep_lip.png"))
|
154 |
+
|
155 |
+
pixel_values_face_mask = [
|
156 |
+
self._augmentation(sep_face_mask, self.attn_transform_64),
|
157 |
+
self._augmentation(sep_face_mask, self.attn_transform_32),
|
158 |
+
self._augmentation(sep_face_mask, self.attn_transform_16),
|
159 |
+
self._augmentation(sep_face_mask, self.attn_transform_8),
|
160 |
+
]
|
161 |
+
pixel_values_lip_mask = [
|
162 |
+
self._augmentation(sep_lip_mask, self.attn_transform_64),
|
163 |
+
self._augmentation(sep_lip_mask, self.attn_transform_32),
|
164 |
+
self._augmentation(sep_lip_mask, self.attn_transform_16),
|
165 |
+
self._augmentation(sep_lip_mask, self.attn_transform_8),
|
166 |
+
]
|
167 |
+
pixel_values_full_mask = [
|
168 |
+
self._augmentation(sep_background_mask, self.attn_transform_64),
|
169 |
+
self._augmentation(sep_background_mask, self.attn_transform_32),
|
170 |
+
self._augmentation(sep_background_mask, self.attn_transform_16),
|
171 |
+
self._augmentation(sep_background_mask, self.attn_transform_8),
|
172 |
+
]
|
173 |
+
|
174 |
+
pixel_values_full_mask = [mask.view(1, -1)
|
175 |
+
for mask in pixel_values_full_mask]
|
176 |
+
pixel_values_face_mask = [mask.view(1, -1)
|
177 |
+
for mask in pixel_values_face_mask]
|
178 |
+
pixel_values_lip_mask = [mask.view(1, -1)
|
179 |
+
for mask in pixel_values_lip_mask]
|
180 |
+
|
181 |
+
return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask
|
182 |
+
|
183 |
+
def close(self):
|
184 |
+
"""
|
185 |
+
Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
self: The ImageProcessor instance.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
None.
|
192 |
+
"""
|
193 |
+
for _, model in self.face_analysis.models.items():
|
194 |
+
if hasattr(model, "Dispose"):
|
195 |
+
model.Dispose()
|
196 |
+
|
197 |
+
def _augmentation(self, images, transform, state=None):
|
198 |
+
if state is not None:
|
199 |
+
torch.set_rng_state(state)
|
200 |
+
if isinstance(images, List):
|
201 |
+
transformed_images = [transform(img) for img in images]
|
202 |
+
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
|
203 |
+
else:
|
204 |
+
ret_tensor = transform(images) # (c, h, w)
|
205 |
+
return ret_tensor
|
206 |
+
|
207 |
+
def __enter__(self):
|
208 |
+
return self
|
209 |
+
|
210 |
+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
211 |
+
self.close()
|
212 |
+
|
213 |
+
|
214 |
+
class ImageProcessorForDataProcessing():
|
215 |
+
"""
|
216 |
+
ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
|
217 |
+
It takes in an image and performs various operations such as augmentation, face detection,
|
218 |
+
face embedding extraction, and rendering a face mask. The processed images are then used for
|
219 |
+
further analysis or recognition purposes.
|
220 |
+
|
221 |
+
Attributes:
|
222 |
+
img_size (int): The size of the image to be processed.
|
223 |
+
face_analysis_model_path (str): The path to the face analysis model.
|
224 |
+
|
225 |
+
Methods:
|
226 |
+
preprocess(source_image_path, cache_dir):
|
227 |
+
Preprocesses the input image by performing augmentation, face detection,
|
228 |
+
face embedding extraction, and rendering a face mask.
|
229 |
+
|
230 |
+
close():
|
231 |
+
Closes the ImageProcessor and releases any resources being used.
|
232 |
+
|
233 |
+
_augmentation(images, transform, state=None):
|
234 |
+
Applies image augmentation to the input images using the given transform and state.
|
235 |
+
|
236 |
+
__enter__():
|
237 |
+
Enters a runtime context and returns the ImageProcessor object.
|
238 |
+
|
239 |
+
__exit__(_exc_type, _exc_val, _exc_tb):
|
240 |
+
Exits a runtime context and handles any exceptions that occurred during the processing.
|
241 |
+
"""
|
242 |
+
def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
|
243 |
+
if step == 2:
|
244 |
+
self.face_analysis = FaceAnalysis(
|
245 |
+
name="",
|
246 |
+
root=face_analysis_model_path,
|
247 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
248 |
+
)
|
249 |
+
self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
|
250 |
+
self.landmarker = None
|
251 |
+
else:
|
252 |
+
BaseOptions = mp.tasks.BaseOptions
|
253 |
+
FaceLandmarker = mp.tasks.vision.FaceLandmarker
|
254 |
+
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
|
255 |
+
VisionRunningMode = mp.tasks.vision.RunningMode
|
256 |
+
# Create a face landmarker instance with the video mode:
|
257 |
+
options = FaceLandmarkerOptions(
|
258 |
+
base_options=BaseOptions(model_asset_path=landmark_model_path),
|
259 |
+
running_mode=VisionRunningMode.IMAGE,
|
260 |
+
)
|
261 |
+
self.landmarker = FaceLandmarker.create_from_options(options)
|
262 |
+
self.face_analysis = None
|
263 |
+
|
264 |
+
def preprocess(self, source_image_path: str):
|
265 |
+
"""
|
266 |
+
Apply preprocessing to the source image to prepare for face analysis.
|
267 |
+
|
268 |
+
Parameters:
|
269 |
+
source_image_path (str): The path to the source image.
|
270 |
+
cache_dir (str): The directory to cache intermediate results.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
None
|
274 |
+
"""
|
275 |
+
# 1. get face embdeding
|
276 |
+
face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
|
277 |
+
if self.face_analysis:
|
278 |
+
for frame in sorted(os.listdir(source_image_path)):
|
279 |
+
try:
|
280 |
+
source_image = Image.open(
|
281 |
+
os.path.join(source_image_path, frame))
|
282 |
+
ref_image_pil = source_image.convert("RGB")
|
283 |
+
# 2.1 detect face
|
284 |
+
faces = self.face_analysis.get(cv2.cvtColor(
|
285 |
+
np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
|
286 |
+
# use max size face
|
287 |
+
face = sorted(faces, key=lambda x: (
|
288 |
+
x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
|
289 |
+
# 2.2 face embedding
|
290 |
+
face_emb = face["embedding"]
|
291 |
+
if face_emb is not None:
|
292 |
+
break
|
293 |
+
except Exception as _:
|
294 |
+
continue
|
295 |
+
|
296 |
+
if self.landmarker:
|
297 |
+
# 3.1 get landmark
|
298 |
+
landmarks, height, width = get_landmark_overframes(
|
299 |
+
self.landmarker, source_image_path)
|
300 |
+
assert len(landmarks) == len(os.listdir(source_image_path))
|
301 |
+
|
302 |
+
# 3 render face and lip mask
|
303 |
+
face_mask = get_union_face_mask(landmarks, height, width)
|
304 |
+
lip_mask = get_union_lip_mask(landmarks, height, width)
|
305 |
+
|
306 |
+
# 4 gaussian blur
|
307 |
+
blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
|
308 |
+
blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
|
309 |
+
|
310 |
+
# 5 seperate mask
|
311 |
+
sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
|
312 |
+
sep_pose_mask = 255.0 - blur_face_mask
|
313 |
+
sep_lip_mask = blur_lip_mask
|
314 |
+
|
315 |
+
return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
|
316 |
+
|
317 |
+
def close(self):
|
318 |
+
"""
|
319 |
+
Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
self: The ImageProcessor instance.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
None.
|
326 |
+
"""
|
327 |
+
for _, model in self.face_analysis.models.items():
|
328 |
+
if hasattr(model, "Dispose"):
|
329 |
+
model.Dispose()
|
330 |
+
|
331 |
+
def _augmentation(self, images, transform, state=None):
|
332 |
+
if state is not None:
|
333 |
+
torch.set_rng_state(state)
|
334 |
+
if isinstance(images, List):
|
335 |
+
transformed_images = [transform(img) for img in images]
|
336 |
+
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
|
337 |
+
else:
|
338 |
+
ret_tensor = transform(images) # (c, h, w)
|
339 |
+
return ret_tensor
|
340 |
+
|
341 |
+
def __enter__(self):
|
342 |
+
return self
|
343 |
+
|
344 |
+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
345 |
+
self.close()
|
joyhallo/datasets/mask_image.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
|
3 |
+
load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
|
4 |
+
provides methods for data augmentation, getting items from the dataset, and determining the length of the
|
5 |
+
dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
|
6 |
+
PIL, and transformers.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from PIL import Image
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from torchvision import transforms
|
17 |
+
from transformers import CLIPImageProcessor
|
18 |
+
|
19 |
+
|
20 |
+
class FaceMaskDataset(Dataset):
|
21 |
+
"""
|
22 |
+
FaceMaskDataset is a custom dataset for face mask images.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
img_size (int): The size of the input images.
|
26 |
+
drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
|
27 |
+
data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
|
28 |
+
sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
|
29 |
+
|
30 |
+
Attributes:
|
31 |
+
img_size (int): The size of the input images.
|
32 |
+
drop_ratio (float): The ratio of dropped pixels during data augmentation.
|
33 |
+
data_meta_paths (list): The paths to the metadata files containing image paths and labels.
|
34 |
+
sample_margin (int): The margin for sampling regions in the image.
|
35 |
+
processor (CLIPImageProcessor): The image processor for preprocessing images.
|
36 |
+
transform (transforms.Compose): The image augmentation transform.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
img_size,
|
42 |
+
drop_ratio=0.1,
|
43 |
+
data_meta_paths=None,
|
44 |
+
sample_margin=30,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.img_size = img_size
|
49 |
+
self.sample_margin = sample_margin
|
50 |
+
|
51 |
+
vid_meta = []
|
52 |
+
for data_meta_path in data_meta_paths:
|
53 |
+
with open(data_meta_path, "r", encoding="utf-8") as f:
|
54 |
+
vid_meta.extend(json.load(f))
|
55 |
+
self.vid_meta = vid_meta
|
56 |
+
self.length = len(self.vid_meta)
|
57 |
+
|
58 |
+
self.clip_image_processor = CLIPImageProcessor()
|
59 |
+
|
60 |
+
self.transform = transforms.Compose(
|
61 |
+
[
|
62 |
+
transforms.Resize(self.img_size),
|
63 |
+
transforms.ToTensor(),
|
64 |
+
transforms.Normalize([0.5], [0.5]),
|
65 |
+
]
|
66 |
+
)
|
67 |
+
|
68 |
+
self.cond_transform = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.Resize(self.img_size),
|
71 |
+
transforms.ToTensor(),
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
self.drop_ratio = drop_ratio
|
76 |
+
|
77 |
+
def augmentation(self, image, transform, state=None):
|
78 |
+
"""
|
79 |
+
Apply data augmentation to the input image.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
image (PIL.Image): The input image.
|
83 |
+
transform (torchvision.transforms.Compose): The data augmentation transforms.
|
84 |
+
state (dict, optional): The random state for reproducibility. Defaults to None.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
PIL.Image: The augmented image.
|
88 |
+
"""
|
89 |
+
if state is not None:
|
90 |
+
torch.set_rng_state(state)
|
91 |
+
return transform(image)
|
92 |
+
|
93 |
+
def __getitem__(self, index):
|
94 |
+
video_meta = self.vid_meta[index]
|
95 |
+
video_path = video_meta["image_path"]
|
96 |
+
mask_path = video_meta["mask_path"]
|
97 |
+
face_emb_path = video_meta["face_emb"]
|
98 |
+
|
99 |
+
video_frames = sorted(Path(video_path).iterdir())
|
100 |
+
video_length = len(video_frames)
|
101 |
+
|
102 |
+
margin = min(self.sample_margin, video_length)
|
103 |
+
|
104 |
+
ref_img_idx = random.randint(0, video_length - 1)
|
105 |
+
if ref_img_idx + margin < video_length:
|
106 |
+
tgt_img_idx = random.randint(
|
107 |
+
ref_img_idx + margin, video_length - 1)
|
108 |
+
elif ref_img_idx - margin > 0:
|
109 |
+
tgt_img_idx = random.randint(0, ref_img_idx - margin)
|
110 |
+
else:
|
111 |
+
tgt_img_idx = random.randint(0, video_length - 1)
|
112 |
+
|
113 |
+
ref_img_pil = Image.open(video_frames[ref_img_idx])
|
114 |
+
tgt_img_pil = Image.open(video_frames[tgt_img_idx])
|
115 |
+
|
116 |
+
tgt_mask_pil = Image.open(mask_path)
|
117 |
+
|
118 |
+
assert ref_img_pil is not None, "Fail to load reference image."
|
119 |
+
assert tgt_img_pil is not None, "Fail to load target image."
|
120 |
+
assert tgt_mask_pil is not None, "Fail to load target mask."
|
121 |
+
|
122 |
+
state = torch.get_rng_state()
|
123 |
+
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
|
124 |
+
tgt_mask_img = self.augmentation(
|
125 |
+
tgt_mask_pil, self.cond_transform, state)
|
126 |
+
tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
|
127 |
+
ref_img_vae = self.augmentation(
|
128 |
+
ref_img_pil, self.transform, state)
|
129 |
+
face_emb = torch.load(face_emb_path)
|
130 |
+
|
131 |
+
|
132 |
+
sample = {
|
133 |
+
"video_dir": video_path,
|
134 |
+
"img": tgt_img,
|
135 |
+
"tgt_mask": tgt_mask_img,
|
136 |
+
"ref_img": ref_img_vae,
|
137 |
+
"face_emb": face_emb,
|
138 |
+
}
|
139 |
+
|
140 |
+
return sample
|
141 |
+
|
142 |
+
def __len__(self):
|
143 |
+
return len(self.vid_meta)
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
data = FaceMaskDataset(img_size=(512, 512))
|
148 |
+
train_dataloader = torch.utils.data.DataLoader(
|
149 |
+
data, batch_size=4, shuffle=True, num_workers=1
|
150 |
+
)
|
151 |
+
for step, batch in enumerate(train_dataloader):
|
152 |
+
print(batch["tgt_mask"].shape)
|
153 |
+
break
|
joyhallo/datasets/talk_video.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
talking_video_dataset.py
|
3 |
+
|
4 |
+
This module defines the TalkingVideoDataset class, a custom PyTorch dataset
|
5 |
+
for handling talking video data. The dataset uses video files, masks, and
|
6 |
+
embeddings to prepare data for tasks such as video generation and
|
7 |
+
speech-driven video animation.
|
8 |
+
|
9 |
+
Classes:
|
10 |
+
TalkingVideoDataset
|
11 |
+
|
12 |
+
Dependencies:
|
13 |
+
json
|
14 |
+
random
|
15 |
+
torch
|
16 |
+
decord.VideoReader, decord.cpu
|
17 |
+
PIL.Image
|
18 |
+
torch.utils.data.Dataset
|
19 |
+
torchvision.transforms
|
20 |
+
|
21 |
+
Example:
|
22 |
+
from talking_video_dataset import TalkingVideoDataset
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
# Example configuration for the Wav2Vec model
|
26 |
+
class Wav2VecConfig:
|
27 |
+
def __init__(self, audio_type, model_scale, features):
|
28 |
+
self.audio_type = audio_type
|
29 |
+
self.model_scale = model_scale
|
30 |
+
self.features = features
|
31 |
+
|
32 |
+
wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature")
|
33 |
+
|
34 |
+
# Initialize dataset
|
35 |
+
dataset = TalkingVideoDataset(
|
36 |
+
img_size=(512, 512),
|
37 |
+
sample_rate=16000,
|
38 |
+
audio_margin=2,
|
39 |
+
n_motion_frames=0,
|
40 |
+
n_sample_frames=16,
|
41 |
+
data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"],
|
42 |
+
wav2vec_cfg=wav2vec_cfg,
|
43 |
+
)
|
44 |
+
|
45 |
+
# Initialize dataloader
|
46 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
47 |
+
|
48 |
+
# Fetch one batch of data
|
49 |
+
batch = next(iter(dataloader))
|
50 |
+
print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512)
|
51 |
+
|
52 |
+
The TalkingVideoDataset class provides methods for loading video frames, masks,
|
53 |
+
audio embeddings, and other relevant data, applying transformations, and preparing
|
54 |
+
the data for training and evaluation in a deep learning pipeline.
|
55 |
+
|
56 |
+
Attributes:
|
57 |
+
img_size (tuple): The dimensions to resize the video frames to.
|
58 |
+
sample_rate (int): The audio sample rate.
|
59 |
+
audio_margin (int): The margin for audio sampling.
|
60 |
+
n_motion_frames (int): The number of motion frames.
|
61 |
+
n_sample_frames (int): The number of sample frames.
|
62 |
+
data_meta_paths (list): List of paths to the JSON metadata files.
|
63 |
+
wav2vec_cfg (object): Configuration for the Wav2Vec model.
|
64 |
+
|
65 |
+
Methods:
|
66 |
+
augmentation(images, transform, state=None): Apply transformation to input images.
|
67 |
+
__getitem__(index): Get a sample from the dataset at the specified index.
|
68 |
+
__len__(): Return the length of the dataset.
|
69 |
+
"""
|
70 |
+
|
71 |
+
import json
|
72 |
+
import random
|
73 |
+
from typing import List
|
74 |
+
|
75 |
+
import torch
|
76 |
+
from decord import VideoReader, cpu
|
77 |
+
from PIL import Image
|
78 |
+
from torch.utils.data import Dataset
|
79 |
+
from torchvision import transforms
|
80 |
+
|
81 |
+
|
82 |
+
class TalkingVideoDataset(Dataset):
|
83 |
+
"""
|
84 |
+
A dataset class for processing talking video data.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
|
88 |
+
sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
|
89 |
+
audio_margin (int, optional): The margin for the audio data. Defaults to 2.
|
90 |
+
n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
|
91 |
+
n_sample_frames (int, optional): The number of sample frames. Defaults to 16.
|
92 |
+
data_meta_paths (list, optional): The paths to the data metadata. Defaults to None.
|
93 |
+
wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None.
|
94 |
+
|
95 |
+
Attributes:
|
96 |
+
img_size (tuple): The size of the output images.
|
97 |
+
sample_rate (int): The sample rate of the audio data.
|
98 |
+
audio_margin (int): The margin for the audio data.
|
99 |
+
n_motion_frames (int): The number of motion frames.
|
100 |
+
n_sample_frames (int): The number of sample frames.
|
101 |
+
data_meta_paths (list): The paths to the data metadata.
|
102 |
+
wav2vec_cfg (dict): The configuration for the wav2vec model.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
img_size=(512, 512),
|
108 |
+
sample_rate=16000,
|
109 |
+
audio_margin=2,
|
110 |
+
n_motion_frames=0,
|
111 |
+
n_sample_frames=16,
|
112 |
+
data_meta_paths=None,
|
113 |
+
wav2vec_cfg=None,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.sample_rate = sample_rate
|
117 |
+
self.img_size = img_size
|
118 |
+
self.audio_margin = audio_margin
|
119 |
+
self.n_motion_frames = n_motion_frames
|
120 |
+
self.n_sample_frames = n_sample_frames
|
121 |
+
self.audio_type = wav2vec_cfg.audio_type
|
122 |
+
self.audio_model = wav2vec_cfg.model_scale
|
123 |
+
self.audio_features = wav2vec_cfg.features
|
124 |
+
|
125 |
+
vid_meta = []
|
126 |
+
for data_meta_path in data_meta_paths:
|
127 |
+
with open(data_meta_path, "r", encoding="utf-8") as f:
|
128 |
+
vid_meta.extend(json.load(f))
|
129 |
+
self.vid_meta = vid_meta
|
130 |
+
self.length = len(self.vid_meta)
|
131 |
+
self.pixel_transform = transforms.Compose(
|
132 |
+
[
|
133 |
+
transforms.Resize(self.img_size),
|
134 |
+
transforms.ToTensor(),
|
135 |
+
transforms.Normalize([0.5], [0.5]),
|
136 |
+
]
|
137 |
+
)
|
138 |
+
|
139 |
+
self.cond_transform = transforms.Compose(
|
140 |
+
[
|
141 |
+
transforms.Resize(self.img_size),
|
142 |
+
transforms.ToTensor(),
|
143 |
+
]
|
144 |
+
)
|
145 |
+
self.attn_transform_64 = transforms.Compose(
|
146 |
+
[
|
147 |
+
transforms.Resize(
|
148 |
+
(self.img_size[0] // 8, self.img_size[0] // 8)),
|
149 |
+
transforms.ToTensor(),
|
150 |
+
]
|
151 |
+
)
|
152 |
+
self.attn_transform_32 = transforms.Compose(
|
153 |
+
[
|
154 |
+
transforms.Resize(
|
155 |
+
(self.img_size[0] // 16, self.img_size[0] // 16)),
|
156 |
+
transforms.ToTensor(),
|
157 |
+
]
|
158 |
+
)
|
159 |
+
self.attn_transform_16 = transforms.Compose(
|
160 |
+
[
|
161 |
+
transforms.Resize(
|
162 |
+
(self.img_size[0] // 32, self.img_size[0] // 32)),
|
163 |
+
transforms.ToTensor(),
|
164 |
+
]
|
165 |
+
)
|
166 |
+
self.attn_transform_8 = transforms.Compose(
|
167 |
+
[
|
168 |
+
transforms.Resize(
|
169 |
+
(self.img_size[0] // 64, self.img_size[0] // 64)),
|
170 |
+
transforms.ToTensor(),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
def augmentation(self, images, transform, state=None):
|
175 |
+
"""
|
176 |
+
Apply the given transformation to the input images.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
images (List[PIL.Image] or PIL.Image): The input images to be transformed.
|
180 |
+
transform (torchvision.transforms.Compose): The transformation to be applied to the images.
|
181 |
+
state (torch.ByteTensor, optional): The state of the random number generator.
|
182 |
+
If provided, it will set the RNG state to this value before applying the transformation. Defaults to None.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
torch.Tensor: The transformed images as a tensor.
|
186 |
+
If the input was a list of images, the tensor will have shape (f, c, h, w),
|
187 |
+
where f is the number of images, c is the number of channels, h is the height, and w is the width.
|
188 |
+
If the input was a single image, the tensor will have shape (c, h, w),
|
189 |
+
where c is the number of channels, h is the height, and w is the width.
|
190 |
+
"""
|
191 |
+
if state is not None:
|
192 |
+
torch.set_rng_state(state)
|
193 |
+
if isinstance(images, List):
|
194 |
+
transformed_images = [transform(img) for img in images]
|
195 |
+
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
|
196 |
+
else:
|
197 |
+
ret_tensor = transform(images) # (c, h, w)
|
198 |
+
return ret_tensor
|
199 |
+
|
200 |
+
def __getitem__(self, index):
|
201 |
+
video_meta = self.vid_meta[index]
|
202 |
+
video_path = video_meta["video_path"]
|
203 |
+
mask_path = video_meta["mask_path"]
|
204 |
+
lip_mask_union_path = video_meta.get("sep_mask_lip", None)
|
205 |
+
face_mask_union_path = video_meta.get("sep_mask_face", None)
|
206 |
+
full_mask_union_path = video_meta.get("sep_mask_border", None)
|
207 |
+
face_emb_path = video_meta["face_emb_path"]
|
208 |
+
audio_emb_path = video_meta[
|
209 |
+
f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}"
|
210 |
+
]
|
211 |
+
tgt_mask_pil = Image.open(mask_path)
|
212 |
+
video_frames = VideoReader(video_path, ctx=cpu(0))
|
213 |
+
assert tgt_mask_pil is not None, "Fail to load target mask."
|
214 |
+
assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames."
|
215 |
+
|
216 |
+
# 提前加载的位置,确认长度
|
217 |
+
audio_emb = torch.load(audio_emb_path)
|
218 |
+
|
219 |
+
# print(len(video_frames), len(audio_emb))
|
220 |
+
# 避免长度不一致,超索引范围
|
221 |
+
video_length = min(len(video_frames), len(audio_emb))
|
222 |
+
|
223 |
+
assert (
|
224 |
+
video_length
|
225 |
+
> self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin
|
226 |
+
)
|
227 |
+
start_idx = random.randint(
|
228 |
+
self.n_motion_frames,
|
229 |
+
video_length - self.n_sample_frames - self.audio_margin - 1,
|
230 |
+
)
|
231 |
+
|
232 |
+
videos = video_frames[start_idx : start_idx + self.n_sample_frames]
|
233 |
+
|
234 |
+
frame_list = [
|
235 |
+
Image.fromarray(video).convert("RGB") for video in videos.asnumpy()
|
236 |
+
]
|
237 |
+
|
238 |
+
face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames
|
239 |
+
lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames
|
240 |
+
full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames
|
241 |
+
assert face_masks_list[0] is not None, "Fail to load face mask."
|
242 |
+
assert lip_masks_list[0] is not None, "Fail to load lip mask."
|
243 |
+
assert full_masks_list[0] is not None, "Fail to load full mask."
|
244 |
+
|
245 |
+
|
246 |
+
face_emb = torch.load(face_emb_path)
|
247 |
+
|
248 |
+
indices = (
|
249 |
+
torch.arange(2 * self.audio_margin + 1) - self.audio_margin
|
250 |
+
) # Generates [-2, -1, 0, 1, 2]
|
251 |
+
center_indices = torch.arange(
|
252 |
+
start_idx,
|
253 |
+
start_idx + self.n_sample_frames,
|
254 |
+
).unsqueeze(1) + indices.unsqueeze(0)
|
255 |
+
audio_tensor = audio_emb[center_indices]
|
256 |
+
|
257 |
+
ref_img_idx = random.randint(
|
258 |
+
self.n_motion_frames,
|
259 |
+
video_length - self.n_sample_frames - self.audio_margin - 1,
|
260 |
+
)
|
261 |
+
ref_img = video_frames[ref_img_idx].asnumpy()
|
262 |
+
ref_img = Image.fromarray(ref_img)
|
263 |
+
|
264 |
+
if self.n_motion_frames > 0:
|
265 |
+
motions = video_frames[start_idx - self.n_motion_frames : start_idx]
|
266 |
+
motion_list = [
|
267 |
+
Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy()
|
268 |
+
]
|
269 |
+
|
270 |
+
# transform
|
271 |
+
state = torch.get_rng_state()
|
272 |
+
pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state)
|
273 |
+
|
274 |
+
pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state)
|
275 |
+
pixel_values_mask = pixel_values_mask.repeat(3, 1, 1)
|
276 |
+
|
277 |
+
pixel_values_face_mask = [
|
278 |
+
self.augmentation(face_masks_list, self.attn_transform_64, state),
|
279 |
+
self.augmentation(face_masks_list, self.attn_transform_32, state),
|
280 |
+
self.augmentation(face_masks_list, self.attn_transform_16, state),
|
281 |
+
self.augmentation(face_masks_list, self.attn_transform_8, state),
|
282 |
+
]
|
283 |
+
pixel_values_lip_mask = [
|
284 |
+
self.augmentation(lip_masks_list, self.attn_transform_64, state),
|
285 |
+
self.augmentation(lip_masks_list, self.attn_transform_32, state),
|
286 |
+
self.augmentation(lip_masks_list, self.attn_transform_16, state),
|
287 |
+
self.augmentation(lip_masks_list, self.attn_transform_8, state),
|
288 |
+
]
|
289 |
+
pixel_values_full_mask = [
|
290 |
+
self.augmentation(full_masks_list, self.attn_transform_64, state),
|
291 |
+
self.augmentation(full_masks_list, self.attn_transform_32, state),
|
292 |
+
self.augmentation(full_masks_list, self.attn_transform_16, state),
|
293 |
+
self.augmentation(full_masks_list, self.attn_transform_8, state),
|
294 |
+
]
|
295 |
+
|
296 |
+
pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
|
297 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
298 |
+
if self.n_motion_frames > 0:
|
299 |
+
pixel_values_motion = self.augmentation(
|
300 |
+
motion_list, self.pixel_transform, state
|
301 |
+
)
|
302 |
+
pixel_values_ref_img = torch.cat(
|
303 |
+
[pixel_values_ref_img, pixel_values_motion], dim=0
|
304 |
+
)
|
305 |
+
|
306 |
+
sample = {
|
307 |
+
"video_dir": video_path,
|
308 |
+
"pixel_values_vid": pixel_values_vid,
|
309 |
+
"pixel_values_mask": pixel_values_mask,
|
310 |
+
"pixel_values_face_mask": pixel_values_face_mask,
|
311 |
+
"pixel_values_lip_mask": pixel_values_lip_mask,
|
312 |
+
"pixel_values_full_mask": pixel_values_full_mask,
|
313 |
+
"audio_tensor": audio_tensor,
|
314 |
+
"pixel_values_ref_img": pixel_values_ref_img,
|
315 |
+
"face_emb": face_emb,
|
316 |
+
}
|
317 |
+
|
318 |
+
return sample
|
319 |
+
|
320 |
+
def __len__(self):
|
321 |
+
return len(self.vid_meta)
|
joyhallo/models/__init__.py
ADDED
File without changes
|
joyhallo/models/attention.py
ADDED
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains various transformer blocks for different applications, such as BasicTransformerBlock,
|
3 |
+
TemporalBasicTransformerBlock, and AudioTemporalBasicTransformerBlock. These blocks are used in various models,
|
4 |
+
such as GLIGEN, UNet, and others. The transformer blocks implement self-attention, cross-attention, feed-forward
|
5 |
+
networks, and other related functions.
|
6 |
+
|
7 |
+
Functions and classes included in this module are:
|
8 |
+
- BasicTransformerBlock: A basic transformer block with self-attention, cross-attention, and feed-forward layers.
|
9 |
+
- TemporalBasicTransformerBlock: A transformer block with additional temporal attention mechanisms for video data.
|
10 |
+
- AudioTemporalBasicTransformerBlock: A transformer block with additional audio-specific mechanisms for audio data.
|
11 |
+
- zero_module: A function to zero out the parameters of a given module.
|
12 |
+
|
13 |
+
For more information on each specific class and function, please refer to the respective docstrings.
|
14 |
+
"""
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero,
|
20 |
+
Attention, FeedForward)
|
21 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
22 |
+
from einops import rearrange
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class GatedSelfAttentionDense(nn.Module):
|
27 |
+
"""
|
28 |
+
A gated self-attention dense layer that combines visual features and object features.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
query_dim (`int`): The number of channels in the query.
|
32 |
+
context_dim (`int`): The number of channels in the context.
|
33 |
+
n_heads (`int`): The number of heads to use for attention.
|
34 |
+
d_head (`int`): The number of channels in each head.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
41 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
42 |
+
|
43 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
44 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
45 |
+
|
46 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
47 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
48 |
+
|
49 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
50 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
51 |
+
|
52 |
+
self.enabled = True
|
53 |
+
|
54 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
Apply the Gated Self-Attention mechanism to the input tensor `x` and object tensor `objs`.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): The input tensor.
|
60 |
+
objs (torch.Tensor): The object tensor.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
torch.Tensor: The output tensor after applying Gated Self-Attention.
|
64 |
+
"""
|
65 |
+
if not self.enabled:
|
66 |
+
return x
|
67 |
+
|
68 |
+
n_visual = x.shape[1]
|
69 |
+
objs = self.linear(objs)
|
70 |
+
|
71 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
72 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
73 |
+
|
74 |
+
return x
|
75 |
+
|
76 |
+
class BasicTransformerBlock(nn.Module):
|
77 |
+
r"""
|
78 |
+
A basic Transformer block.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
dim (`int`): The number of channels in the input and output.
|
82 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
83 |
+
attention_head_dim (`int`): The number of channels in each head.
|
84 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
85 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
86 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
87 |
+
num_embeds_ada_norm (:
|
88 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
89 |
+
attention_bias (:
|
90 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
91 |
+
only_cross_attention (`bool`, *optional*):
|
92 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
93 |
+
double_self_attention (`bool`, *optional*):
|
94 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
95 |
+
upcast_attention (`bool`, *optional*):
|
96 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
97 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
98 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
99 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
100 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
101 |
+
final_dropout (`bool` *optional*, defaults to False):
|
102 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
103 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
104 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
105 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
106 |
+
The type of positional embeddings to apply to.
|
107 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
108 |
+
The maximum number of positional embeddings to apply.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dim: int,
|
114 |
+
num_attention_heads: int,
|
115 |
+
attention_head_dim: int,
|
116 |
+
dropout=0.0,
|
117 |
+
cross_attention_dim: Optional[int] = None,
|
118 |
+
activation_fn: str = "geglu",
|
119 |
+
num_embeds_ada_norm: Optional[int] = None,
|
120 |
+
attention_bias: bool = False,
|
121 |
+
only_cross_attention: bool = False,
|
122 |
+
double_self_attention: bool = False,
|
123 |
+
upcast_attention: bool = False,
|
124 |
+
norm_elementwise_affine: bool = True,
|
125 |
+
# 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
126 |
+
norm_type: str = "layer_norm",
|
127 |
+
norm_eps: float = 1e-5,
|
128 |
+
final_dropout: bool = False,
|
129 |
+
attention_type: str = "default",
|
130 |
+
positional_embeddings: Optional[str] = None,
|
131 |
+
num_positional_embeddings: Optional[int] = None,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.only_cross_attention = only_cross_attention
|
135 |
+
|
136 |
+
self.use_ada_layer_norm_zero = (
|
137 |
+
num_embeds_ada_norm is not None
|
138 |
+
) and norm_type == "ada_norm_zero"
|
139 |
+
self.use_ada_layer_norm = (
|
140 |
+
num_embeds_ada_norm is not None
|
141 |
+
) and norm_type == "ada_norm"
|
142 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
143 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
144 |
+
|
145 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
146 |
+
raise ValueError(
|
147 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
148 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
149 |
+
)
|
150 |
+
|
151 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
152 |
+
raise ValueError(
|
153 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
154 |
+
)
|
155 |
+
|
156 |
+
if positional_embeddings == "sinusoidal":
|
157 |
+
self.pos_embed = SinusoidalPositionalEmbedding(
|
158 |
+
dim, max_seq_length=num_positional_embeddings
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.pos_embed = None
|
162 |
+
|
163 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
164 |
+
# 1. Self-Attn
|
165 |
+
if self.use_ada_layer_norm:
|
166 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
167 |
+
elif self.use_ada_layer_norm_zero:
|
168 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
169 |
+
else:
|
170 |
+
self.norm1 = nn.LayerNorm(
|
171 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
172 |
+
)
|
173 |
+
|
174 |
+
self.attn1 = Attention(
|
175 |
+
query_dim=dim,
|
176 |
+
heads=num_attention_heads,
|
177 |
+
dim_head=attention_head_dim,
|
178 |
+
dropout=dropout,
|
179 |
+
bias=attention_bias,
|
180 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
181 |
+
upcast_attention=upcast_attention,
|
182 |
+
)
|
183 |
+
|
184 |
+
# 2. Cross-Attn
|
185 |
+
if cross_attention_dim is not None or double_self_attention:
|
186 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
187 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
188 |
+
# the second cross attention block.
|
189 |
+
self.norm2 = (
|
190 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
191 |
+
if self.use_ada_layer_norm
|
192 |
+
else nn.LayerNorm(
|
193 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
194 |
+
)
|
195 |
+
)
|
196 |
+
self.attn2 = Attention(
|
197 |
+
query_dim=dim,
|
198 |
+
cross_attention_dim=(
|
199 |
+
cross_attention_dim if not double_self_attention else None
|
200 |
+
),
|
201 |
+
heads=num_attention_heads,
|
202 |
+
dim_head=attention_head_dim,
|
203 |
+
dropout=dropout,
|
204 |
+
bias=attention_bias,
|
205 |
+
upcast_attention=upcast_attention,
|
206 |
+
) # is self-attn if encoder_hidden_states is none
|
207 |
+
else:
|
208 |
+
self.norm2 = None
|
209 |
+
self.attn2 = None
|
210 |
+
|
211 |
+
# 3. Feed-forward
|
212 |
+
if not self.use_ada_layer_norm_single:
|
213 |
+
self.norm3 = nn.LayerNorm(
|
214 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
215 |
+
)
|
216 |
+
|
217 |
+
self.ff = FeedForward(
|
218 |
+
dim,
|
219 |
+
dropout=dropout,
|
220 |
+
activation_fn=activation_fn,
|
221 |
+
final_dropout=final_dropout,
|
222 |
+
)
|
223 |
+
|
224 |
+
# 4. Fuser
|
225 |
+
if attention_type in {"gated", "gated-text-image"}: # Updated line
|
226 |
+
self.fuser = GatedSelfAttentionDense(
|
227 |
+
dim, cross_attention_dim, num_attention_heads, attention_head_dim
|
228 |
+
)
|
229 |
+
|
230 |
+
# 5. Scale-shift for PixArt-Alpha.
|
231 |
+
if self.use_ada_layer_norm_single:
|
232 |
+
self.scale_shift_table = nn.Parameter(
|
233 |
+
torch.randn(6, dim) / dim**0.5)
|
234 |
+
|
235 |
+
# let chunk size default to None
|
236 |
+
self._chunk_size = None
|
237 |
+
self._chunk_dim = 0
|
238 |
+
|
239 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
240 |
+
"""
|
241 |
+
Sets the chunk size for feed-forward processing in the transformer block.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
chunk_size (Optional[int]): The size of the chunks to process in feed-forward layers.
|
245 |
+
If None, the chunk size is set to the maximum possible value.
|
246 |
+
dim (int, optional): The dimension along which to split the input tensor into chunks. Defaults to 0.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
None.
|
250 |
+
"""
|
251 |
+
self._chunk_size = chunk_size
|
252 |
+
self._chunk_dim = dim
|
253 |
+
|
254 |
+
def forward(
|
255 |
+
self,
|
256 |
+
hidden_states: torch.FloatTensor,
|
257 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
258 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
259 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
260 |
+
timestep: Optional[torch.LongTensor] = None,
|
261 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
262 |
+
class_labels: Optional[torch.LongTensor] = None,
|
263 |
+
) -> torch.FloatTensor:
|
264 |
+
"""
|
265 |
+
This function defines the forward pass of the BasicTransformerBlock.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
self (BasicTransformerBlock):
|
269 |
+
An instance of the BasicTransformerBlock class.
|
270 |
+
hidden_states (torch.FloatTensor):
|
271 |
+
A tensor containing the hidden states.
|
272 |
+
attention_mask (Optional[torch.FloatTensor], optional):
|
273 |
+
A tensor containing the attention mask. Defaults to None.
|
274 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional):
|
275 |
+
A tensor containing the encoder hidden states. Defaults to None.
|
276 |
+
encoder_attention_mask (Optional[torch.FloatTensor], optional):
|
277 |
+
A tensor containing the encoder attention mask. Defaults to None.
|
278 |
+
timestep (Optional[torch.LongTensor], optional):
|
279 |
+
A tensor containing the timesteps. Defaults to None.
|
280 |
+
cross_attention_kwargs (Dict[str, Any], optional):
|
281 |
+
Additional cross-attention arguments. Defaults to None.
|
282 |
+
class_labels (Optional[torch.LongTensor], optional):
|
283 |
+
A tensor containing the class labels. Defaults to None.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
torch.FloatTensor:
|
287 |
+
A tensor containing the transformed hidden states.
|
288 |
+
"""
|
289 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
290 |
+
# 0. Self-Attention
|
291 |
+
batch_size = hidden_states.shape[0]
|
292 |
+
|
293 |
+
gate_msa = None
|
294 |
+
scale_mlp = None
|
295 |
+
shift_mlp = None
|
296 |
+
gate_mlp = None
|
297 |
+
if self.use_ada_layer_norm:
|
298 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
299 |
+
elif self.use_ada_layer_norm_zero:
|
300 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
301 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
302 |
+
)
|
303 |
+
elif self.use_layer_norm:
|
304 |
+
norm_hidden_states = self.norm1(hidden_states)
|
305 |
+
elif self.use_ada_layer_norm_single:
|
306 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
307 |
+
self.scale_shift_table[None] +
|
308 |
+
timestep.reshape(batch_size, 6, -1)
|
309 |
+
).chunk(6, dim=1)
|
310 |
+
norm_hidden_states = self.norm1(hidden_states)
|
311 |
+
norm_hidden_states = norm_hidden_states * \
|
312 |
+
(1 + scale_msa) + shift_msa
|
313 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
314 |
+
else:
|
315 |
+
raise ValueError("Incorrect norm used")
|
316 |
+
|
317 |
+
if self.pos_embed is not None:
|
318 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
319 |
+
|
320 |
+
# 1. Retrieve lora scale.
|
321 |
+
lora_scale = (
|
322 |
+
cross_attention_kwargs.get("scale", 1.0)
|
323 |
+
if cross_attention_kwargs is not None
|
324 |
+
else 1.0
|
325 |
+
)
|
326 |
+
|
327 |
+
# 2. Prepare GLIGEN inputs
|
328 |
+
cross_attention_kwargs = (
|
329 |
+
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
330 |
+
)
|
331 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
332 |
+
|
333 |
+
attn_output = self.attn1(
|
334 |
+
norm_hidden_states,
|
335 |
+
encoder_hidden_states=(
|
336 |
+
encoder_hidden_states if self.only_cross_attention else None
|
337 |
+
),
|
338 |
+
attention_mask=attention_mask,
|
339 |
+
**cross_attention_kwargs,
|
340 |
+
)
|
341 |
+
if self.use_ada_layer_norm_zero:
|
342 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
343 |
+
elif self.use_ada_layer_norm_single:
|
344 |
+
attn_output = gate_msa * attn_output
|
345 |
+
|
346 |
+
hidden_states = attn_output + hidden_states
|
347 |
+
if hidden_states.ndim == 4:
|
348 |
+
hidden_states = hidden_states.squeeze(1)
|
349 |
+
|
350 |
+
# 2.5 GLIGEN Control
|
351 |
+
if gligen_kwargs is not None:
|
352 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
353 |
+
|
354 |
+
# 3. Cross-Attention
|
355 |
+
if self.attn2 is not None:
|
356 |
+
if self.use_ada_layer_norm:
|
357 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
358 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
359 |
+
norm_hidden_states = self.norm2(hidden_states)
|
360 |
+
elif self.use_ada_layer_norm_single:
|
361 |
+
# For PixArt norm2 isn't applied here:
|
362 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
363 |
+
norm_hidden_states = hidden_states
|
364 |
+
else:
|
365 |
+
raise ValueError("Incorrect norm")
|
366 |
+
|
367 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
368 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
369 |
+
|
370 |
+
attn_output = self.attn2(
|
371 |
+
norm_hidden_states,
|
372 |
+
encoder_hidden_states=encoder_hidden_states,
|
373 |
+
attention_mask=encoder_attention_mask,
|
374 |
+
**cross_attention_kwargs,
|
375 |
+
)
|
376 |
+
hidden_states = attn_output + hidden_states
|
377 |
+
|
378 |
+
# 4. Feed-forward
|
379 |
+
if not self.use_ada_layer_norm_single:
|
380 |
+
norm_hidden_states = self.norm3(hidden_states)
|
381 |
+
|
382 |
+
if self.use_ada_layer_norm_zero:
|
383 |
+
norm_hidden_states = (
|
384 |
+
norm_hidden_states *
|
385 |
+
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
386 |
+
)
|
387 |
+
|
388 |
+
if self.use_ada_layer_norm_single:
|
389 |
+
norm_hidden_states = self.norm2(hidden_states)
|
390 |
+
norm_hidden_states = norm_hidden_states * \
|
391 |
+
(1 + scale_mlp) + shift_mlp
|
392 |
+
|
393 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
394 |
+
|
395 |
+
if self.use_ada_layer_norm_zero:
|
396 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
397 |
+
elif self.use_ada_layer_norm_single:
|
398 |
+
ff_output = gate_mlp * ff_output
|
399 |
+
|
400 |
+
hidden_states = ff_output + hidden_states
|
401 |
+
if hidden_states.ndim == 4:
|
402 |
+
hidden_states = hidden_states.squeeze(1)
|
403 |
+
|
404 |
+
return hidden_states
|
405 |
+
|
406 |
+
|
407 |
+
class TemporalBasicTransformerBlock(nn.Module):
|
408 |
+
"""
|
409 |
+
A PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
|
410 |
+
This class is particularly useful for video-related tasks where capturing temporal information within the sequence of frames is necessary.
|
411 |
+
|
412 |
+
Attributes:
|
413 |
+
dim (int): The dimension of the input and output embeddings.
|
414 |
+
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
|
415 |
+
attention_head_dim (int): The dimension of each attention head.
|
416 |
+
dropout (float): The dropout probability for the attention scores.
|
417 |
+
cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
|
418 |
+
activation_fn (str): The activation function used in the feed-forward layer.
|
419 |
+
num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
|
420 |
+
attention_bias (bool): If True, uses bias in the attention mechanism.
|
421 |
+
only_cross_attention (bool): If True, only uses cross-attention.
|
422 |
+
upcast_attention (bool): If True, upcasts the attention mechanism for better performance.
|
423 |
+
unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in the UNet model.
|
424 |
+
unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in the UNet model.
|
425 |
+
"""
|
426 |
+
def __init__(
|
427 |
+
self,
|
428 |
+
dim: int,
|
429 |
+
num_attention_heads: int,
|
430 |
+
attention_head_dim: int,
|
431 |
+
dropout=0.0,
|
432 |
+
cross_attention_dim: Optional[int] = None,
|
433 |
+
activation_fn: str = "geglu",
|
434 |
+
num_embeds_ada_norm: Optional[int] = None,
|
435 |
+
attention_bias: bool = False,
|
436 |
+
only_cross_attention: bool = False,
|
437 |
+
upcast_attention: bool = False,
|
438 |
+
unet_use_cross_frame_attention=None,
|
439 |
+
unet_use_temporal_attention=None,
|
440 |
+
):
|
441 |
+
"""
|
442 |
+
The TemporalBasicTransformerBlock class is a PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
|
443 |
+
This is particularly useful for video-related tasks, where the model needs to capture the temporal information within the sequence of frames.
|
444 |
+
The block consists of self-attention, cross-attention, feed-forward, and temporal attention mechanisms.
|
445 |
+
|
446 |
+
dim (int): The dimension of the input and output embeddings.
|
447 |
+
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
|
448 |
+
attention_head_dim (int): The dimension of each attention head.
|
449 |
+
dropout (float, optional): The dropout probability for the attention scores. Defaults to 0.0.
|
450 |
+
cross_attention_dim (int, optional): The dimension of the cross-attention mechanism. Defaults to None.
|
451 |
+
activation_fn (str, optional): The activation function used in the feed-forward layer. Defaults to "geglu".
|
452 |
+
num_embeds_ada_norm (int, optional): The number of embeddings for adaptive normalization. Defaults to None.
|
453 |
+
attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
|
454 |
+
only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
|
455 |
+
upcast_attention (bool, optional): If True, upcasts the attention mechanism for better performance. Defaults to False.
|
456 |
+
unet_use_cross_frame_attention (bool, optional): If True, uses cross-frame attention in the UNet model. Defaults to None.
|
457 |
+
unet_use_temporal_attention (bool, optional): If True, uses temporal attention in the UNet model. Defaults to None.
|
458 |
+
|
459 |
+
Forward method:
|
460 |
+
hidden_states (torch.FloatTensor): The input hidden states.
|
461 |
+
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
|
462 |
+
timestep (torch.LongTensor, optional): The current timestep for the transformer model. Defaults to None.
|
463 |
+
attention_mask (torch.FloatTensor, optional): The attention mask for the self-attention mechanism. Defaults to None.
|
464 |
+
video_length (int, optional): The length of the video sequence. Defaults to None.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
torch.FloatTensor: The output hidden states after passing through the TemporalBasicTransformerBlock.
|
468 |
+
"""
|
469 |
+
super().__init__()
|
470 |
+
self.only_cross_attention = only_cross_attention
|
471 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
472 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
473 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
474 |
+
|
475 |
+
# SC-Attn
|
476 |
+
self.attn1 = Attention(
|
477 |
+
query_dim=dim,
|
478 |
+
heads=num_attention_heads,
|
479 |
+
dim_head=attention_head_dim,
|
480 |
+
dropout=dropout,
|
481 |
+
bias=attention_bias,
|
482 |
+
upcast_attention=upcast_attention,
|
483 |
+
)
|
484 |
+
self.norm1 = (
|
485 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
486 |
+
if self.use_ada_layer_norm
|
487 |
+
else nn.LayerNorm(dim)
|
488 |
+
)
|
489 |
+
|
490 |
+
# Cross-Attn
|
491 |
+
if cross_attention_dim is not None:
|
492 |
+
self.attn2 = Attention(
|
493 |
+
query_dim=dim,
|
494 |
+
cross_attention_dim=cross_attention_dim,
|
495 |
+
heads=num_attention_heads,
|
496 |
+
dim_head=attention_head_dim,
|
497 |
+
dropout=dropout,
|
498 |
+
bias=attention_bias,
|
499 |
+
upcast_attention=upcast_attention,
|
500 |
+
)
|
501 |
+
else:
|
502 |
+
self.attn2 = None
|
503 |
+
|
504 |
+
if cross_attention_dim is not None:
|
505 |
+
self.norm2 = (
|
506 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
507 |
+
if self.use_ada_layer_norm
|
508 |
+
else nn.LayerNorm(dim)
|
509 |
+
)
|
510 |
+
else:
|
511 |
+
self.norm2 = None
|
512 |
+
|
513 |
+
# Feed-forward
|
514 |
+
self.ff = FeedForward(dim, dropout=dropout,
|
515 |
+
activation_fn=activation_fn)
|
516 |
+
self.norm3 = nn.LayerNorm(dim)
|
517 |
+
self.use_ada_layer_norm_zero = False
|
518 |
+
|
519 |
+
# Temp-Attn
|
520 |
+
# assert unet_use_temporal_attention is not None
|
521 |
+
if unet_use_temporal_attention is None:
|
522 |
+
unet_use_temporal_attention = False
|
523 |
+
if unet_use_temporal_attention:
|
524 |
+
self.attn_temp = Attention(
|
525 |
+
query_dim=dim,
|
526 |
+
heads=num_attention_heads,
|
527 |
+
dim_head=attention_head_dim,
|
528 |
+
dropout=dropout,
|
529 |
+
bias=attention_bias,
|
530 |
+
upcast_attention=upcast_attention,
|
531 |
+
)
|
532 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
533 |
+
self.norm_temp = (
|
534 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
535 |
+
if self.use_ada_layer_norm
|
536 |
+
else nn.LayerNorm(dim)
|
537 |
+
)
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
hidden_states,
|
542 |
+
encoder_hidden_states=None,
|
543 |
+
timestep=None,
|
544 |
+
attention_mask=None,
|
545 |
+
video_length=None,
|
546 |
+
):
|
547 |
+
"""
|
548 |
+
Forward pass for the TemporalBasicTransformerBlock.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
hidden_states (torch.FloatTensor): The input hidden states with shape (batch_size, seq_len, dim).
|
552 |
+
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states with shape (batch_size, src_seq_len, dim).
|
553 |
+
timestep (torch.LongTensor, optional): The timestep for the transformer block.
|
554 |
+
attention_mask (torch.FloatTensor, optional): The attention mask with shape (batch_size, seq_len, seq_len).
|
555 |
+
video_length (int, optional): The length of the video sequence.
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
torch.FloatTensor: The output tensor after passing through the transformer block with shape (batch_size, seq_len, dim).
|
559 |
+
"""
|
560 |
+
norm_hidden_states = (
|
561 |
+
self.norm1(hidden_states, timestep)
|
562 |
+
if self.use_ada_layer_norm
|
563 |
+
else self.norm1(hidden_states)
|
564 |
+
)
|
565 |
+
|
566 |
+
if self.unet_use_cross_frame_attention:
|
567 |
+
hidden_states = (
|
568 |
+
self.attn1(
|
569 |
+
norm_hidden_states,
|
570 |
+
attention_mask=attention_mask,
|
571 |
+
video_length=video_length,
|
572 |
+
)
|
573 |
+
+ hidden_states
|
574 |
+
)
|
575 |
+
else:
|
576 |
+
hidden_states = (
|
577 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask)
|
578 |
+
+ hidden_states
|
579 |
+
)
|
580 |
+
|
581 |
+
if self.attn2 is not None:
|
582 |
+
# Cross-Attention
|
583 |
+
norm_hidden_states = (
|
584 |
+
self.norm2(hidden_states, timestep)
|
585 |
+
if self.use_ada_layer_norm
|
586 |
+
else self.norm2(hidden_states)
|
587 |
+
)
|
588 |
+
hidden_states = (
|
589 |
+
self.attn2(
|
590 |
+
norm_hidden_states,
|
591 |
+
encoder_hidden_states=encoder_hidden_states,
|
592 |
+
attention_mask=attention_mask,
|
593 |
+
)
|
594 |
+
+ hidden_states
|
595 |
+
)
|
596 |
+
|
597 |
+
# Feed-forward
|
598 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
599 |
+
|
600 |
+
# Temporal-Attention
|
601 |
+
if self.unet_use_temporal_attention:
|
602 |
+
d = hidden_states.shape[1]
|
603 |
+
hidden_states = rearrange(
|
604 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
605 |
+
)
|
606 |
+
norm_hidden_states = (
|
607 |
+
self.norm_temp(hidden_states, timestep)
|
608 |
+
if self.use_ada_layer_norm
|
609 |
+
else self.norm_temp(hidden_states)
|
610 |
+
)
|
611 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
612 |
+
hidden_states = rearrange(
|
613 |
+
hidden_states, "(b d) f c -> (b f) d c", d=d)
|
614 |
+
|
615 |
+
return hidden_states
|
616 |
+
|
617 |
+
|
618 |
+
class AudioTemporalBasicTransformerBlock(nn.Module):
|
619 |
+
"""
|
620 |
+
A PyTorch module designed to handle audio data within a transformer framework, including temporal attention mechanisms.
|
621 |
+
|
622 |
+
Attributes:
|
623 |
+
dim (int): The dimension of the input and output embeddings.
|
624 |
+
num_attention_heads (int): The number of attention heads.
|
625 |
+
attention_head_dim (int): The dimension of each attention head.
|
626 |
+
dropout (float): The dropout probability.
|
627 |
+
cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
|
628 |
+
activation_fn (str): The activation function for the feed-forward network.
|
629 |
+
num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
|
630 |
+
attention_bias (bool): If True, uses bias in the attention mechanism.
|
631 |
+
only_cross_attention (bool): If True, only uses cross-attention.
|
632 |
+
upcast_attention (bool): If True, upcasts the attention mechanism to float32.
|
633 |
+
unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in UNet.
|
634 |
+
unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in UNet.
|
635 |
+
depth (int): The depth of the transformer block.
|
636 |
+
unet_block_name (Optional[str]): The name of the UNet block.
|
637 |
+
stack_enable_blocks_name (Optional[List[str]]): The list of enabled blocks in the stack.
|
638 |
+
stack_enable_blocks_depth (Optional[List[int]]): The list of depths for the enabled blocks in the stack.
|
639 |
+
"""
|
640 |
+
def __init__(
|
641 |
+
self,
|
642 |
+
dim: int,
|
643 |
+
num_attention_heads: int,
|
644 |
+
attention_head_dim: int,
|
645 |
+
dropout=0.0,
|
646 |
+
cross_attention_dim: Optional[int] = None,
|
647 |
+
activation_fn: str = "geglu",
|
648 |
+
num_embeds_ada_norm: Optional[int] = None,
|
649 |
+
attention_bias: bool = False,
|
650 |
+
only_cross_attention: bool = False,
|
651 |
+
upcast_attention: bool = False,
|
652 |
+
unet_use_cross_frame_attention=None,
|
653 |
+
unet_use_temporal_attention=None,
|
654 |
+
depth=0,
|
655 |
+
unet_block_name=None,
|
656 |
+
stack_enable_blocks_name: Optional[List[str]] = None,
|
657 |
+
stack_enable_blocks_depth: Optional[List[int]] = None,
|
658 |
+
):
|
659 |
+
"""
|
660 |
+
Initializes the AudioTemporalBasicTransformerBlock module.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
dim (int): The dimension of the input and output embeddings.
|
664 |
+
num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
|
665 |
+
attention_head_dim (int): The dimension of each attention head.
|
666 |
+
dropout (float, optional): The dropout probability for the attention mechanism. Defaults to 0.0.
|
667 |
+
cross_attention_dim (Optional[int], optional): The dimension of the cross-attention mechanism. Defaults to None.
|
668 |
+
activation_fn (str, optional): The activation function to be used in the feed-forward network. Defaults to "geglu".
|
669 |
+
num_embeds_ada_norm (Optional[int], optional): The number of embeddings for adaptive normalization. Defaults to None.
|
670 |
+
attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
|
671 |
+
only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
|
672 |
+
upcast_attention (bool, optional): If True, upcasts the attention mechanism to float32. Defaults to False.
|
673 |
+
unet_use_cross_frame_attention (Optional[bool], optional): If True, uses cross-frame attention in UNet. Defaults to None.
|
674 |
+
unet_use_temporal_attention (Optional[bool], optional): If True, uses temporal attention in UNet. Defaults to None.
|
675 |
+
depth (int, optional): The depth of the transformer block. Defaults to 0.
|
676 |
+
unet_block_name (Optional[str], optional): The name of the UNet block. Defaults to None.
|
677 |
+
stack_enable_blocks_name (Optional[List[str]], optional): The list of enabled blocks in the stack. Defaults to None.
|
678 |
+
stack_enable_blocks_depth (Optional[List[int]], optional): The list of depths for the enabled blocks in the stack. Defaults to None.
|
679 |
+
"""
|
680 |
+
super().__init__()
|
681 |
+
self.only_cross_attention = only_cross_attention
|
682 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
683 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
684 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
685 |
+
self.unet_block_name = unet_block_name
|
686 |
+
self.depth = depth
|
687 |
+
|
688 |
+
zero_conv_full = nn.Conv2d(
|
689 |
+
dim, dim, kernel_size=1)
|
690 |
+
self.zero_conv_full = zero_module(zero_conv_full)
|
691 |
+
|
692 |
+
zero_conv_face = nn.Conv2d(
|
693 |
+
dim, dim, kernel_size=1)
|
694 |
+
self.zero_conv_face = zero_module(zero_conv_face)
|
695 |
+
|
696 |
+
zero_conv_lip = nn.Conv2d(
|
697 |
+
dim, dim, kernel_size=1)
|
698 |
+
self.zero_conv_lip = zero_module(zero_conv_lip)
|
699 |
+
# SC-Attn
|
700 |
+
self.attn1 = Attention(
|
701 |
+
query_dim=dim,
|
702 |
+
heads=num_attention_heads,
|
703 |
+
dim_head=attention_head_dim,
|
704 |
+
dropout=dropout,
|
705 |
+
bias=attention_bias,
|
706 |
+
upcast_attention=upcast_attention,
|
707 |
+
)
|
708 |
+
self.norm1 = (
|
709 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
710 |
+
if self.use_ada_layer_norm
|
711 |
+
else nn.LayerNorm(dim)
|
712 |
+
)
|
713 |
+
|
714 |
+
# Cross-Attn
|
715 |
+
if cross_attention_dim is not None:
|
716 |
+
if (stack_enable_blocks_name is not None and
|
717 |
+
stack_enable_blocks_depth is not None and
|
718 |
+
self.unet_block_name in stack_enable_blocks_name and
|
719 |
+
self.depth in stack_enable_blocks_depth):
|
720 |
+
self.attn2_0 = Attention(
|
721 |
+
query_dim=dim,
|
722 |
+
cross_attention_dim=cross_attention_dim,
|
723 |
+
heads=num_attention_heads,
|
724 |
+
dim_head=attention_head_dim,
|
725 |
+
dropout=dropout,
|
726 |
+
bias=attention_bias,
|
727 |
+
upcast_attention=upcast_attention,
|
728 |
+
)
|
729 |
+
self.attn2 = None
|
730 |
+
|
731 |
+
else:
|
732 |
+
self.attn2 = Attention(
|
733 |
+
query_dim=dim,
|
734 |
+
cross_attention_dim=cross_attention_dim,
|
735 |
+
heads=num_attention_heads,
|
736 |
+
dim_head=attention_head_dim,
|
737 |
+
dropout=dropout,
|
738 |
+
bias=attention_bias,
|
739 |
+
upcast_attention=upcast_attention,
|
740 |
+
)
|
741 |
+
self.attn2_0=None
|
742 |
+
else:
|
743 |
+
self.attn2 = None
|
744 |
+
self.attn2_0 = None
|
745 |
+
|
746 |
+
if cross_attention_dim is not None:
|
747 |
+
self.norm2 = (
|
748 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
749 |
+
if self.use_ada_layer_norm
|
750 |
+
else nn.LayerNorm(dim)
|
751 |
+
)
|
752 |
+
else:
|
753 |
+
self.norm2 = None
|
754 |
+
|
755 |
+
# Feed-forward
|
756 |
+
self.ff = FeedForward(dim, dropout=dropout,
|
757 |
+
activation_fn=activation_fn)
|
758 |
+
self.norm3 = nn.LayerNorm(dim)
|
759 |
+
self.use_ada_layer_norm_zero = False
|
760 |
+
|
761 |
+
|
762 |
+
|
763 |
+
def forward(
|
764 |
+
self,
|
765 |
+
hidden_states,
|
766 |
+
encoder_hidden_states=None,
|
767 |
+
timestep=None,
|
768 |
+
attention_mask=None,
|
769 |
+
full_mask=None,
|
770 |
+
face_mask=None,
|
771 |
+
lip_mask=None,
|
772 |
+
motion_scale=None,
|
773 |
+
video_length=None,
|
774 |
+
):
|
775 |
+
"""
|
776 |
+
Forward pass for the AudioTemporalBasicTransformerBlock.
|
777 |
+
|
778 |
+
Args:
|
779 |
+
hidden_states (torch.FloatTensor): The input hidden states.
|
780 |
+
encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
|
781 |
+
timestep (torch.LongTensor, optional): The timestep for the transformer block. Defaults to None.
|
782 |
+
attention_mask (torch.FloatTensor, optional): The attention mask. Defaults to None.
|
783 |
+
full_mask (torch.FloatTensor, optional): The full mask. Defaults to None.
|
784 |
+
face_mask (torch.FloatTensor, optional): The face mask. Defaults to None.
|
785 |
+
lip_mask (torch.FloatTensor, optional): The lip mask. Defaults to None.
|
786 |
+
video_length (int, optional): The length of the video. Defaults to None.
|
787 |
+
|
788 |
+
Returns:
|
789 |
+
torch.FloatTensor: The output tensor after passing through the AudioTemporalBasicTransformerBlock.
|
790 |
+
"""
|
791 |
+
norm_hidden_states = (
|
792 |
+
self.norm1(hidden_states, timestep)
|
793 |
+
if self.use_ada_layer_norm
|
794 |
+
else self.norm1(hidden_states)
|
795 |
+
)
|
796 |
+
|
797 |
+
if self.unet_use_cross_frame_attention:
|
798 |
+
hidden_states = (
|
799 |
+
self.attn1(
|
800 |
+
norm_hidden_states,
|
801 |
+
attention_mask=attention_mask,
|
802 |
+
video_length=video_length,
|
803 |
+
)
|
804 |
+
+ hidden_states
|
805 |
+
)
|
806 |
+
else:
|
807 |
+
hidden_states = (
|
808 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask)
|
809 |
+
+ hidden_states
|
810 |
+
)
|
811 |
+
|
812 |
+
if self.attn2 is not None:
|
813 |
+
# Cross-Attention
|
814 |
+
norm_hidden_states = (
|
815 |
+
self.norm2(hidden_states, timestep)
|
816 |
+
if self.use_ada_layer_norm
|
817 |
+
else self.norm2(hidden_states)
|
818 |
+
)
|
819 |
+
hidden_states = self.attn2(
|
820 |
+
norm_hidden_states,
|
821 |
+
encoder_hidden_states=encoder_hidden_states,
|
822 |
+
attention_mask=attention_mask,
|
823 |
+
) + hidden_states
|
824 |
+
|
825 |
+
elif self.attn2_0 is not None:
|
826 |
+
norm_hidden_states = (
|
827 |
+
self.norm2(hidden_states, timestep)
|
828 |
+
if self.use_ada_layer_norm
|
829 |
+
else self.norm2(hidden_states)
|
830 |
+
)
|
831 |
+
|
832 |
+
level = self.depth
|
833 |
+
all_hidden_states = self.attn2_0(
|
834 |
+
norm_hidden_states,
|
835 |
+
encoder_hidden_states=encoder_hidden_states,
|
836 |
+
attention_mask=attention_mask,
|
837 |
+
)
|
838 |
+
|
839 |
+
full_hidden_states = (
|
840 |
+
all_hidden_states * full_mask[level][:, :, None]
|
841 |
+
)
|
842 |
+
bz, sz, c = full_hidden_states.shape
|
843 |
+
sz_sqrt = int(sz ** 0.5)
|
844 |
+
full_hidden_states = full_hidden_states.reshape(
|
845 |
+
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
|
846 |
+
full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c)
|
847 |
+
|
848 |
+
face_hidden_state = (
|
849 |
+
all_hidden_states * face_mask[level][:, :, None]
|
850 |
+
)
|
851 |
+
face_hidden_state = face_hidden_state.reshape(
|
852 |
+
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
|
853 |
+
face_hidden_state = self.zero_conv_face(
|
854 |
+
face_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
|
855 |
+
|
856 |
+
lip_hidden_state = (
|
857 |
+
all_hidden_states * lip_mask[level][:, :, None]
|
858 |
+
) # [32, 4096, 320]
|
859 |
+
lip_hidden_state = lip_hidden_state.reshape(
|
860 |
+
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
|
861 |
+
lip_hidden_state = self.zero_conv_lip(
|
862 |
+
lip_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
|
863 |
+
|
864 |
+
if motion_scale is not None:
|
865 |
+
hidden_states = (
|
866 |
+
motion_scale[0] * full_hidden_states +
|
867 |
+
motion_scale[1] * face_hidden_state +
|
868 |
+
motion_scale[2] * lip_hidden_state + hidden_states
|
869 |
+
)
|
870 |
+
else:
|
871 |
+
hidden_states = (
|
872 |
+
full_hidden_states +
|
873 |
+
face_hidden_state +
|
874 |
+
lip_hidden_state + hidden_states
|
875 |
+
)
|
876 |
+
# Feed-forward
|
877 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
878 |
+
|
879 |
+
return hidden_states
|
880 |
+
|
881 |
+
def zero_module(module):
|
882 |
+
"""
|
883 |
+
Zeroes out the parameters of a given module.
|
884 |
+
|
885 |
+
Args:
|
886 |
+
module (nn.Module): The module whose parameters need to be zeroed out.
|
887 |
+
|
888 |
+
Returns:
|
889 |
+
None.
|
890 |
+
"""
|
891 |
+
for p in module.parameters():
|
892 |
+
nn.init.zeros_(p)
|
893 |
+
return module
|
joyhallo/models/audio_proj.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module provides the implementation of an Audio Projection Model, which is designed for
|
3 |
+
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
|
4 |
+
that can be used for various downstream applications, such as audio analysis or synthesis.
|
5 |
+
|
6 |
+
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
|
7 |
+
provides a foundation for building custom models. This implementation includes multiple linear
|
8 |
+
layers with ReLU activation functions and a LayerNorm for normalization.
|
9 |
+
|
10 |
+
Key Features:
|
11 |
+
- Audio embedding input with flexible sequence length and block structure.
|
12 |
+
- Multiple linear layers for feature transformation.
|
13 |
+
- ReLU activation for non-linear transformation.
|
14 |
+
- LayerNorm for stabilizing and speeding up training.
|
15 |
+
- Rearrangement of input embeddings to match the model's expected input shape.
|
16 |
+
- Customizable number of blocks, channels, and context tokens for adaptability.
|
17 |
+
|
18 |
+
The module is structured to be easily integrated into larger systems or used as a standalone
|
19 |
+
component for audio feature extraction and processing.
|
20 |
+
|
21 |
+
Classes:
|
22 |
+
- AudioProjModel: A class representing the audio projection model with configurable parameters.
|
23 |
+
|
24 |
+
Functions:
|
25 |
+
- (none)
|
26 |
+
|
27 |
+
Dependencies:
|
28 |
+
- torch: For tensor operations and neural network components.
|
29 |
+
- diffusers: For the ModelMixin base class.
|
30 |
+
- einops: For tensor rearrangement operations.
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
import torch
|
35 |
+
from diffusers import ModelMixin
|
36 |
+
from einops import rearrange
|
37 |
+
from torch import nn
|
38 |
+
|
39 |
+
|
40 |
+
class AudioProjModel(ModelMixin):
|
41 |
+
"""Audio Projection Model
|
42 |
+
|
43 |
+
This class defines an audio projection model that takes audio embeddings as input
|
44 |
+
and produces context tokens as output. The model is based on the ModelMixin class
|
45 |
+
and consists of multiple linear layers and activation functions. It can be used
|
46 |
+
for various audio processing tasks.
|
47 |
+
|
48 |
+
Attributes:
|
49 |
+
seq_len (int): The length of the audio sequence.
|
50 |
+
blocks (int): The number of blocks in the audio projection model.
|
51 |
+
channels (int): The number of channels in the audio projection model.
|
52 |
+
intermediate_dim (int): The intermediate dimension of the model.
|
53 |
+
context_tokens (int): The number of context tokens in the output.
|
54 |
+
output_dim (int): The output dimension of the context tokens.
|
55 |
+
|
56 |
+
Methods:
|
57 |
+
__init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
|
58 |
+
Initializes the AudioProjModel with the given parameters.
|
59 |
+
forward(self, audio_embeds):
|
60 |
+
Defines the forward pass for the AudioProjModel.
|
61 |
+
Parameters:
|
62 |
+
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
|
63 |
+
Returns:
|
64 |
+
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
seq_len=5,
|
71 |
+
blocks=12, # add a new parameter blocks
|
72 |
+
channels=768, # add a new parameter channels
|
73 |
+
intermediate_dim=512,
|
74 |
+
output_dim=768,
|
75 |
+
context_tokens=32,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.seq_len = seq_len
|
80 |
+
self.blocks = blocks
|
81 |
+
self.channels = channels
|
82 |
+
self.input_dim = (
|
83 |
+
seq_len * blocks * channels
|
84 |
+
) # update input_dim to be the product of blocks and channels.
|
85 |
+
self.intermediate_dim = intermediate_dim
|
86 |
+
self.context_tokens = context_tokens
|
87 |
+
self.output_dim = output_dim
|
88 |
+
|
89 |
+
# define multiple linear layers
|
90 |
+
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
91 |
+
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
92 |
+
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
93 |
+
|
94 |
+
self.norm = nn.LayerNorm(output_dim)
|
95 |
+
|
96 |
+
def forward(self, audio_embeds):
|
97 |
+
"""
|
98 |
+
Defines the forward pass for the AudioProjModel.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
|
105 |
+
"""
|
106 |
+
# merge
|
107 |
+
video_length = audio_embeds.shape[1]
|
108 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
109 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
110 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
111 |
+
|
112 |
+
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
113 |
+
audio_embeds = torch.relu(self.proj2(audio_embeds))
|
114 |
+
|
115 |
+
context_tokens = self.proj3(audio_embeds).reshape(
|
116 |
+
batch_size, self.context_tokens, self.output_dim
|
117 |
+
)
|
118 |
+
|
119 |
+
context_tokens = self.norm(context_tokens)
|
120 |
+
context_tokens = rearrange(
|
121 |
+
context_tokens, "(bz f) m c -> bz f m c", f=video_length
|
122 |
+
)
|
123 |
+
|
124 |
+
return context_tokens
|
joyhallo/models/face_locator.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module implements the FaceLocator class, which is a neural network model designed to
|
3 |
+
locate and extract facial features from input images or tensors. It uses a series of
|
4 |
+
convolutional layers to progressively downsample and refine the facial feature map.
|
5 |
+
|
6 |
+
The FaceLocator class is part of a larger system that may involve facial recognition or
|
7 |
+
similar tasks where precise location and extraction of facial features are required.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
conditioning_embedding_channels (int): The number of channels in the output embedding.
|
11 |
+
conditioning_channels (int): The number of input channels for the conditioning tensor.
|
12 |
+
block_out_channels (Tuple[int]): A tuple of integers representing the output channels
|
13 |
+
for each block in the model.
|
14 |
+
|
15 |
+
The model uses the following components:
|
16 |
+
- InflatedConv3d: A convolutional layer that inflates the input to increase the depth.
|
17 |
+
- zero_module: A utility function that may set certain parameters to zero for regularization
|
18 |
+
or other purposes.
|
19 |
+
|
20 |
+
The forward method of the FaceLocator class takes a conditioning tensor as input and
|
21 |
+
produces an embedding tensor as output, which can be used for further processing or analysis.
|
22 |
+
"""
|
23 |
+
|
24 |
+
from typing import Tuple
|
25 |
+
|
26 |
+
import torch.nn.functional as F
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from torch import nn
|
29 |
+
|
30 |
+
from .motion_module import zero_module
|
31 |
+
from .resnet import InflatedConv3d
|
32 |
+
|
33 |
+
|
34 |
+
class FaceLocator(ModelMixin):
|
35 |
+
"""
|
36 |
+
The FaceLocator class is a neural network model designed to process and extract facial
|
37 |
+
features from an input tensor. It consists of a series of convolutional layers that
|
38 |
+
progressively downsample the input while increasing the depth of the feature map.
|
39 |
+
|
40 |
+
The model is built using InflatedConv3d layers, which are designed to inflate the
|
41 |
+
feature channels, allowing for more complex feature extraction. The final output is a
|
42 |
+
conditioning embedding that can be used for various tasks such as facial recognition or
|
43 |
+
feature-based image manipulation.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
conditioning_embedding_channels (int): The number of channels in the output embedding.
|
47 |
+
conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3.
|
48 |
+
block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels
|
49 |
+
for each block in the model. The default is (16, 32, 64, 128), which defines the
|
50 |
+
progression of the network's depth.
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process.
|
54 |
+
blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model.
|
55 |
+
conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding.
|
56 |
+
|
57 |
+
The forward method applies the convolutional layers to the input conditioning tensor and
|
58 |
+
returns the resulting embedding tensor.
|
59 |
+
"""
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
conditioning_embedding_channels: int,
|
63 |
+
conditioning_channels: int = 3,
|
64 |
+
block_out_channels: Tuple[int] = (16, 32, 64, 128),
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
self.conv_in = InflatedConv3d(
|
68 |
+
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
|
69 |
+
)
|
70 |
+
|
71 |
+
self.blocks = nn.ModuleList([])
|
72 |
+
|
73 |
+
for i in range(len(block_out_channels) - 1):
|
74 |
+
channel_in = block_out_channels[i]
|
75 |
+
channel_out = block_out_channels[i + 1]
|
76 |
+
self.blocks.append(
|
77 |
+
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
|
78 |
+
)
|
79 |
+
self.blocks.append(
|
80 |
+
InflatedConv3d(
|
81 |
+
channel_in, channel_out, kernel_size=3, padding=1, stride=2
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
self.conv_out = zero_module(
|
86 |
+
InflatedConv3d(
|
87 |
+
block_out_channels[-1],
|
88 |
+
conditioning_embedding_channels,
|
89 |
+
kernel_size=3,
|
90 |
+
padding=1,
|
91 |
+
)
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, conditioning):
|
95 |
+
"""
|
96 |
+
Forward pass of the FaceLocator model.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
conditioning (Tensor): The input conditioning tensor.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Tensor: The output embedding tensor.
|
103 |
+
"""
|
104 |
+
embedding = self.conv_in(conditioning)
|
105 |
+
embedding = F.silu(embedding)
|
106 |
+
|
107 |
+
for block in self.blocks:
|
108 |
+
embedding = block(embedding)
|
109 |
+
embedding = F.silu(embedding)
|
110 |
+
|
111 |
+
embedding = self.conv_out(embedding)
|
112 |
+
|
113 |
+
return embedding
|
joyhallo/models/image_proj.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
image_proj_model.py
|
3 |
+
|
4 |
+
This module defines the ImageProjModel class, which is responsible for
|
5 |
+
projecting image embeddings into a different dimensional space. The model
|
6 |
+
leverages a linear transformation followed by a layer normalization to
|
7 |
+
reshape and normalize the input image embeddings for further processing in
|
8 |
+
cross-attention mechanisms or other downstream tasks.
|
9 |
+
|
10 |
+
Classes:
|
11 |
+
ImageProjModel
|
12 |
+
|
13 |
+
Dependencies:
|
14 |
+
torch
|
15 |
+
diffusers.ModelMixin
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from diffusers import ModelMixin
|
21 |
+
|
22 |
+
|
23 |
+
class ImageProjModel(ModelMixin):
|
24 |
+
"""
|
25 |
+
ImageProjModel is a class that projects image embeddings into a different
|
26 |
+
dimensional space. It inherits from ModelMixin, providing additional functionalities
|
27 |
+
specific to image projection.
|
28 |
+
|
29 |
+
Attributes:
|
30 |
+
cross_attention_dim (int): The dimension of the cross attention.
|
31 |
+
clip_embeddings_dim (int): The dimension of the CLIP embeddings.
|
32 |
+
clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
|
33 |
+
|
34 |
+
Methods:
|
35 |
+
forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
|
36 |
+
embeddings and returns the projected tokens.
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
cross_attention_dim=1024,
|
43 |
+
clip_embeddings_dim=1024,
|
44 |
+
clip_extra_context_tokens=4,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.generator = None
|
49 |
+
self.cross_attention_dim = cross_attention_dim
|
50 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
51 |
+
self.proj = torch.nn.Linear(
|
52 |
+
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
|
53 |
+
)
|
54 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
55 |
+
|
56 |
+
def forward(self, image_embeds):
|
57 |
+
"""
|
58 |
+
Forward pass of the ImageProjModel, which takes in image embeddings and returns the
|
59 |
+
projected tokens after reshaping and normalization.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
image_embeds (torch.Tensor): The input image embeddings, with shape
|
63 |
+
batch_size x num_image_tokens x clip_embeddings_dim.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
|
67 |
+
and normalization, with shape batch_size x (clip_extra_context_tokens *
|
68 |
+
cross_attention_dim).
|
69 |
+
|
70 |
+
"""
|
71 |
+
embeds = image_embeds
|
72 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
73 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
74 |
+
)
|
75 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
76 |
+
return clip_extra_context_tokens
|
joyhallo/models/motion_module.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
temporal_transformers.py
|
3 |
+
|
4 |
+
This module provides classes and functions for implementing Temporal Transformers
|
5 |
+
in PyTorch, designed for handling video data and temporal sequences within transformer-based models.
|
6 |
+
|
7 |
+
Functions:
|
8 |
+
zero_module(module)
|
9 |
+
Zero out the parameters of a module and return it.
|
10 |
+
|
11 |
+
Classes:
|
12 |
+
TemporalTransformer3DModelOutput(BaseOutput)
|
13 |
+
Dataclass for storing the output of TemporalTransformer3DModel.
|
14 |
+
|
15 |
+
VanillaTemporalModule(nn.Module)
|
16 |
+
A Vanilla Temporal Module class for handling temporal data.
|
17 |
+
|
18 |
+
TemporalTransformer3DModel(nn.Module)
|
19 |
+
A Temporal Transformer 3D Model class for transforming temporal data.
|
20 |
+
|
21 |
+
TemporalTransformerBlock(nn.Module)
|
22 |
+
A Temporal Transformer Block class for building the transformer architecture.
|
23 |
+
|
24 |
+
PositionalEncoding(nn.Module)
|
25 |
+
A Positional Encoding module for transformers to encode positional information.
|
26 |
+
|
27 |
+
Dependencies:
|
28 |
+
math
|
29 |
+
dataclasses.dataclass
|
30 |
+
typing (Callable, Optional)
|
31 |
+
torch
|
32 |
+
diffusers (FeedForward, Attention, AttnProcessor)
|
33 |
+
diffusers.utils (BaseOutput)
|
34 |
+
diffusers.utils.import_utils (is_xformers_available)
|
35 |
+
einops (rearrange, repeat)
|
36 |
+
torch.nn
|
37 |
+
xformers
|
38 |
+
xformers.ops
|
39 |
+
|
40 |
+
Example Usage:
|
41 |
+
>>> motion_module = get_motion_module(in_channels=512, motion_module_type="Vanilla", motion_module_kwargs={})
|
42 |
+
>>> output = motion_module(input_tensor, temb, encoder_hidden_states)
|
43 |
+
|
44 |
+
This module is designed to facilitate the creation, training, and inference of transformer models
|
45 |
+
that operate on temporal data, such as videos or time-series. It includes mechanisms for applying temporal attention,
|
46 |
+
managing positional encoding, and integrating with external libraries for efficient attention operations.
|
47 |
+
"""
|
48 |
+
|
49 |
+
# This code is copied from https://github.com/guoyww/AnimateDiff.
|
50 |
+
|
51 |
+
import math
|
52 |
+
|
53 |
+
import torch
|
54 |
+
import xformers
|
55 |
+
import xformers.ops
|
56 |
+
from diffusers.models.attention import FeedForward
|
57 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
58 |
+
from diffusers.utils import BaseOutput
|
59 |
+
from diffusers.utils.import_utils import is_xformers_available
|
60 |
+
from einops import rearrange, repeat
|
61 |
+
from torch import nn
|
62 |
+
|
63 |
+
|
64 |
+
def zero_module(module):
|
65 |
+
"""
|
66 |
+
Zero out the parameters of a module and return it.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
- module: A PyTorch module to zero out its parameters.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
A zeroed out PyTorch module.
|
73 |
+
"""
|
74 |
+
for p in module.parameters():
|
75 |
+
p.detach().zero_()
|
76 |
+
return module
|
77 |
+
|
78 |
+
|
79 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
80 |
+
"""
|
81 |
+
Output class for the TemporalTransformer3DModel.
|
82 |
+
|
83 |
+
Attributes:
|
84 |
+
sample (torch.FloatTensor): The output sample tensor from the model.
|
85 |
+
"""
|
86 |
+
sample: torch.FloatTensor
|
87 |
+
|
88 |
+
def get_sample_shape(self):
|
89 |
+
"""
|
90 |
+
Returns the shape of the sample tensor.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple: The shape of the sample tensor.
|
94 |
+
"""
|
95 |
+
return self.sample.shape
|
96 |
+
|
97 |
+
|
98 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
99 |
+
"""
|
100 |
+
This function returns a motion module based on the given type and parameters.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
- in_channels (int): The number of input channels for the motion module.
|
104 |
+
- motion_module_type (str): The type of motion module to create. Currently, only "Vanilla" is supported.
|
105 |
+
- motion_module_kwargs (dict): Additional keyword arguments to pass to the motion module constructor.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
VanillaTemporalModule: The created motion module.
|
109 |
+
|
110 |
+
Raises:
|
111 |
+
ValueError: If an unsupported motion_module_type is provided.
|
112 |
+
"""
|
113 |
+
if motion_module_type == "Vanilla":
|
114 |
+
return VanillaTemporalModule(
|
115 |
+
in_channels=in_channels,
|
116 |
+
**motion_module_kwargs,
|
117 |
+
)
|
118 |
+
|
119 |
+
raise ValueError
|
120 |
+
|
121 |
+
|
122 |
+
class VanillaTemporalModule(nn.Module):
|
123 |
+
"""
|
124 |
+
A Vanilla Temporal Module class.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
- in_channels (int): The number of input channels for the motion module.
|
128 |
+
- num_attention_heads (int): Number of attention heads.
|
129 |
+
- num_transformer_block (int): Number of transformer blocks.
|
130 |
+
- attention_block_types (tuple): Types of attention blocks.
|
131 |
+
- cross_frame_attention_mode: Mode for cross-frame attention.
|
132 |
+
- temporal_position_encoding (bool): Flag for temporal position encoding.
|
133 |
+
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
|
134 |
+
- temporal_attention_dim_div (int): Divisor for temporal attention dimension.
|
135 |
+
- zero_initialize (bool): Flag for zero initialization.
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
in_channels,
|
141 |
+
num_attention_heads=8,
|
142 |
+
num_transformer_block=2,
|
143 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
144 |
+
cross_frame_attention_mode=None,
|
145 |
+
temporal_position_encoding=False,
|
146 |
+
temporal_position_encoding_max_len=24,
|
147 |
+
temporal_attention_dim_div=1,
|
148 |
+
zero_initialize=True,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
153 |
+
in_channels=in_channels,
|
154 |
+
num_attention_heads=num_attention_heads,
|
155 |
+
attention_head_dim=in_channels
|
156 |
+
// num_attention_heads
|
157 |
+
// temporal_attention_dim_div,
|
158 |
+
num_layers=num_transformer_block,
|
159 |
+
attention_block_types=attention_block_types,
|
160 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
161 |
+
temporal_position_encoding=temporal_position_encoding,
|
162 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
163 |
+
)
|
164 |
+
|
165 |
+
if zero_initialize:
|
166 |
+
self.temporal_transformer.proj_out = zero_module(
|
167 |
+
self.temporal_transformer.proj_out
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
input_tensor,
|
173 |
+
encoder_hidden_states,
|
174 |
+
attention_mask=None,
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
Forward pass of the TemporalTransformer3DModel.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
hidden_states (torch.Tensor): The hidden states of the model.
|
181 |
+
encoder_hidden_states (torch.Tensor, optional): The hidden states of the encoder.
|
182 |
+
attention_mask (torch.Tensor, optional): The attention mask.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
torch.Tensor: The output tensor after the forward pass.
|
186 |
+
"""
|
187 |
+
hidden_states = input_tensor
|
188 |
+
hidden_states = self.temporal_transformer(
|
189 |
+
hidden_states, encoder_hidden_states
|
190 |
+
)
|
191 |
+
|
192 |
+
output = hidden_states
|
193 |
+
return output
|
194 |
+
|
195 |
+
|
196 |
+
class TemporalTransformer3DModel(nn.Module):
|
197 |
+
"""
|
198 |
+
A Temporal Transformer 3D Model class.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
- in_channels (int): The number of input channels.
|
202 |
+
- num_attention_heads (int): Number of attention heads.
|
203 |
+
- attention_head_dim (int): Dimension of attention heads.
|
204 |
+
- num_layers (int): Number of transformer layers.
|
205 |
+
- attention_block_types (tuple): Types of attention blocks.
|
206 |
+
- dropout (float): Dropout rate.
|
207 |
+
- norm_num_groups (int): Number of groups for normalization.
|
208 |
+
- cross_attention_dim (int): Dimension for cross-attention.
|
209 |
+
- activation_fn (str): Activation function.
|
210 |
+
- attention_bias (bool): Flag for attention bias.
|
211 |
+
- upcast_attention (bool): Flag for upcast attention.
|
212 |
+
- cross_frame_attention_mode: Mode for cross-frame attention.
|
213 |
+
- temporal_position_encoding (bool): Flag for temporal position encoding.
|
214 |
+
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
|
215 |
+
"""
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
in_channels,
|
219 |
+
num_attention_heads,
|
220 |
+
attention_head_dim,
|
221 |
+
num_layers,
|
222 |
+
attention_block_types=(
|
223 |
+
"Temporal_Self",
|
224 |
+
"Temporal_Self",
|
225 |
+
),
|
226 |
+
dropout=0.0,
|
227 |
+
norm_num_groups=32,
|
228 |
+
cross_attention_dim=768,
|
229 |
+
activation_fn="geglu",
|
230 |
+
attention_bias=False,
|
231 |
+
upcast_attention=False,
|
232 |
+
cross_frame_attention_mode=None,
|
233 |
+
temporal_position_encoding=False,
|
234 |
+
temporal_position_encoding_max_len=24,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
inner_dim = num_attention_heads * attention_head_dim
|
239 |
+
|
240 |
+
self.norm = torch.nn.GroupNorm(
|
241 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
242 |
+
)
|
243 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
244 |
+
|
245 |
+
self.transformer_blocks = nn.ModuleList(
|
246 |
+
[
|
247 |
+
TemporalTransformerBlock(
|
248 |
+
dim=inner_dim,
|
249 |
+
num_attention_heads=num_attention_heads,
|
250 |
+
attention_head_dim=attention_head_dim,
|
251 |
+
attention_block_types=attention_block_types,
|
252 |
+
dropout=dropout,
|
253 |
+
cross_attention_dim=cross_attention_dim,
|
254 |
+
activation_fn=activation_fn,
|
255 |
+
attention_bias=attention_bias,
|
256 |
+
upcast_attention=upcast_attention,
|
257 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
258 |
+
temporal_position_encoding=temporal_position_encoding,
|
259 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
260 |
+
)
|
261 |
+
for d in range(num_layers)
|
262 |
+
]
|
263 |
+
)
|
264 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
265 |
+
|
266 |
+
def forward(self, hidden_states, encoder_hidden_states=None):
|
267 |
+
"""
|
268 |
+
Forward pass for the TemporalTransformer3DModel.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
hidden_states (torch.Tensor): The input hidden states with shape (batch_size, sequence_length, in_channels).
|
272 |
+
encoder_hidden_states (torch.Tensor, optional): The encoder hidden states with shape (batch_size, encoder_sequence_length, in_channels).
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
torch.Tensor: The output hidden states with shape (batch_size, sequence_length, in_channels).
|
276 |
+
"""
|
277 |
+
assert (
|
278 |
+
hidden_states.dim() == 5
|
279 |
+
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
280 |
+
video_length = hidden_states.shape[2]
|
281 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
282 |
+
|
283 |
+
batch, _, height, weight = hidden_states.shape
|
284 |
+
residual = hidden_states
|
285 |
+
|
286 |
+
hidden_states = self.norm(hidden_states)
|
287 |
+
inner_dim = hidden_states.shape[1]
|
288 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
289 |
+
batch, height * weight, inner_dim
|
290 |
+
)
|
291 |
+
hidden_states = self.proj_in(hidden_states)
|
292 |
+
|
293 |
+
# Transformer Blocks
|
294 |
+
for block in self.transformer_blocks:
|
295 |
+
hidden_states = block(
|
296 |
+
hidden_states,
|
297 |
+
encoder_hidden_states=encoder_hidden_states,
|
298 |
+
video_length=video_length,
|
299 |
+
)
|
300 |
+
|
301 |
+
# output
|
302 |
+
hidden_states = self.proj_out(hidden_states)
|
303 |
+
hidden_states = (
|
304 |
+
hidden_states.reshape(batch, height, weight, inner_dim)
|
305 |
+
.permute(0, 3, 1, 2)
|
306 |
+
.contiguous()
|
307 |
+
)
|
308 |
+
|
309 |
+
output = hidden_states + residual
|
310 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
311 |
+
|
312 |
+
return output
|
313 |
+
|
314 |
+
|
315 |
+
class TemporalTransformerBlock(nn.Module):
|
316 |
+
"""
|
317 |
+
A Temporal Transformer Block class.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
- dim (int): Dimension of the block.
|
321 |
+
- num_attention_heads (int): Number of attention heads.
|
322 |
+
- attention_head_dim (int): Dimension of attention heads.
|
323 |
+
- attention_block_types (tuple): Types of attention blocks.
|
324 |
+
- dropout (float): Dropout rate.
|
325 |
+
- cross_attention_dim (int): Dimension for cross-attention.
|
326 |
+
- activation_fn (str): Activation function.
|
327 |
+
- attention_bias (bool): Flag for attention bias.
|
328 |
+
- upcast_attention (bool): Flag for upcast attention.
|
329 |
+
- cross_frame_attention_mode: Mode for cross-frame attention.
|
330 |
+
- temporal_position_encoding (bool): Flag for temporal position encoding.
|
331 |
+
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
|
332 |
+
"""
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
dim,
|
336 |
+
num_attention_heads,
|
337 |
+
attention_head_dim,
|
338 |
+
attention_block_types=(
|
339 |
+
"Temporal_Self",
|
340 |
+
"Temporal_Self",
|
341 |
+
),
|
342 |
+
dropout=0.0,
|
343 |
+
cross_attention_dim=768,
|
344 |
+
activation_fn="geglu",
|
345 |
+
attention_bias=False,
|
346 |
+
upcast_attention=False,
|
347 |
+
cross_frame_attention_mode=None,
|
348 |
+
temporal_position_encoding=False,
|
349 |
+
temporal_position_encoding_max_len=24,
|
350 |
+
):
|
351 |
+
super().__init__()
|
352 |
+
|
353 |
+
attention_blocks = []
|
354 |
+
norms = []
|
355 |
+
|
356 |
+
for block_name in attention_block_types:
|
357 |
+
attention_blocks.append(
|
358 |
+
VersatileAttention(
|
359 |
+
attention_mode=block_name.split("_", maxsplit=1)[0],
|
360 |
+
cross_attention_dim=cross_attention_dim
|
361 |
+
if block_name.endswith("_Cross")
|
362 |
+
else None,
|
363 |
+
query_dim=dim,
|
364 |
+
heads=num_attention_heads,
|
365 |
+
dim_head=attention_head_dim,
|
366 |
+
dropout=dropout,
|
367 |
+
bias=attention_bias,
|
368 |
+
upcast_attention=upcast_attention,
|
369 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
370 |
+
temporal_position_encoding=temporal_position_encoding,
|
371 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
372 |
+
)
|
373 |
+
)
|
374 |
+
norms.append(nn.LayerNorm(dim))
|
375 |
+
|
376 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
377 |
+
self.norms = nn.ModuleList(norms)
|
378 |
+
|
379 |
+
self.ff = FeedForward(dim, dropout=dropout,
|
380 |
+
activation_fn=activation_fn)
|
381 |
+
self.ff_norm = nn.LayerNorm(dim)
|
382 |
+
|
383 |
+
def forward(
|
384 |
+
self,
|
385 |
+
hidden_states,
|
386 |
+
encoder_hidden_states=None,
|
387 |
+
video_length=None,
|
388 |
+
):
|
389 |
+
"""
|
390 |
+
Forward pass for the TemporalTransformerBlock.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
hidden_states (torch.Tensor): The input hidden states with shape
|
394 |
+
(batch_size, video_length, in_channels).
|
395 |
+
encoder_hidden_states (torch.Tensor, optional): The encoder hidden states
|
396 |
+
with shape (batch_size, encoder_length, in_channels).
|
397 |
+
video_length (int, optional): The length of the video.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
torch.Tensor: The output hidden states with shape
|
401 |
+
(batch_size, video_length, in_channels).
|
402 |
+
"""
|
403 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
404 |
+
norm_hidden_states = norm(hidden_states)
|
405 |
+
hidden_states = (
|
406 |
+
attention_block(
|
407 |
+
norm_hidden_states,
|
408 |
+
encoder_hidden_states=encoder_hidden_states
|
409 |
+
if attention_block.is_cross_attention
|
410 |
+
else None,
|
411 |
+
video_length=video_length,
|
412 |
+
)
|
413 |
+
+ hidden_states
|
414 |
+
)
|
415 |
+
|
416 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
417 |
+
|
418 |
+
output = hidden_states
|
419 |
+
return output
|
420 |
+
|
421 |
+
|
422 |
+
class PositionalEncoding(nn.Module):
|
423 |
+
"""
|
424 |
+
Positional Encoding module for transformers.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
- d_model (int): Model dimension.
|
428 |
+
- dropout (float): Dropout rate.
|
429 |
+
- max_len (int): Maximum length for positional encoding.
|
430 |
+
"""
|
431 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
432 |
+
super().__init__()
|
433 |
+
self.dropout = nn.Dropout(p=dropout)
|
434 |
+
position = torch.arange(max_len).unsqueeze(1)
|
435 |
+
div_term = torch.exp(
|
436 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
437 |
+
)
|
438 |
+
pe = torch.zeros(1, max_len, d_model)
|
439 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
440 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
441 |
+
self.register_buffer("pe", pe)
|
442 |
+
|
443 |
+
def forward(self, x):
|
444 |
+
"""
|
445 |
+
Forward pass of the PositionalEncoding module.
|
446 |
+
|
447 |
+
This method takes an input tensor `x` and adds the positional encoding to it. The positional encoding is
|
448 |
+
generated based on the input tensor's shape and is added to the input tensor element-wise.
|
449 |
+
|
450 |
+
Args:
|
451 |
+
x (torch.Tensor): The input tensor to be positionally encoded.
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
torch.Tensor: The positionally encoded tensor.
|
455 |
+
"""
|
456 |
+
x = x + self.pe[:, : x.size(1)]
|
457 |
+
return self.dropout(x)
|
458 |
+
|
459 |
+
|
460 |
+
class VersatileAttention(Attention):
|
461 |
+
"""
|
462 |
+
Versatile Attention class.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
- attention_mode: Attention mode.
|
466 |
+
- temporal_position_encoding (bool): Flag for temporal position encoding.
|
467 |
+
- temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
|
468 |
+
"""
|
469 |
+
def __init__(
|
470 |
+
self,
|
471 |
+
*args,
|
472 |
+
attention_mode=None,
|
473 |
+
cross_frame_attention_mode=None,
|
474 |
+
temporal_position_encoding=False,
|
475 |
+
temporal_position_encoding_max_len=24,
|
476 |
+
**kwargs,
|
477 |
+
):
|
478 |
+
super().__init__(*args, **kwargs)
|
479 |
+
assert attention_mode == "Temporal"
|
480 |
+
|
481 |
+
self.attention_mode = attention_mode
|
482 |
+
self.is_cross_attention = kwargs.get("cross_attention_dim") is not None
|
483 |
+
|
484 |
+
self.pos_encoder = (
|
485 |
+
PositionalEncoding(
|
486 |
+
kwargs["query_dim"],
|
487 |
+
dropout=0.0,
|
488 |
+
max_len=temporal_position_encoding_max_len,
|
489 |
+
)
|
490 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
491 |
+
else None
|
492 |
+
)
|
493 |
+
|
494 |
+
def extra_repr(self):
|
495 |
+
"""
|
496 |
+
Returns a string representation of the module with information about the attention mode and whether it is cross-attention.
|
497 |
+
|
498 |
+
Returns:
|
499 |
+
str: A string representation of the module.
|
500 |
+
"""
|
501 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
502 |
+
|
503 |
+
def set_use_memory_efficient_attention_xformers(
|
504 |
+
self,
|
505 |
+
use_memory_efficient_attention_xformers: bool,
|
506 |
+
attention_op = None,
|
507 |
+
):
|
508 |
+
"""
|
509 |
+
Sets the use of memory-efficient attention xformers for the VersatileAttention class.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
use_memory_efficient_attention_xformers (bool): A boolean flag indicating whether to use memory-efficient attention xformers or not.
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
None
|
516 |
+
|
517 |
+
"""
|
518 |
+
if use_memory_efficient_attention_xformers:
|
519 |
+
if not is_xformers_available():
|
520 |
+
raise ModuleNotFoundError(
|
521 |
+
(
|
522 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
523 |
+
" xformers"
|
524 |
+
),
|
525 |
+
name="xformers",
|
526 |
+
)
|
527 |
+
|
528 |
+
if not torch.cuda.is_available():
|
529 |
+
raise ValueError(
|
530 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
531 |
+
" only available for GPU "
|
532 |
+
)
|
533 |
+
|
534 |
+
try:
|
535 |
+
# Make sure we can run the memory efficient attention
|
536 |
+
_ = xformers.ops.memory_efficient_attention(
|
537 |
+
torch.randn((1, 2, 40), device="cuda"),
|
538 |
+
torch.randn((1, 2, 40), device="cuda"),
|
539 |
+
torch.randn((1, 2, 40), device="cuda"),
|
540 |
+
)
|
541 |
+
except Exception as e:
|
542 |
+
raise e
|
543 |
+
processor = AttnProcessor()
|
544 |
+
else:
|
545 |
+
processor = AttnProcessor()
|
546 |
+
|
547 |
+
self.set_processor(processor)
|
548 |
+
|
549 |
+
def forward(
|
550 |
+
self,
|
551 |
+
hidden_states,
|
552 |
+
encoder_hidden_states=None,
|
553 |
+
attention_mask=None,
|
554 |
+
video_length=None,
|
555 |
+
**cross_attention_kwargs,
|
556 |
+
):
|
557 |
+
"""
|
558 |
+
Args:
|
559 |
+
hidden_states (`torch.Tensor`):
|
560 |
+
The hidden states to be passed through the model.
|
561 |
+
encoder_hidden_states (`torch.Tensor`, optional):
|
562 |
+
The encoder hidden states to be passed through the model.
|
563 |
+
attention_mask (`torch.Tensor`, optional):
|
564 |
+
The attention mask to be used in the model.
|
565 |
+
video_length (`int`, optional):
|
566 |
+
The length of the video.
|
567 |
+
cross_attention_kwargs (`dict`, optional):
|
568 |
+
Additional keyword arguments to be used for cross-attention.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
`torch.Tensor`:
|
572 |
+
The output tensor after passing through the model.
|
573 |
+
|
574 |
+
"""
|
575 |
+
if self.attention_mode == "Temporal":
|
576 |
+
d = hidden_states.shape[1] # d means HxW
|
577 |
+
hidden_states = rearrange(
|
578 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
579 |
+
)
|
580 |
+
|
581 |
+
if self.pos_encoder is not None:
|
582 |
+
hidden_states = self.pos_encoder(hidden_states)
|
583 |
+
|
584 |
+
encoder_hidden_states = (
|
585 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
586 |
+
if encoder_hidden_states is not None
|
587 |
+
else encoder_hidden_states
|
588 |
+
)
|
589 |
+
|
590 |
+
else:
|
591 |
+
raise NotImplementedError
|
592 |
+
|
593 |
+
hidden_states = self.processor(
|
594 |
+
self,
|
595 |
+
hidden_states,
|
596 |
+
encoder_hidden_states=encoder_hidden_states,
|
597 |
+
attention_mask=attention_mask,
|
598 |
+
**cross_attention_kwargs,
|
599 |
+
)
|
600 |
+
|
601 |
+
if self.attention_mode == "Temporal":
|
602 |
+
hidden_states = rearrange(
|
603 |
+
hidden_states, "(b d) f c -> (b f) d c", d=d)
|
604 |
+
|
605 |
+
return hidden_states
|
joyhallo/models/mutual_self_attention.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains the implementation of mutual self-attention,
|
3 |
+
which is a type of attention mechanism used in deep learning models.
|
4 |
+
The module includes several classes and functions related to attention mechanisms,
|
5 |
+
such as BasicTransformerBlock and TemporalBasicTransformerBlock.
|
6 |
+
The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning,
|
7 |
+
such as image and video processing, natural language processing, and so on.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from typing import Any, Dict, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
16 |
+
|
17 |
+
|
18 |
+
def torch_dfs(model: torch.nn.Module):
|
19 |
+
"""
|
20 |
+
Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture.
|
21 |
+
|
22 |
+
This function recursively traverses all the children modules of a given PyTorch model and returns a list
|
23 |
+
containing all the modules in the model's architecture. The DFS approach starts with the input model and
|
24 |
+
explores its children modules depth-wise before backtracking and exploring other branches.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model (torch.nn.Module): The root module of the neural network to traverse.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
list: A list of all the modules in the model's architecture.
|
31 |
+
"""
|
32 |
+
result = [model]
|
33 |
+
for child in model.children():
|
34 |
+
result += torch_dfs(child)
|
35 |
+
return result
|
36 |
+
|
37 |
+
|
38 |
+
class ReferenceAttentionControl:
|
39 |
+
"""
|
40 |
+
This class is used to control the reference attention mechanism in a neural network model.
|
41 |
+
It is responsible for managing the guidance and fusion blocks, and modifying the self-attention
|
42 |
+
and group normalization mechanisms. The class also provides methods for registering reference hooks
|
43 |
+
and updating/clearing the internal state of the attention control object.
|
44 |
+
|
45 |
+
Attributes:
|
46 |
+
unet: The UNet model associated with this attention control object.
|
47 |
+
mode: The operating mode of the attention control object, either 'write' or 'read'.
|
48 |
+
do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism.
|
49 |
+
attention_auto_machine_weight: The weight assigned to the attention auto-machine.
|
50 |
+
gn_auto_machine_weight: The weight assigned to the group normalization auto-machine.
|
51 |
+
style_fidelity: The style fidelity parameter for the attention mechanism.
|
52 |
+
reference_attn: Whether to use reference attention in the model.
|
53 |
+
reference_adain: Whether to use reference AdaIN in the model.
|
54 |
+
fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion').
|
55 |
+
batch_size: The batch size used for processing video frames.
|
56 |
+
|
57 |
+
Methods:
|
58 |
+
register_reference_hooks: Registers the reference hooks for the attention control object.
|
59 |
+
hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block.
|
60 |
+
update: Updates the internal state of the attention control object using the provided writer and dtype.
|
61 |
+
clear: Clears the internal state of the attention control object.
|
62 |
+
"""
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
unet,
|
66 |
+
mode="write",
|
67 |
+
do_classifier_free_guidance=False,
|
68 |
+
attention_auto_machine_weight=float("inf"),
|
69 |
+
gn_auto_machine_weight=1.0,
|
70 |
+
style_fidelity=1.0,
|
71 |
+
reference_attn=True,
|
72 |
+
reference_adain=False,
|
73 |
+
fusion_blocks="midup",
|
74 |
+
batch_size=1,
|
75 |
+
) -> None:
|
76 |
+
"""
|
77 |
+
Initializes the ReferenceAttentionControl class.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
unet (torch.nn.Module): The UNet model.
|
81 |
+
mode (str, optional): The mode of operation. Defaults to "write".
|
82 |
+
do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False.
|
83 |
+
attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity.
|
84 |
+
gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0.
|
85 |
+
style_fidelity (float, optional): The style fidelity. Defaults to 1.0.
|
86 |
+
reference_attn (bool, optional): Whether to use reference attention. Defaults to True.
|
87 |
+
reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False.
|
88 |
+
fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup".
|
89 |
+
batch_size (int, optional): The batch size. Defaults to 1.
|
90 |
+
|
91 |
+
Raises:
|
92 |
+
ValueError: If the mode is not recognized.
|
93 |
+
ValueError: If the fusion blocks are not recognized.
|
94 |
+
"""
|
95 |
+
# 10. Modify self attention and group norm
|
96 |
+
self.unet = unet
|
97 |
+
assert mode in ["read", "write"]
|
98 |
+
assert fusion_blocks in ["midup", "full"]
|
99 |
+
self.reference_attn = reference_attn
|
100 |
+
self.reference_adain = reference_adain
|
101 |
+
self.fusion_blocks = fusion_blocks
|
102 |
+
self.register_reference_hooks(
|
103 |
+
mode,
|
104 |
+
do_classifier_free_guidance,
|
105 |
+
attention_auto_machine_weight,
|
106 |
+
gn_auto_machine_weight,
|
107 |
+
style_fidelity,
|
108 |
+
reference_attn,
|
109 |
+
reference_adain,
|
110 |
+
fusion_blocks,
|
111 |
+
batch_size=batch_size,
|
112 |
+
)
|
113 |
+
|
114 |
+
def register_reference_hooks(
|
115 |
+
self,
|
116 |
+
mode,
|
117 |
+
do_classifier_free_guidance,
|
118 |
+
_attention_auto_machine_weight,
|
119 |
+
_gn_auto_machine_weight,
|
120 |
+
_style_fidelity,
|
121 |
+
_reference_attn,
|
122 |
+
_reference_adain,
|
123 |
+
_dtype=torch.float16,
|
124 |
+
batch_size=1,
|
125 |
+
num_images_per_prompt=1,
|
126 |
+
device=torch.device("cpu"),
|
127 |
+
_fusion_blocks="midup",
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Registers reference hooks for the model.
|
131 |
+
|
132 |
+
This function is responsible for registering reference hooks in the model,
|
133 |
+
which are used to modify the attention mechanism and group normalization layers.
|
134 |
+
It takes various parameters as input, such as mode,
|
135 |
+
do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity,
|
136 |
+
_reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
self: Reference to the instance of the class.
|
140 |
+
mode: The mode of operation for the reference hooks.
|
141 |
+
do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance.
|
142 |
+
_attention_auto_machine_weight: The weight for the attention auto-machine.
|
143 |
+
_gn_auto_machine_weight: The weight for the group normalization auto-machine.
|
144 |
+
_style_fidelity: The style fidelity for the reference hooks.
|
145 |
+
_reference_attn: A boolean flag indicating whether to use reference attention.
|
146 |
+
_reference_adain: A boolean flag indicating whether to use reference AdaIN.
|
147 |
+
_dtype: The data type for the reference hooks.
|
148 |
+
batch_size: The batch size for the reference hooks.
|
149 |
+
num_images_per_prompt: The number of images per prompt for the reference hooks.
|
150 |
+
device: The device for the reference hooks.
|
151 |
+
_fusion_blocks: The fusion blocks for the reference hooks.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
None
|
155 |
+
"""
|
156 |
+
MODE = mode
|
157 |
+
if do_classifier_free_guidance:
|
158 |
+
uc_mask = (
|
159 |
+
torch.Tensor(
|
160 |
+
[1] * batch_size * num_images_per_prompt * 16
|
161 |
+
+ [0] * batch_size * num_images_per_prompt * 16
|
162 |
+
)
|
163 |
+
.to(device)
|
164 |
+
.bool()
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
uc_mask = (
|
168 |
+
torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
|
169 |
+
.to(device)
|
170 |
+
.bool()
|
171 |
+
)
|
172 |
+
|
173 |
+
def hacked_basic_transformer_inner_forward(
|
174 |
+
self,
|
175 |
+
hidden_states: torch.FloatTensor,
|
176 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
177 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
178 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
179 |
+
timestep: Optional[torch.LongTensor] = None,
|
180 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
181 |
+
class_labels: Optional[torch.LongTensor] = None,
|
182 |
+
video_length=None,
|
183 |
+
):
|
184 |
+
gate_msa = None
|
185 |
+
shift_mlp = None
|
186 |
+
scale_mlp = None
|
187 |
+
gate_mlp = None
|
188 |
+
|
189 |
+
if self.use_ada_layer_norm: # False
|
190 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
191 |
+
elif self.use_ada_layer_norm_zero:
|
192 |
+
(
|
193 |
+
norm_hidden_states,
|
194 |
+
gate_msa,
|
195 |
+
shift_mlp,
|
196 |
+
scale_mlp,
|
197 |
+
gate_mlp,
|
198 |
+
) = self.norm1(
|
199 |
+
hidden_states,
|
200 |
+
timestep,
|
201 |
+
class_labels,
|
202 |
+
hidden_dtype=hidden_states.dtype,
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
norm_hidden_states = self.norm1(hidden_states)
|
206 |
+
|
207 |
+
# 1. Self-Attention
|
208 |
+
# self.only_cross_attention = False
|
209 |
+
cross_attention_kwargs = (
|
210 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
211 |
+
)
|
212 |
+
if self.only_cross_attention:
|
213 |
+
attn_output = self.attn1(
|
214 |
+
norm_hidden_states,
|
215 |
+
encoder_hidden_states=(
|
216 |
+
encoder_hidden_states if self.only_cross_attention else None
|
217 |
+
),
|
218 |
+
attention_mask=attention_mask,
|
219 |
+
**cross_attention_kwargs,
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
if MODE == "write":
|
223 |
+
self.bank.append(norm_hidden_states.clone())
|
224 |
+
attn_output = self.attn1(
|
225 |
+
norm_hidden_states,
|
226 |
+
encoder_hidden_states=(
|
227 |
+
encoder_hidden_states if self.only_cross_attention else None
|
228 |
+
),
|
229 |
+
attention_mask=attention_mask,
|
230 |
+
**cross_attention_kwargs,
|
231 |
+
)
|
232 |
+
if MODE == "read":
|
233 |
+
|
234 |
+
bank_fea = [
|
235 |
+
rearrange(
|
236 |
+
rearrange(
|
237 |
+
d,
|
238 |
+
"(b s) l c -> b s l c",
|
239 |
+
b=norm_hidden_states.shape[0] // video_length,
|
240 |
+
)[:, 0, :, :]
|
241 |
+
# .unsqueeze(1)
|
242 |
+
.repeat(1, video_length, 1, 1),
|
243 |
+
"b t l c -> (b t) l c",
|
244 |
+
)
|
245 |
+
for d in self.bank
|
246 |
+
]
|
247 |
+
motion_frames_fea = [rearrange(
|
248 |
+
d,
|
249 |
+
"(b s) l c -> b s l c",
|
250 |
+
b=norm_hidden_states.shape[0] // video_length,
|
251 |
+
)[:, 1:, :, :] for d in self.bank]
|
252 |
+
modify_norm_hidden_states = torch.cat(
|
253 |
+
[norm_hidden_states] + bank_fea, dim=1
|
254 |
+
)
|
255 |
+
hidden_states_uc = (
|
256 |
+
self.attn1(
|
257 |
+
norm_hidden_states,
|
258 |
+
encoder_hidden_states=modify_norm_hidden_states,
|
259 |
+
attention_mask=attention_mask,
|
260 |
+
)
|
261 |
+
+ hidden_states
|
262 |
+
)
|
263 |
+
if do_classifier_free_guidance:
|
264 |
+
hidden_states_c = hidden_states_uc.clone()
|
265 |
+
_uc_mask = uc_mask.clone()
|
266 |
+
if hidden_states.shape[0] != _uc_mask.shape[0]:
|
267 |
+
_uc_mask = (
|
268 |
+
torch.Tensor(
|
269 |
+
[1] * (hidden_states.shape[0] // 2)
|
270 |
+
+ [0] * (hidden_states.shape[0] // 2)
|
271 |
+
)
|
272 |
+
.to(device)
|
273 |
+
.bool()
|
274 |
+
)
|
275 |
+
hidden_states_c[_uc_mask] = (
|
276 |
+
self.attn1(
|
277 |
+
norm_hidden_states[_uc_mask],
|
278 |
+
encoder_hidden_states=norm_hidden_states[_uc_mask],
|
279 |
+
attention_mask=attention_mask,
|
280 |
+
)
|
281 |
+
+ hidden_states[_uc_mask]
|
282 |
+
)
|
283 |
+
hidden_states = hidden_states_c.clone()
|
284 |
+
else:
|
285 |
+
hidden_states = hidden_states_uc
|
286 |
+
|
287 |
+
# self.bank.clear()
|
288 |
+
if self.attn2 is not None:
|
289 |
+
# Cross-Attention
|
290 |
+
norm_hidden_states = (
|
291 |
+
self.norm2(hidden_states, timestep)
|
292 |
+
if self.use_ada_layer_norm
|
293 |
+
else self.norm2(hidden_states)
|
294 |
+
)
|
295 |
+
hidden_states = (
|
296 |
+
self.attn2(
|
297 |
+
norm_hidden_states,
|
298 |
+
encoder_hidden_states=encoder_hidden_states,
|
299 |
+
attention_mask=attention_mask,
|
300 |
+
)
|
301 |
+
+ hidden_states
|
302 |
+
)
|
303 |
+
|
304 |
+
# Feed-forward
|
305 |
+
hidden_states = self.ff(self.norm3(
|
306 |
+
hidden_states)) + hidden_states
|
307 |
+
|
308 |
+
# Temporal-Attention
|
309 |
+
if self.unet_use_temporal_attention:
|
310 |
+
d = hidden_states.shape[1]
|
311 |
+
hidden_states = rearrange(
|
312 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
313 |
+
)
|
314 |
+
norm_hidden_states = (
|
315 |
+
self.norm_temp(hidden_states, timestep)
|
316 |
+
if self.use_ada_layer_norm
|
317 |
+
else self.norm_temp(hidden_states)
|
318 |
+
)
|
319 |
+
hidden_states = (
|
320 |
+
self.attn_temp(norm_hidden_states) + hidden_states
|
321 |
+
)
|
322 |
+
hidden_states = rearrange(
|
323 |
+
hidden_states, "(b d) f c -> (b f) d c", d=d
|
324 |
+
)
|
325 |
+
|
326 |
+
return hidden_states, motion_frames_fea
|
327 |
+
|
328 |
+
if self.use_ada_layer_norm_zero:
|
329 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
330 |
+
hidden_states = attn_output + hidden_states
|
331 |
+
|
332 |
+
if self.attn2 is not None:
|
333 |
+
norm_hidden_states = (
|
334 |
+
self.norm2(hidden_states, timestep)
|
335 |
+
if self.use_ada_layer_norm
|
336 |
+
else self.norm2(hidden_states)
|
337 |
+
)
|
338 |
+
|
339 |
+
# 2. Cross-Attention
|
340 |
+
tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
|
341 |
+
attn_output = self.attn2(
|
342 |
+
norm_hidden_states,
|
343 |
+
# TODO: repeat这个地方需要斟酌一下
|
344 |
+
encoder_hidden_states=encoder_hidden_states.repeat(
|
345 |
+
tmp, 1, 1),
|
346 |
+
attention_mask=encoder_attention_mask,
|
347 |
+
**cross_attention_kwargs,
|
348 |
+
)
|
349 |
+
hidden_states = attn_output + hidden_states
|
350 |
+
|
351 |
+
# 3. Feed-forward
|
352 |
+
norm_hidden_states = self.norm3(hidden_states)
|
353 |
+
|
354 |
+
if self.use_ada_layer_norm_zero:
|
355 |
+
norm_hidden_states = (
|
356 |
+
norm_hidden_states *
|
357 |
+
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
358 |
+
)
|
359 |
+
|
360 |
+
ff_output = self.ff(norm_hidden_states)
|
361 |
+
|
362 |
+
if self.use_ada_layer_norm_zero:
|
363 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
364 |
+
|
365 |
+
hidden_states = ff_output + hidden_states
|
366 |
+
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
if self.reference_attn:
|
370 |
+
if self.fusion_blocks == "midup":
|
371 |
+
attn_modules = [
|
372 |
+
module
|
373 |
+
for module in (
|
374 |
+
torch_dfs(self.unet.mid_block) +
|
375 |
+
torch_dfs(self.unet.up_blocks)
|
376 |
+
)
|
377 |
+
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
378 |
+
]
|
379 |
+
elif self.fusion_blocks == "full":
|
380 |
+
attn_modules = [
|
381 |
+
module
|
382 |
+
for module in torch_dfs(self.unet)
|
383 |
+
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
384 |
+
]
|
385 |
+
attn_modules = sorted(
|
386 |
+
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
387 |
+
)
|
388 |
+
|
389 |
+
for i, module in enumerate(attn_modules):
|
390 |
+
module._original_inner_forward = module.forward
|
391 |
+
if isinstance(module, BasicTransformerBlock):
|
392 |
+
module.forward = hacked_basic_transformer_inner_forward.__get__(
|
393 |
+
module,
|
394 |
+
BasicTransformerBlock)
|
395 |
+
if isinstance(module, TemporalBasicTransformerBlock):
|
396 |
+
module.forward = hacked_basic_transformer_inner_forward.__get__(
|
397 |
+
module,
|
398 |
+
TemporalBasicTransformerBlock)
|
399 |
+
|
400 |
+
module.bank = []
|
401 |
+
module.attn_weight = float(i) / float(len(attn_modules))
|
402 |
+
|
403 |
+
def update(self, writer, dtype=torch.float16):
|
404 |
+
"""
|
405 |
+
Update the model's parameters.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
writer (torch.nn.Module): The model's writer object.
|
409 |
+
dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16.
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
None.
|
413 |
+
"""
|
414 |
+
if self.reference_attn:
|
415 |
+
if self.fusion_blocks == "midup":
|
416 |
+
reader_attn_modules = [
|
417 |
+
module
|
418 |
+
for module in (
|
419 |
+
torch_dfs(self.unet.mid_block) +
|
420 |
+
torch_dfs(self.unet.up_blocks)
|
421 |
+
)
|
422 |
+
if isinstance(module, TemporalBasicTransformerBlock)
|
423 |
+
]
|
424 |
+
writer_attn_modules = [
|
425 |
+
module
|
426 |
+
for module in (
|
427 |
+
torch_dfs(writer.unet.mid_block)
|
428 |
+
+ torch_dfs(writer.unet.up_blocks)
|
429 |
+
)
|
430 |
+
if isinstance(module, BasicTransformerBlock)
|
431 |
+
]
|
432 |
+
elif self.fusion_blocks == "full":
|
433 |
+
reader_attn_modules = [
|
434 |
+
module
|
435 |
+
for module in torch_dfs(self.unet)
|
436 |
+
if isinstance(module, TemporalBasicTransformerBlock)
|
437 |
+
]
|
438 |
+
writer_attn_modules = [
|
439 |
+
module
|
440 |
+
for module in torch_dfs(writer.unet)
|
441 |
+
if isinstance(module, BasicTransformerBlock)
|
442 |
+
]
|
443 |
+
|
444 |
+
assert len(reader_attn_modules) == len(writer_attn_modules)
|
445 |
+
reader_attn_modules = sorted(
|
446 |
+
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
447 |
+
)
|
448 |
+
writer_attn_modules = sorted(
|
449 |
+
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
450 |
+
)
|
451 |
+
for r, w in zip(reader_attn_modules, writer_attn_modules):
|
452 |
+
r.bank = [v.clone().to(dtype) for v in w.bank]
|
453 |
+
|
454 |
+
|
455 |
+
def clear(self):
|
456 |
+
"""
|
457 |
+
Clears the attention bank of all reader attention modules.
|
458 |
+
|
459 |
+
This method is used when the `reference_attn` attribute is set to `True`.
|
460 |
+
It clears the attention bank of all reader attention modules inside the UNet
|
461 |
+
model based on the selected `fusion_blocks` mode.
|
462 |
+
|
463 |
+
If `fusion_blocks` is set to "midup", it searches for reader attention modules
|
464 |
+
in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set
|
465 |
+
to "full", it searches for reader attention modules in the entire UNet model.
|
466 |
+
|
467 |
+
It sorts the reader attention modules by the number of neurons in their
|
468 |
+
`norm1.normalized_shape[0]` attribute in descending order. This sorting ensures
|
469 |
+
that the modules with more neurons are cleared first.
|
470 |
+
|
471 |
+
Finally, it iterates through the sorted list of reader attention modules and
|
472 |
+
calls the `clear()` method on each module's `bank` attribute to clear the
|
473 |
+
attention bank.
|
474 |
+
"""
|
475 |
+
if self.reference_attn:
|
476 |
+
if self.fusion_blocks == "midup":
|
477 |
+
reader_attn_modules = [
|
478 |
+
module
|
479 |
+
for module in (
|
480 |
+
torch_dfs(self.unet.mid_block) +
|
481 |
+
torch_dfs(self.unet.up_blocks)
|
482 |
+
)
|
483 |
+
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
484 |
+
]
|
485 |
+
elif self.fusion_blocks == "full":
|
486 |
+
reader_attn_modules = [
|
487 |
+
module
|
488 |
+
for module in torch_dfs(self.unet)
|
489 |
+
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
490 |
+
]
|
491 |
+
reader_attn_modules = sorted(
|
492 |
+
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
493 |
+
)
|
494 |
+
for r in reader_attn_modules:
|
495 |
+
r.bank.clear()
|
joyhallo/models/resnet.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm,
|
3 |
+
Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct
|
4 |
+
a deep neural network model for image classification or other computer vision tasks.
|
5 |
+
|
6 |
+
Classes:
|
7 |
+
- InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d.
|
8 |
+
- InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm.
|
9 |
+
- Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor.
|
10 |
+
- Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor.
|
11 |
+
- ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures.
|
12 |
+
- Mish: A Mish activation function, which is a smooth, non-monotonic activation function.
|
13 |
+
|
14 |
+
To use this module, simply import the classes and functions you need and follow the instructions provided in
|
15 |
+
the respective class and function docstrings.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from einops import rearrange
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class InflatedConv3d(nn.Conv2d):
|
25 |
+
"""
|
26 |
+
InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method.
|
27 |
+
|
28 |
+
This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer
|
29 |
+
commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and
|
30 |
+
InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of
|
31 |
+
inflating 2D convolutional layers to 3D for use in 3D deep learning tasks.
|
32 |
+
|
33 |
+
Attributes:
|
34 |
+
Same as torch.nn.Conv2d.
|
35 |
+
|
36 |
+
Methods:
|
37 |
+
forward(self, x):
|
38 |
+
Performs 3D convolution on the input tensor x using the InflatedConv3d layer.
|
39 |
+
|
40 |
+
Example:
|
41 |
+
conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
|
42 |
+
output = conv_layer(input_tensor)
|
43 |
+
"""
|
44 |
+
def forward(self, x):
|
45 |
+
"""
|
46 |
+
Forward pass of the InflatedConv3d layer.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): Input tensor to the layer.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Output tensor after applying the InflatedConv3d layer.
|
53 |
+
"""
|
54 |
+
video_length = x.shape[2]
|
55 |
+
|
56 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
57 |
+
x = super().forward(x)
|
58 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
59 |
+
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
64 |
+
"""
|
65 |
+
InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm.
|
66 |
+
It is used to apply group normalization to 3D tensors.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
num_groups (int): The number of groups to divide the channels into.
|
70 |
+
num_channels (int): The number of channels in the input tensor.
|
71 |
+
eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5.
|
72 |
+
affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True.
|
73 |
+
|
74 |
+
Attributes:
|
75 |
+
weight (torch.Tensor): The learnable weight tensor for scale.
|
76 |
+
bias (torch.Tensor): The learnable bias tensor for shift.
|
77 |
+
|
78 |
+
Forward method:
|
79 |
+
x (torch.Tensor): Input tensor to be normalized.
|
80 |
+
return (torch.Tensor): Normalized tensor.
|
81 |
+
"""
|
82 |
+
def forward(self, x):
|
83 |
+
"""
|
84 |
+
Performs a forward pass through the CustomClassName.
|
85 |
+
|
86 |
+
:param x: Input tensor of shape (batch_size, channels, video_length, height, width).
|
87 |
+
:return: Output tensor of shape (batch_size, channels, video_length, height, width).
|
88 |
+
"""
|
89 |
+
video_length = x.shape[2]
|
90 |
+
|
91 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
92 |
+
x = super().forward(x)
|
93 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
94 |
+
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class Upsample3D(nn.Module):
|
99 |
+
"""
|
100 |
+
Upsample3D is a PyTorch module that upsamples a 3D tensor.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
channels (int): The number of channels in the input tensor.
|
104 |
+
use_conv (bool): Whether to use a convolutional layer for upsampling.
|
105 |
+
use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling.
|
106 |
+
out_channels (int): The number of channels in the output tensor.
|
107 |
+
name (str): The name of the convolutional layer.
|
108 |
+
"""
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
channels,
|
112 |
+
use_conv=False,
|
113 |
+
use_conv_transpose=False,
|
114 |
+
out_channels=None,
|
115 |
+
name="conv",
|
116 |
+
):
|
117 |
+
super().__init__()
|
118 |
+
self.channels = channels
|
119 |
+
self.out_channels = out_channels or channels
|
120 |
+
self.use_conv = use_conv
|
121 |
+
self.use_conv_transpose = use_conv_transpose
|
122 |
+
self.name = name
|
123 |
+
|
124 |
+
if use_conv_transpose:
|
125 |
+
raise NotImplementedError
|
126 |
+
if use_conv:
|
127 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
128 |
+
|
129 |
+
def forward(self, hidden_states, output_size=None):
|
130 |
+
"""
|
131 |
+
Forward pass of the Upsample3D class.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
hidden_states (torch.Tensor): Input tensor to be upsampled.
|
135 |
+
output_size (tuple, optional): Desired output size of the upsampled tensor.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
torch.Tensor: Upsampled tensor.
|
139 |
+
|
140 |
+
Raises:
|
141 |
+
AssertionError: If the number of channels in the input tensor does not match the expected channels.
|
142 |
+
"""
|
143 |
+
assert hidden_states.shape[1] == self.channels
|
144 |
+
|
145 |
+
if self.use_conv_transpose:
|
146 |
+
raise NotImplementedError
|
147 |
+
|
148 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
149 |
+
dtype = hidden_states.dtype
|
150 |
+
if dtype == torch.bfloat16:
|
151 |
+
hidden_states = hidden_states.to(torch.float32)
|
152 |
+
|
153 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
154 |
+
if hidden_states.shape[0] >= 64:
|
155 |
+
hidden_states = hidden_states.contiguous()
|
156 |
+
|
157 |
+
# if `output_size` is passed we force the interpolation output
|
158 |
+
# size and do not make use of `scale_factor=2`
|
159 |
+
if output_size is None:
|
160 |
+
hidden_states = F.interpolate(
|
161 |
+
hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
hidden_states = F.interpolate(
|
165 |
+
hidden_states, size=output_size, mode="nearest"
|
166 |
+
)
|
167 |
+
|
168 |
+
# If the input is bfloat16, we cast back to bfloat16
|
169 |
+
if dtype == torch.bfloat16:
|
170 |
+
hidden_states = hidden_states.to(dtype)
|
171 |
+
|
172 |
+
# if self.use_conv:
|
173 |
+
# if self.name == "conv":
|
174 |
+
# hidden_states = self.conv(hidden_states)
|
175 |
+
# else:
|
176 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
177 |
+
hidden_states = self.conv(hidden_states)
|
178 |
+
|
179 |
+
return hidden_states
|
180 |
+
|
181 |
+
|
182 |
+
class Downsample3D(nn.Module):
|
183 |
+
"""
|
184 |
+
The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to
|
185 |
+
reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network.
|
186 |
+
|
187 |
+
Attributes:
|
188 |
+
channels (int): Number of input channels.
|
189 |
+
use_conv (bool): Flag to use a convolutional layer for downsampling.
|
190 |
+
out_channels (int, optional): Number of output channels. Defaults to input channels if None.
|
191 |
+
padding (int): Padding added to the input.
|
192 |
+
name (str): Name of the convolutional layer used for downsampling.
|
193 |
+
|
194 |
+
Methods:
|
195 |
+
forward(self, hidden_states):
|
196 |
+
Downsamples the input tensor hidden_states and returns the downsampled tensor.
|
197 |
+
"""
|
198 |
+
def __init__(
|
199 |
+
self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Downsamples the given input in the 3D space.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
channels: The number of input channels.
|
206 |
+
use_conv: Whether to use a convolutional layer for downsampling.
|
207 |
+
out_channels: The number of output channels. If None, the input channels are used.
|
208 |
+
padding: The amount of padding to be added to the input.
|
209 |
+
name: The name of the convolutional layer.
|
210 |
+
"""
|
211 |
+
super().__init__()
|
212 |
+
self.channels = channels
|
213 |
+
self.out_channels = out_channels or channels
|
214 |
+
self.use_conv = use_conv
|
215 |
+
self.padding = padding
|
216 |
+
stride = 2
|
217 |
+
self.name = name
|
218 |
+
|
219 |
+
if use_conv:
|
220 |
+
self.conv = InflatedConv3d(
|
221 |
+
self.channels, self.out_channels, 3, stride=stride, padding=padding
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
raise NotImplementedError
|
225 |
+
|
226 |
+
def forward(self, hidden_states):
|
227 |
+
"""
|
228 |
+
Forward pass for the Downsample3D class.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
hidden_states (torch.Tensor): Input tensor to be downsampled.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
torch.Tensor: Downsampled tensor.
|
235 |
+
|
236 |
+
Raises:
|
237 |
+
AssertionError: If the number of channels in the input tensor does not match the expected channels.
|
238 |
+
"""
|
239 |
+
assert hidden_states.shape[1] == self.channels
|
240 |
+
if self.use_conv and self.padding == 0:
|
241 |
+
raise NotImplementedError
|
242 |
+
|
243 |
+
assert hidden_states.shape[1] == self.channels
|
244 |
+
hidden_states = self.conv(hidden_states)
|
245 |
+
|
246 |
+
return hidden_states
|
247 |
+
|
248 |
+
|
249 |
+
class ResnetBlock3D(nn.Module):
|
250 |
+
"""
|
251 |
+
The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet
|
252 |
+
architectures for both image and video modeling tasks.
|
253 |
+
|
254 |
+
Attributes:
|
255 |
+
in_channels (int): Number of input channels.
|
256 |
+
out_channels (int, optional): Number of output channels, defaults to in_channels if None.
|
257 |
+
conv_shortcut (bool): Flag to use a convolutional shortcut.
|
258 |
+
dropout (float): Dropout rate.
|
259 |
+
temb_channels (int): Number of channels in the time embedding tensor.
|
260 |
+
groups (int): Number of groups for the group normalization layers.
|
261 |
+
eps (float): Epsilon value for group normalization.
|
262 |
+
non_linearity (str): Type of nonlinearity to apply after convolutions.
|
263 |
+
time_embedding_norm (str): Type of normalization for the time embedding.
|
264 |
+
output_scale_factor (float): Scaling factor for the output tensor.
|
265 |
+
use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection.
|
266 |
+
use_inflated_groupnorm (bool): Flag to use inflated group normalization layers.
|
267 |
+
|
268 |
+
Methods:
|
269 |
+
forward(self, input_tensor, temb):
|
270 |
+
Passes the input tensor and time embedding through the residual block and
|
271 |
+
returns the output tensor.
|
272 |
+
"""
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
*,
|
276 |
+
in_channels,
|
277 |
+
out_channels=None,
|
278 |
+
conv_shortcut=False,
|
279 |
+
dropout=0.0,
|
280 |
+
temb_channels=512,
|
281 |
+
groups=32,
|
282 |
+
groups_out=None,
|
283 |
+
pre_norm=True,
|
284 |
+
eps=1e-6,
|
285 |
+
non_linearity="swish",
|
286 |
+
time_embedding_norm="default",
|
287 |
+
output_scale_factor=1.0,
|
288 |
+
use_in_shortcut=None,
|
289 |
+
use_inflated_groupnorm=None,
|
290 |
+
):
|
291 |
+
super().__init__()
|
292 |
+
self.pre_norm = pre_norm
|
293 |
+
self.pre_norm = True
|
294 |
+
self.in_channels = in_channels
|
295 |
+
out_channels = in_channels if out_channels is None else out_channels
|
296 |
+
self.out_channels = out_channels
|
297 |
+
self.use_conv_shortcut = conv_shortcut
|
298 |
+
self.time_embedding_norm = time_embedding_norm
|
299 |
+
self.output_scale_factor = output_scale_factor
|
300 |
+
|
301 |
+
if groups_out is None:
|
302 |
+
groups_out = groups
|
303 |
+
|
304 |
+
assert use_inflated_groupnorm is not None
|
305 |
+
if use_inflated_groupnorm:
|
306 |
+
self.norm1 = InflatedGroupNorm(
|
307 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
308 |
+
)
|
309 |
+
else:
|
310 |
+
self.norm1 = torch.nn.GroupNorm(
|
311 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
312 |
+
)
|
313 |
+
|
314 |
+
self.conv1 = InflatedConv3d(
|
315 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
316 |
+
)
|
317 |
+
|
318 |
+
if temb_channels is not None:
|
319 |
+
if self.time_embedding_norm == "default":
|
320 |
+
time_emb_proj_out_channels = out_channels
|
321 |
+
elif self.time_embedding_norm == "scale_shift":
|
322 |
+
time_emb_proj_out_channels = out_channels * 2
|
323 |
+
else:
|
324 |
+
raise ValueError(
|
325 |
+
f"unknown time_embedding_norm : {self.time_embedding_norm} "
|
326 |
+
)
|
327 |
+
|
328 |
+
self.time_emb_proj = torch.nn.Linear(
|
329 |
+
temb_channels, time_emb_proj_out_channels
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
self.time_emb_proj = None
|
333 |
+
|
334 |
+
if use_inflated_groupnorm:
|
335 |
+
self.norm2 = InflatedGroupNorm(
|
336 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
self.norm2 = torch.nn.GroupNorm(
|
340 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
|
341 |
+
)
|
342 |
+
self.dropout = torch.nn.Dropout(dropout)
|
343 |
+
self.conv2 = InflatedConv3d(
|
344 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
345 |
+
)
|
346 |
+
|
347 |
+
if non_linearity == "swish":
|
348 |
+
self.nonlinearity = F.silu()
|
349 |
+
elif non_linearity == "mish":
|
350 |
+
self.nonlinearity = Mish()
|
351 |
+
elif non_linearity == "silu":
|
352 |
+
self.nonlinearity = nn.SiLU()
|
353 |
+
|
354 |
+
self.use_in_shortcut = (
|
355 |
+
self.in_channels != self.out_channels
|
356 |
+
if use_in_shortcut is None
|
357 |
+
else use_in_shortcut
|
358 |
+
)
|
359 |
+
|
360 |
+
self.conv_shortcut = None
|
361 |
+
if self.use_in_shortcut:
|
362 |
+
self.conv_shortcut = InflatedConv3d(
|
363 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
364 |
+
)
|
365 |
+
|
366 |
+
def forward(self, input_tensor, temb):
|
367 |
+
"""
|
368 |
+
Forward pass for the ResnetBlock3D class.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer.
|
372 |
+
temb (torch.Tensor): Token embedding tensor.
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
torch.Tensor: Output tensor after passing through the ResnetBlock3D layer.
|
376 |
+
"""
|
377 |
+
hidden_states = input_tensor
|
378 |
+
|
379 |
+
hidden_states = self.norm1(hidden_states)
|
380 |
+
hidden_states = self.nonlinearity(hidden_states)
|
381 |
+
|
382 |
+
hidden_states = self.conv1(hidden_states)
|
383 |
+
|
384 |
+
if temb is not None:
|
385 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
386 |
+
|
387 |
+
if temb is not None and self.time_embedding_norm == "default":
|
388 |
+
hidden_states = hidden_states + temb
|
389 |
+
|
390 |
+
hidden_states = self.norm2(hidden_states)
|
391 |
+
|
392 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
393 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
394 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
395 |
+
|
396 |
+
hidden_states = self.nonlinearity(hidden_states)
|
397 |
+
|
398 |
+
hidden_states = self.dropout(hidden_states)
|
399 |
+
hidden_states = self.conv2(hidden_states)
|
400 |
+
|
401 |
+
if self.conv_shortcut is not None:
|
402 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
403 |
+
|
404 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
405 |
+
|
406 |
+
return output_tensor
|
407 |
+
|
408 |
+
|
409 |
+
class Mish(torch.nn.Module):
|
410 |
+
"""
|
411 |
+
The Mish class implements the Mish activation function, a smooth, non-monotonic function
|
412 |
+
that can be used in neural networks as an alternative to traditional activation functions like ReLU.
|
413 |
+
|
414 |
+
Methods:
|
415 |
+
forward(self, hidden_states):
|
416 |
+
Applies the Mish activation function to the input tensor hidden_states and
|
417 |
+
returns the resulting tensor.
|
418 |
+
"""
|
419 |
+
def forward(self, hidden_states):
|
420 |
+
"""
|
421 |
+
Mish activation function.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
hidden_states (torch.Tensor): The output tensor after applying the Mish activation function.
|
428 |
+
"""
|
429 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
joyhallo/models/transformer_2d.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module defines the Transformer2DModel, a PyTorch model that extends ModelMixin and ConfigMixin. It includes
|
3 |
+
methods for gradient checkpointing, forward propagation, and various utility functions. The model is designed for
|
4 |
+
2D image-related tasks and uses LoRa (Low-Rank All-Attention) compatible layers for efficient attention computation.
|
5 |
+
|
6 |
+
The file includes the following import statements:
|
7 |
+
|
8 |
+
- From dataclasses import dataclass
|
9 |
+
- From typing import Any, Dict, Optional
|
10 |
+
- Import torch
|
11 |
+
- From diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
- From diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
13 |
+
- From diffusers.models.modeling_utils import ModelMixin
|
14 |
+
- From diffusers.models.normalization import AdaLayerNormSingle
|
15 |
+
- From diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
|
16 |
+
is_torch_version)
|
17 |
+
- From torch import nn
|
18 |
+
- From .attention import BasicTransformerBlock
|
19 |
+
|
20 |
+
The file also includes the following classes and functions:
|
21 |
+
|
22 |
+
- Transformer2DModel: A model class that extends ModelMixin and ConfigMixin. It includes methods for gradient
|
23 |
+
checkpointing, forward propagation, and various utility functions.
|
24 |
+
- _set_gradient_checkpointing: A utility function to set gradient checkpointing for a given module.
|
25 |
+
- forward: The forward propagation method for the Transformer2DModel.
|
26 |
+
|
27 |
+
To use this module, you can import the Transformer2DModel class and create an instance of the model with the desired
|
28 |
+
configuration. Then, you can use the forward method to pass input tensors through the model and get the output tensors.
|
29 |
+
"""
|
30 |
+
|
31 |
+
from dataclasses import dataclass
|
32 |
+
from typing import Any, Dict, Optional
|
33 |
+
|
34 |
+
import torch
|
35 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
36 |
+
# from diffusers.models.embeddings import CaptionProjection
|
37 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
38 |
+
from diffusers.models.modeling_utils import ModelMixin
|
39 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
40 |
+
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
|
41 |
+
is_torch_version)
|
42 |
+
from torch import nn
|
43 |
+
|
44 |
+
from .attention import BasicTransformerBlock
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class Transformer2DModelOutput(BaseOutput):
|
49 |
+
"""
|
50 |
+
The output of [`Transformer2DModel`].
|
51 |
+
|
52 |
+
Args:
|
53 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`
|
54 |
+
or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
55 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
56 |
+
distributions for the unnoised latent pixels.
|
57 |
+
"""
|
58 |
+
|
59 |
+
sample: torch.FloatTensor
|
60 |
+
ref_feature: torch.FloatTensor
|
61 |
+
|
62 |
+
|
63 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
64 |
+
"""
|
65 |
+
A 2D Transformer model for image-like data.
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
69 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
70 |
+
in_channels (`int`, *optional*):
|
71 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
72 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
73 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
74 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
75 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
76 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
77 |
+
num_vector_embeds (`int`, *optional*):
|
78 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
79 |
+
Includes the class for the masked latent pixel.
|
80 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
81 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
82 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
83 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
84 |
+
added to the hidden states.
|
85 |
+
|
86 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
87 |
+
attention_bias (`bool`, *optional*):
|
88 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
89 |
+
"""
|
90 |
+
|
91 |
+
_supports_gradient_checkpointing = True
|
92 |
+
|
93 |
+
@register_to_config
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
num_attention_heads: int = 16,
|
97 |
+
attention_head_dim: int = 88,
|
98 |
+
in_channels: Optional[int] = None,
|
99 |
+
out_channels: Optional[int] = None,
|
100 |
+
num_layers: int = 1,
|
101 |
+
dropout: float = 0.0,
|
102 |
+
norm_num_groups: int = 32,
|
103 |
+
cross_attention_dim: Optional[int] = None,
|
104 |
+
attention_bias: bool = False,
|
105 |
+
num_vector_embeds: Optional[int] = None,
|
106 |
+
patch_size: Optional[int] = None,
|
107 |
+
activation_fn: str = "geglu",
|
108 |
+
num_embeds_ada_norm: Optional[int] = None,
|
109 |
+
use_linear_projection: bool = False,
|
110 |
+
only_cross_attention: bool = False,
|
111 |
+
double_self_attention: bool = False,
|
112 |
+
upcast_attention: bool = False,
|
113 |
+
norm_type: str = "layer_norm",
|
114 |
+
norm_elementwise_affine: bool = True,
|
115 |
+
norm_eps: float = 1e-5,
|
116 |
+
attention_type: str = "default",
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.use_linear_projection = use_linear_projection
|
120 |
+
self.num_attention_heads = num_attention_heads
|
121 |
+
self.attention_head_dim = attention_head_dim
|
122 |
+
inner_dim = num_attention_heads * attention_head_dim
|
123 |
+
|
124 |
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
125 |
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
126 |
+
|
127 |
+
# 1. Transformer2DModel can process both standard continuous images of
|
128 |
+
# shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of
|
129 |
+
# shape `(batch_size, num_image_vectors)`
|
130 |
+
# Define whether input is continuous or discrete depending on configuration
|
131 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
132 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
133 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
134 |
+
|
135 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
136 |
+
deprecation_message = (
|
137 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
138 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
139 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
140 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
141 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
142 |
+
)
|
143 |
+
deprecate(
|
144 |
+
"norm_type!=num_embeds_ada_norm",
|
145 |
+
"1.0.0",
|
146 |
+
deprecation_message,
|
147 |
+
standard_warn=False,
|
148 |
+
)
|
149 |
+
norm_type = "ada_norm"
|
150 |
+
|
151 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
152 |
+
raise ValueError(
|
153 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
154 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
155 |
+
)
|
156 |
+
|
157 |
+
if self.is_input_vectorized and self.is_input_patches:
|
158 |
+
raise ValueError(
|
159 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
160 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
161 |
+
)
|
162 |
+
|
163 |
+
if (
|
164 |
+
not self.is_input_continuous
|
165 |
+
and not self.is_input_vectorized
|
166 |
+
and not self.is_input_patches
|
167 |
+
):
|
168 |
+
raise ValueError(
|
169 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
170 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
171 |
+
)
|
172 |
+
|
173 |
+
# 2. Define input layers
|
174 |
+
self.in_channels = in_channels
|
175 |
+
|
176 |
+
self.norm = torch.nn.GroupNorm(
|
177 |
+
num_groups=norm_num_groups,
|
178 |
+
num_channels=in_channels,
|
179 |
+
eps=1e-6,
|
180 |
+
affine=True,
|
181 |
+
)
|
182 |
+
if use_linear_projection:
|
183 |
+
self.proj_in = linear_cls(in_channels, inner_dim)
|
184 |
+
else:
|
185 |
+
self.proj_in = conv_cls(
|
186 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
187 |
+
)
|
188 |
+
|
189 |
+
# 3. Define transformers blocks
|
190 |
+
self.transformer_blocks = nn.ModuleList(
|
191 |
+
[
|
192 |
+
BasicTransformerBlock(
|
193 |
+
inner_dim,
|
194 |
+
num_attention_heads,
|
195 |
+
attention_head_dim,
|
196 |
+
dropout=dropout,
|
197 |
+
cross_attention_dim=cross_attention_dim,
|
198 |
+
activation_fn=activation_fn,
|
199 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
200 |
+
attention_bias=attention_bias,
|
201 |
+
only_cross_attention=only_cross_attention,
|
202 |
+
double_self_attention=double_self_attention,
|
203 |
+
upcast_attention=upcast_attention,
|
204 |
+
norm_type=norm_type,
|
205 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
206 |
+
norm_eps=norm_eps,
|
207 |
+
attention_type=attention_type,
|
208 |
+
)
|
209 |
+
for d in range(num_layers)
|
210 |
+
]
|
211 |
+
)
|
212 |
+
|
213 |
+
# 4. Define output layers
|
214 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
215 |
+
# TODO: should use out_channels for continuous projections
|
216 |
+
if use_linear_projection:
|
217 |
+
self.proj_out = linear_cls(inner_dim, in_channels)
|
218 |
+
else:
|
219 |
+
self.proj_out = conv_cls(
|
220 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
+
)
|
222 |
+
|
223 |
+
# 5. PixArt-Alpha blocks.
|
224 |
+
self.adaln_single = None
|
225 |
+
self.use_additional_conditions = False
|
226 |
+
if norm_type == "ada_norm_single":
|
227 |
+
self.use_additional_conditions = self.config.sample_size == 128
|
228 |
+
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
229 |
+
# additional conditions until we find better name
|
230 |
+
self.adaln_single = AdaLayerNormSingle(
|
231 |
+
inner_dim, use_additional_conditions=self.use_additional_conditions
|
232 |
+
)
|
233 |
+
|
234 |
+
self.caption_projection = None
|
235 |
+
|
236 |
+
self.gradient_checkpointing = False
|
237 |
+
|
238 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
239 |
+
if hasattr(module, "gradient_checkpointing"):
|
240 |
+
module.gradient_checkpointing = value
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
hidden_states: torch.Tensor,
|
245 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
246 |
+
timestep: Optional[torch.LongTensor] = None,
|
247 |
+
_added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
248 |
+
class_labels: Optional[torch.LongTensor] = None,
|
249 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
250 |
+
attention_mask: Optional[torch.Tensor] = None,
|
251 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
252 |
+
return_dict: bool = True,
|
253 |
+
):
|
254 |
+
"""
|
255 |
+
The [`Transformer2DModel`] forward method.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
|
259 |
+
`torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
260 |
+
Input `hidden_states`.
|
261 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
262 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
263 |
+
self-attention.
|
264 |
+
timestep ( `torch.LongTensor`, *optional*):
|
265 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
266 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
267 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
268 |
+
`AdaLayerZeroNorm`.
|
269 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
270 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
271 |
+
`self.processor` in
|
272 |
+
[diffusers.models.attention_processor]
|
273 |
+
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
274 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
275 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
276 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
277 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
278 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
279 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
280 |
+
|
281 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
282 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
283 |
+
|
284 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
285 |
+
above. This bias will be added to the cross-attention scores.
|
286 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
287 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
288 |
+
tuple.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
292 |
+
`tuple` where the first element is the sample tensor.
|
293 |
+
"""
|
294 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
295 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
296 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
297 |
+
# expects mask of shape:
|
298 |
+
# [batch, key_tokens]
|
299 |
+
# adds singleton query_tokens dimension:
|
300 |
+
# [batch, 1, key_tokens]
|
301 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
302 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
303 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
304 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
305 |
+
# assume that mask is expressed as:
|
306 |
+
# (1 = keep, 0 = discard)
|
307 |
+
# convert mask into a bias that can be added to attention scores:
|
308 |
+
# (keep = +0, discard = -10000.0)
|
309 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
310 |
+
attention_mask = attention_mask.unsqueeze(1)
|
311 |
+
|
312 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
313 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
314 |
+
encoder_attention_mask = (
|
315 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
316 |
+
) * -10000.0
|
317 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
318 |
+
|
319 |
+
# Retrieve lora scale.
|
320 |
+
lora_scale = (
|
321 |
+
cross_attention_kwargs.get("scale", 1.0)
|
322 |
+
if cross_attention_kwargs is not None
|
323 |
+
else 1.0
|
324 |
+
)
|
325 |
+
|
326 |
+
# 1. Input
|
327 |
+
batch, _, height, width = hidden_states.shape
|
328 |
+
residual = hidden_states
|
329 |
+
|
330 |
+
hidden_states = self.norm(hidden_states)
|
331 |
+
if not self.use_linear_projection:
|
332 |
+
hidden_states = (
|
333 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
334 |
+
if not USE_PEFT_BACKEND
|
335 |
+
else self.proj_in(hidden_states)
|
336 |
+
)
|
337 |
+
inner_dim = hidden_states.shape[1]
|
338 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
339 |
+
batch, height * width, inner_dim
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
inner_dim = hidden_states.shape[1]
|
343 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
344 |
+
batch, height * width, inner_dim
|
345 |
+
)
|
346 |
+
hidden_states = (
|
347 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
348 |
+
if not USE_PEFT_BACKEND
|
349 |
+
else self.proj_in(hidden_states)
|
350 |
+
)
|
351 |
+
|
352 |
+
# 2. Blocks
|
353 |
+
if self.caption_projection is not None:
|
354 |
+
batch_size = hidden_states.shape[0]
|
355 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
356 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
357 |
+
batch_size, -1, hidden_states.shape[-1]
|
358 |
+
)
|
359 |
+
|
360 |
+
ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
|
361 |
+
for block in self.transformer_blocks:
|
362 |
+
if self.training and self.gradient_checkpointing:
|
363 |
+
|
364 |
+
def create_custom_forward(module, return_dict=None):
|
365 |
+
def custom_forward(*inputs):
|
366 |
+
if return_dict is not None:
|
367 |
+
return module(*inputs, return_dict=return_dict)
|
368 |
+
|
369 |
+
return module(*inputs)
|
370 |
+
|
371 |
+
return custom_forward
|
372 |
+
|
373 |
+
ckpt_kwargs: Dict[str, Any] = (
|
374 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
375 |
+
)
|
376 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
377 |
+
create_custom_forward(block),
|
378 |
+
hidden_states,
|
379 |
+
attention_mask,
|
380 |
+
encoder_hidden_states,
|
381 |
+
encoder_attention_mask,
|
382 |
+
timestep,
|
383 |
+
cross_attention_kwargs,
|
384 |
+
class_labels,
|
385 |
+
**ckpt_kwargs,
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
hidden_states = block(
|
389 |
+
hidden_states, # shape [5, 4096, 320]
|
390 |
+
attention_mask=attention_mask,
|
391 |
+
encoder_hidden_states=encoder_hidden_states, # shape [1,4,768]
|
392 |
+
encoder_attention_mask=encoder_attention_mask,
|
393 |
+
timestep=timestep,
|
394 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
395 |
+
class_labels=class_labels,
|
396 |
+
)
|
397 |
+
|
398 |
+
# 3. Output
|
399 |
+
output = None
|
400 |
+
if self.is_input_continuous:
|
401 |
+
if not self.use_linear_projection:
|
402 |
+
hidden_states = (
|
403 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
404 |
+
.permute(0, 3, 1, 2)
|
405 |
+
.contiguous()
|
406 |
+
)
|
407 |
+
hidden_states = (
|
408 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
409 |
+
if not USE_PEFT_BACKEND
|
410 |
+
else self.proj_out(hidden_states)
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
hidden_states = (
|
414 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
415 |
+
if not USE_PEFT_BACKEND
|
416 |
+
else self.proj_out(hidden_states)
|
417 |
+
)
|
418 |
+
hidden_states = (
|
419 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
420 |
+
.permute(0, 3, 1, 2)
|
421 |
+
.contiguous()
|
422 |
+
)
|
423 |
+
|
424 |
+
output = hidden_states + residual
|
425 |
+
if not return_dict:
|
426 |
+
return (output, ref_feature)
|
427 |
+
|
428 |
+
return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
|
joyhallo/models/transformer_3d.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module implements the Transformer3DModel, a PyTorch model designed for processing
|
3 |
+
3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer
|
4 |
+
model with support for gradient checkpointing and various types of attention mechanisms.
|
5 |
+
The model can be configured with different parameters such as the number of attention heads,
|
6 |
+
attention head dimension, and the number of layers. It also supports the use of audio modules
|
7 |
+
for enhanced feature extraction from video data.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
15 |
+
from diffusers.models import ModelMixin
|
16 |
+
from diffusers.utils import BaseOutput
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from .attention import (AudioTemporalBasicTransformerBlock,
|
21 |
+
TemporalBasicTransformerBlock)
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class Transformer3DModelOutput(BaseOutput):
|
26 |
+
"""
|
27 |
+
The output of the [`Transformer3DModel`].
|
28 |
+
|
29 |
+
Attributes:
|
30 |
+
sample (`torch.FloatTensor`):
|
31 |
+
The output tensor from the transformer model, which is the result of processing the input
|
32 |
+
hidden states through the transformer blocks and any subsequent layers.
|
33 |
+
"""
|
34 |
+
sample: torch.FloatTensor
|
35 |
+
|
36 |
+
|
37 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
38 |
+
"""
|
39 |
+
Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model.
|
40 |
+
It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks.
|
41 |
+
The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method.
|
42 |
+
"""
|
43 |
+
_supports_gradient_checkpointing = True
|
44 |
+
|
45 |
+
@register_to_config
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
num_attention_heads: int = 16,
|
49 |
+
attention_head_dim: int = 88,
|
50 |
+
in_channels: Optional[int] = None,
|
51 |
+
num_layers: int = 1,
|
52 |
+
dropout: float = 0.0,
|
53 |
+
norm_num_groups: int = 32,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
attention_bias: bool = False,
|
56 |
+
activation_fn: str = "geglu",
|
57 |
+
num_embeds_ada_norm: Optional[int] = None,
|
58 |
+
use_linear_projection: bool = False,
|
59 |
+
only_cross_attention: bool = False,
|
60 |
+
upcast_attention: bool = False,
|
61 |
+
unet_use_cross_frame_attention=None,
|
62 |
+
unet_use_temporal_attention=None,
|
63 |
+
use_audio_module=False,
|
64 |
+
depth=0,
|
65 |
+
unet_block_name=None,
|
66 |
+
stack_enable_blocks_name = None,
|
67 |
+
stack_enable_blocks_depth = None,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self.use_linear_projection = use_linear_projection
|
71 |
+
self.num_attention_heads = num_attention_heads
|
72 |
+
self.attention_head_dim = attention_head_dim
|
73 |
+
inner_dim = num_attention_heads * attention_head_dim
|
74 |
+
self.use_audio_module = use_audio_module
|
75 |
+
# Define input layers
|
76 |
+
self.in_channels = in_channels
|
77 |
+
|
78 |
+
self.norm = torch.nn.GroupNorm(
|
79 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
80 |
+
)
|
81 |
+
if use_linear_projection:
|
82 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
83 |
+
else:
|
84 |
+
self.proj_in = nn.Conv2d(
|
85 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
86 |
+
)
|
87 |
+
|
88 |
+
if use_audio_module:
|
89 |
+
self.transformer_blocks = nn.ModuleList(
|
90 |
+
[
|
91 |
+
AudioTemporalBasicTransformerBlock(
|
92 |
+
inner_dim,
|
93 |
+
num_attention_heads,
|
94 |
+
attention_head_dim,
|
95 |
+
dropout=dropout,
|
96 |
+
cross_attention_dim=cross_attention_dim,
|
97 |
+
activation_fn=activation_fn,
|
98 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
99 |
+
attention_bias=attention_bias,
|
100 |
+
only_cross_attention=only_cross_attention,
|
101 |
+
upcast_attention=upcast_attention,
|
102 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
103 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
104 |
+
depth=depth,
|
105 |
+
unet_block_name=unet_block_name,
|
106 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
107 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
108 |
+
)
|
109 |
+
for d in range(num_layers)
|
110 |
+
]
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
# Define transformers blocks
|
114 |
+
self.transformer_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
TemporalBasicTransformerBlock(
|
117 |
+
inner_dim,
|
118 |
+
num_attention_heads,
|
119 |
+
attention_head_dim,
|
120 |
+
dropout=dropout,
|
121 |
+
cross_attention_dim=cross_attention_dim,
|
122 |
+
activation_fn=activation_fn,
|
123 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
124 |
+
attention_bias=attention_bias,
|
125 |
+
only_cross_attention=only_cross_attention,
|
126 |
+
upcast_attention=upcast_attention,
|
127 |
+
)
|
128 |
+
for d in range(num_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
|
132 |
+
# 4. Define output layers
|
133 |
+
if use_linear_projection:
|
134 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
135 |
+
else:
|
136 |
+
self.proj_out = nn.Conv2d(
|
137 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
|
138 |
+
)
|
139 |
+
|
140 |
+
self.gradient_checkpointing = False
|
141 |
+
|
142 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
143 |
+
if hasattr(module, "gradient_checkpointing"):
|
144 |
+
module.gradient_checkpointing = value
|
145 |
+
|
146 |
+
def forward(
|
147 |
+
self,
|
148 |
+
hidden_states,
|
149 |
+
encoder_hidden_states=None,
|
150 |
+
attention_mask=None,
|
151 |
+
full_mask=None,
|
152 |
+
face_mask=None,
|
153 |
+
lip_mask=None,
|
154 |
+
motion_scale=None,
|
155 |
+
timestep=None,
|
156 |
+
return_dict: bool = True,
|
157 |
+
):
|
158 |
+
"""
|
159 |
+
Forward pass for the Transformer3DModel.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
hidden_states (torch.Tensor): The input hidden states.
|
163 |
+
encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states.
|
164 |
+
attention_mask (torch.Tensor, optional): The attention mask.
|
165 |
+
full_mask (torch.Tensor, optional): The full mask.
|
166 |
+
face_mask (torch.Tensor, optional): The face mask.
|
167 |
+
lip_mask (torch.Tensor, optional): The lip mask.
|
168 |
+
timestep (int, optional): The current timestep.
|
169 |
+
return_dict (bool, optional): Whether to return a dictionary or a tuple.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel.
|
173 |
+
"""
|
174 |
+
# Input
|
175 |
+
assert (
|
176 |
+
hidden_states.dim() == 5
|
177 |
+
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
178 |
+
video_length = hidden_states.shape[2]
|
179 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
180 |
+
|
181 |
+
# TODO
|
182 |
+
if self.use_audio_module:
|
183 |
+
encoder_hidden_states = rearrange(
|
184 |
+
encoder_hidden_states,
|
185 |
+
"bs f margin dim -> (bs f) margin dim",
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
|
189 |
+
encoder_hidden_states = repeat(
|
190 |
+
encoder_hidden_states, "b n c -> (b f) n c", f=video_length
|
191 |
+
)
|
192 |
+
|
193 |
+
batch, _, height, weight = hidden_states.shape
|
194 |
+
residual = hidden_states
|
195 |
+
|
196 |
+
hidden_states = self.norm(hidden_states)
|
197 |
+
if not self.use_linear_projection:
|
198 |
+
hidden_states = self.proj_in(hidden_states)
|
199 |
+
inner_dim = hidden_states.shape[1]
|
200 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
201 |
+
batch, height * weight, inner_dim
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
inner_dim = hidden_states.shape[1]
|
205 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
206 |
+
batch, height * weight, inner_dim
|
207 |
+
)
|
208 |
+
hidden_states = self.proj_in(hidden_states)
|
209 |
+
|
210 |
+
# Blocks
|
211 |
+
motion_frames = []
|
212 |
+
for _, block in enumerate(self.transformer_blocks):
|
213 |
+
if isinstance(block, TemporalBasicTransformerBlock):
|
214 |
+
hidden_states, motion_frame_fea = block(
|
215 |
+
hidden_states,
|
216 |
+
encoder_hidden_states=encoder_hidden_states,
|
217 |
+
timestep=timestep,
|
218 |
+
video_length=video_length,
|
219 |
+
)
|
220 |
+
motion_frames.append(motion_frame_fea)
|
221 |
+
else:
|
222 |
+
hidden_states = block(
|
223 |
+
hidden_states, # shape [2, 4096, 320]
|
224 |
+
encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
|
225 |
+
attention_mask=attention_mask,
|
226 |
+
full_mask=full_mask,
|
227 |
+
face_mask=face_mask,
|
228 |
+
lip_mask=lip_mask,
|
229 |
+
timestep=timestep,
|
230 |
+
video_length=video_length,
|
231 |
+
motion_scale=motion_scale,
|
232 |
+
)
|
233 |
+
|
234 |
+
# Output
|
235 |
+
if not self.use_linear_projection:
|
236 |
+
hidden_states = (
|
237 |
+
hidden_states.reshape(batch, height, weight, inner_dim)
|
238 |
+
.permute(0, 3, 1, 2)
|
239 |
+
.contiguous()
|
240 |
+
)
|
241 |
+
hidden_states = self.proj_out(hidden_states)
|
242 |
+
else:
|
243 |
+
hidden_states = self.proj_out(hidden_states)
|
244 |
+
hidden_states = (
|
245 |
+
hidden_states.reshape(batch, height, weight, inner_dim)
|
246 |
+
.permute(0, 3, 1, 2)
|
247 |
+
.contiguous()
|
248 |
+
)
|
249 |
+
|
250 |
+
output = hidden_states + residual
|
251 |
+
|
252 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
253 |
+
if not return_dict:
|
254 |
+
return (output, motion_frames)
|
255 |
+
|
256 |
+
return Transformer3DModelOutput(sample=output)
|
joyhallo/models/unet_2d_blocks.py
ADDED
@@ -0,0 +1,1340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines the 2D blocks for the UNet model in a PyTorch implementation.
|
3 |
+
The UNet model is a popular architecture for image segmentation tasks,
|
4 |
+
which consists of an encoder, a decoder, and a skip connection mechanism.
|
5 |
+
The 2D blocks in this file include various types of layers, such as ResNet blocks,
|
6 |
+
Transformer blocks, and cross-attention blocks,
|
7 |
+
which are used to build the encoder and decoder parts of the UNet model.
|
8 |
+
The AutoencoderTinyBlock class is a simple autoencoder block for tiny models,
|
9 |
+
and the UNetMidBlock2D and CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D,
|
10 |
+
and UpBlock2D classes are used for the middle and decoder parts of the UNet model.
|
11 |
+
The classes and functions in this file provide a flexible and modular way
|
12 |
+
to construct the UNet model for different image segmentation tasks.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from diffusers.models.activations import get_activation
|
19 |
+
from diffusers.models.attention_processor import Attention
|
20 |
+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
21 |
+
from diffusers.models.transformers.dual_transformer_2d import \
|
22 |
+
DualTransformer2DModel
|
23 |
+
from diffusers.utils import is_torch_version, logging
|
24 |
+
from diffusers.utils.torch_utils import apply_freeu
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
from .transformer_2d import Transformer2DModel
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30 |
+
|
31 |
+
|
32 |
+
def get_down_block(
|
33 |
+
down_block_type: str,
|
34 |
+
num_layers: int,
|
35 |
+
in_channels: int,
|
36 |
+
out_channels: int,
|
37 |
+
temb_channels: int,
|
38 |
+
add_downsample: bool,
|
39 |
+
resnet_eps: float,
|
40 |
+
resnet_act_fn: str,
|
41 |
+
transformer_layers_per_block: int = 1,
|
42 |
+
num_attention_heads: Optional[int] = None,
|
43 |
+
resnet_groups: Optional[int] = None,
|
44 |
+
cross_attention_dim: Optional[int] = None,
|
45 |
+
downsample_padding: Optional[int] = None,
|
46 |
+
dual_cross_attention: bool = False,
|
47 |
+
use_linear_projection: bool = False,
|
48 |
+
only_cross_attention: bool = False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
resnet_time_scale_shift: str = "default",
|
51 |
+
attention_type: str = "default",
|
52 |
+
attention_head_dim: Optional[int] = None,
|
53 |
+
dropout: float = 0.0,
|
54 |
+
):
|
55 |
+
""" This function creates and returns a UpBlock2D or CrossAttnUpBlock2D object based on the given up_block_type.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
up_block_type (str): The type of up block to create. Must be either "UpBlock2D" or "CrossAttnUpBlock2D".
|
59 |
+
num_layers (int): The number of layers in the ResNet block.
|
60 |
+
in_channels (int): The number of input channels.
|
61 |
+
out_channels (int): The number of output channels.
|
62 |
+
prev_output_channel (int): The number of channels in the previous output.
|
63 |
+
temb_channels (int): The number of channels in the token embedding.
|
64 |
+
add_upsample (bool): Whether to add an upsample layer after the ResNet block. Defaults to True.
|
65 |
+
resnet_eps (float): The epsilon value for the ResNet block. Defaults to 1e-6.
|
66 |
+
resnet_act_fn (str): The activation function to use in the ResNet block. Defaults to "swish".
|
67 |
+
resnet_groups (int): The number of groups in the ResNet block. Defaults to 32.
|
68 |
+
resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet block. Defaults to True.
|
69 |
+
output_scale_factor (float): The scale factor to apply to the output. Defaults to 1.0.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
nn.Module: The created UpBlock2D or CrossAttnUpBlock2D object.
|
73 |
+
"""
|
74 |
+
# If attn head dim is not defined, we default it to the number of heads
|
75 |
+
if attention_head_dim is None:
|
76 |
+
logger.warning("It is recommended to provide `attention_head_dim` when calling `get_down_block`.")
|
77 |
+
logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
|
78 |
+
attention_head_dim = num_attention_heads
|
79 |
+
|
80 |
+
down_block_type = (
|
81 |
+
down_block_type[7:]
|
82 |
+
if down_block_type.startswith("UNetRes")
|
83 |
+
else down_block_type
|
84 |
+
)
|
85 |
+
if down_block_type == "DownBlock2D":
|
86 |
+
return DownBlock2D(
|
87 |
+
num_layers=num_layers,
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=out_channels,
|
90 |
+
temb_channels=temb_channels,
|
91 |
+
dropout=dropout,
|
92 |
+
add_downsample=add_downsample,
|
93 |
+
resnet_eps=resnet_eps,
|
94 |
+
resnet_act_fn=resnet_act_fn,
|
95 |
+
resnet_groups=resnet_groups,
|
96 |
+
downsample_padding=downsample_padding,
|
97 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
98 |
+
)
|
99 |
+
|
100 |
+
if down_block_type == "CrossAttnDownBlock2D":
|
101 |
+
if cross_attention_dim is None:
|
102 |
+
raise ValueError(
|
103 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock2D"
|
104 |
+
)
|
105 |
+
return CrossAttnDownBlock2D(
|
106 |
+
num_layers=num_layers,
|
107 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
108 |
+
in_channels=in_channels,
|
109 |
+
out_channels=out_channels,
|
110 |
+
temb_channels=temb_channels,
|
111 |
+
dropout=dropout,
|
112 |
+
add_downsample=add_downsample,
|
113 |
+
resnet_eps=resnet_eps,
|
114 |
+
resnet_act_fn=resnet_act_fn,
|
115 |
+
resnet_groups=resnet_groups,
|
116 |
+
downsample_padding=downsample_padding,
|
117 |
+
cross_attention_dim=cross_attention_dim,
|
118 |
+
num_attention_heads=num_attention_heads,
|
119 |
+
dual_cross_attention=dual_cross_attention,
|
120 |
+
use_linear_projection=use_linear_projection,
|
121 |
+
only_cross_attention=only_cross_attention,
|
122 |
+
upcast_attention=upcast_attention,
|
123 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
124 |
+
attention_type=attention_type,
|
125 |
+
)
|
126 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
127 |
+
|
128 |
+
|
129 |
+
def get_up_block(
|
130 |
+
up_block_type: str,
|
131 |
+
num_layers: int,
|
132 |
+
in_channels: int,
|
133 |
+
out_channels: int,
|
134 |
+
prev_output_channel: int,
|
135 |
+
temb_channels: int,
|
136 |
+
add_upsample: bool,
|
137 |
+
resnet_eps: float,
|
138 |
+
resnet_act_fn: str,
|
139 |
+
resolution_idx: Optional[int] = None,
|
140 |
+
transformer_layers_per_block: int = 1,
|
141 |
+
num_attention_heads: Optional[int] = None,
|
142 |
+
resnet_groups: Optional[int] = None,
|
143 |
+
cross_attention_dim: Optional[int] = None,
|
144 |
+
dual_cross_attention: bool = False,
|
145 |
+
use_linear_projection: bool = False,
|
146 |
+
only_cross_attention: bool = False,
|
147 |
+
upcast_attention: bool = False,
|
148 |
+
resnet_time_scale_shift: str = "default",
|
149 |
+
attention_type: str = "default",
|
150 |
+
attention_head_dim: Optional[int] = None,
|
151 |
+
dropout: float = 0.0,
|
152 |
+
) -> nn.Module:
|
153 |
+
""" This function ...
|
154 |
+
Args:
|
155 |
+
Returns:
|
156 |
+
"""
|
157 |
+
# If attn head dim is not defined, we default it to the number of heads
|
158 |
+
if attention_head_dim is None:
|
159 |
+
logger.warning("It is recommended to provide `attention_head_dim` when calling `get_up_block`.")
|
160 |
+
logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
|
161 |
+
attention_head_dim = num_attention_heads
|
162 |
+
|
163 |
+
up_block_type = (
|
164 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
165 |
+
)
|
166 |
+
if up_block_type == "UpBlock2D":
|
167 |
+
return UpBlock2D(
|
168 |
+
num_layers=num_layers,
|
169 |
+
in_channels=in_channels,
|
170 |
+
out_channels=out_channels,
|
171 |
+
prev_output_channel=prev_output_channel,
|
172 |
+
temb_channels=temb_channels,
|
173 |
+
resolution_idx=resolution_idx,
|
174 |
+
dropout=dropout,
|
175 |
+
add_upsample=add_upsample,
|
176 |
+
resnet_eps=resnet_eps,
|
177 |
+
resnet_act_fn=resnet_act_fn,
|
178 |
+
resnet_groups=resnet_groups,
|
179 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
180 |
+
)
|
181 |
+
if up_block_type == "CrossAttnUpBlock2D":
|
182 |
+
if cross_attention_dim is None:
|
183 |
+
raise ValueError(
|
184 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock2D"
|
185 |
+
)
|
186 |
+
return CrossAttnUpBlock2D(
|
187 |
+
num_layers=num_layers,
|
188 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
189 |
+
in_channels=in_channels,
|
190 |
+
out_channels=out_channels,
|
191 |
+
prev_output_channel=prev_output_channel,
|
192 |
+
temb_channels=temb_channels,
|
193 |
+
resolution_idx=resolution_idx,
|
194 |
+
dropout=dropout,
|
195 |
+
add_upsample=add_upsample,
|
196 |
+
resnet_eps=resnet_eps,
|
197 |
+
resnet_act_fn=resnet_act_fn,
|
198 |
+
resnet_groups=resnet_groups,
|
199 |
+
cross_attention_dim=cross_attention_dim,
|
200 |
+
num_attention_heads=num_attention_heads,
|
201 |
+
dual_cross_attention=dual_cross_attention,
|
202 |
+
use_linear_projection=use_linear_projection,
|
203 |
+
only_cross_attention=only_cross_attention,
|
204 |
+
upcast_attention=upcast_attention,
|
205 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
206 |
+
attention_type=attention_type,
|
207 |
+
)
|
208 |
+
|
209 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
210 |
+
|
211 |
+
|
212 |
+
class AutoencoderTinyBlock(nn.Module):
|
213 |
+
"""
|
214 |
+
Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
|
215 |
+
blocks.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
in_channels (`int`): The number of input channels.
|
219 |
+
out_channels (`int`): The number of output channels.
|
220 |
+
act_fn (`str`):
|
221 |
+
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
`torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
|
225 |
+
`out_channels`.
|
226 |
+
"""
|
227 |
+
|
228 |
+
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
|
229 |
+
super().__init__()
|
230 |
+
act_fn = get_activation(act_fn)
|
231 |
+
self.conv = nn.Sequential(
|
232 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
233 |
+
act_fn,
|
234 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
235 |
+
act_fn,
|
236 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
237 |
+
)
|
238 |
+
self.skip = (
|
239 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
240 |
+
if in_channels != out_channels
|
241 |
+
else nn.Identity()
|
242 |
+
)
|
243 |
+
self.fuse = nn.ReLU()
|
244 |
+
|
245 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
246 |
+
"""
|
247 |
+
Forward pass of the AutoencoderTinyBlock class.
|
248 |
+
|
249 |
+
Parameters:
|
250 |
+
x (torch.FloatTensor): The input tensor to the AutoencoderTinyBlock.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
torch.FloatTensor: The output tensor after passing through the AutoencoderTinyBlock.
|
254 |
+
"""
|
255 |
+
return self.fuse(self.conv(x) + self.skip(x))
|
256 |
+
|
257 |
+
|
258 |
+
class UNetMidBlock2D(nn.Module):
|
259 |
+
"""
|
260 |
+
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
in_channels (`int`): The number of input channels.
|
264 |
+
temb_channels (`int`): The number of temporal embedding channels.
|
265 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
266 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
267 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
268 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
269 |
+
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
270 |
+
model on tasks with long-range temporal dependencies.
|
271 |
+
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
272 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
273 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
274 |
+
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
275 |
+
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
276 |
+
Whether to use pre-normalization for the resnet blocks.
|
277 |
+
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
278 |
+
attention_head_dim (`int`, *optional*, defaults to 1):
|
279 |
+
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
280 |
+
the number of input channels.
|
281 |
+
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
285 |
+
in_channels, height, width)`.
|
286 |
+
|
287 |
+
"""
|
288 |
+
|
289 |
+
def __init__(
|
290 |
+
self,
|
291 |
+
in_channels: int,
|
292 |
+
temb_channels: int,
|
293 |
+
dropout: float = 0.0,
|
294 |
+
num_layers: int = 1,
|
295 |
+
resnet_eps: float = 1e-6,
|
296 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
297 |
+
resnet_act_fn: str = "swish",
|
298 |
+
resnet_groups: int = 32,
|
299 |
+
attn_groups: Optional[int] = None,
|
300 |
+
resnet_pre_norm: bool = True,
|
301 |
+
add_attention: bool = True,
|
302 |
+
attention_head_dim: int = 1,
|
303 |
+
output_scale_factor: float = 1.0,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
resnet_groups = (
|
307 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
308 |
+
)
|
309 |
+
self.add_attention = add_attention
|
310 |
+
|
311 |
+
if attn_groups is None:
|
312 |
+
attn_groups = (
|
313 |
+
resnet_groups if resnet_time_scale_shift == "default" else None
|
314 |
+
)
|
315 |
+
|
316 |
+
# there is always at least one resnet
|
317 |
+
resnets = [
|
318 |
+
ResnetBlock2D(
|
319 |
+
in_channels=in_channels,
|
320 |
+
out_channels=in_channels,
|
321 |
+
temb_channels=temb_channels,
|
322 |
+
eps=resnet_eps,
|
323 |
+
groups=resnet_groups,
|
324 |
+
dropout=dropout,
|
325 |
+
time_embedding_norm=resnet_time_scale_shift,
|
326 |
+
non_linearity=resnet_act_fn,
|
327 |
+
output_scale_factor=output_scale_factor,
|
328 |
+
pre_norm=resnet_pre_norm,
|
329 |
+
)
|
330 |
+
]
|
331 |
+
attentions = []
|
332 |
+
|
333 |
+
if attention_head_dim is None:
|
334 |
+
logger.warning(
|
335 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
336 |
+
)
|
337 |
+
attention_head_dim = in_channels
|
338 |
+
|
339 |
+
for _ in range(num_layers):
|
340 |
+
if self.add_attention:
|
341 |
+
attentions.append(
|
342 |
+
Attention(
|
343 |
+
in_channels,
|
344 |
+
heads=in_channels // attention_head_dim,
|
345 |
+
dim_head=attention_head_dim,
|
346 |
+
rescale_output_factor=output_scale_factor,
|
347 |
+
eps=resnet_eps,
|
348 |
+
norm_num_groups=attn_groups,
|
349 |
+
spatial_norm_dim=(
|
350 |
+
temb_channels
|
351 |
+
if resnet_time_scale_shift == "spatial"
|
352 |
+
else None
|
353 |
+
),
|
354 |
+
residual_connection=True,
|
355 |
+
bias=True,
|
356 |
+
upcast_softmax=True,
|
357 |
+
_from_deprecated_attn_block=True,
|
358 |
+
)
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
attentions.append(None)
|
362 |
+
|
363 |
+
resnets.append(
|
364 |
+
ResnetBlock2D(
|
365 |
+
in_channels=in_channels,
|
366 |
+
out_channels=in_channels,
|
367 |
+
temb_channels=temb_channels,
|
368 |
+
eps=resnet_eps,
|
369 |
+
groups=resnet_groups,
|
370 |
+
dropout=dropout,
|
371 |
+
time_embedding_norm=resnet_time_scale_shift,
|
372 |
+
non_linearity=resnet_act_fn,
|
373 |
+
output_scale_factor=output_scale_factor,
|
374 |
+
pre_norm=resnet_pre_norm,
|
375 |
+
)
|
376 |
+
)
|
377 |
+
|
378 |
+
self.attentions = nn.ModuleList(attentions)
|
379 |
+
self.resnets = nn.ModuleList(resnets)
|
380 |
+
|
381 |
+
def forward(
|
382 |
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
|
383 |
+
) -> torch.FloatTensor:
|
384 |
+
"""
|
385 |
+
Forward pass of the UNetMidBlock2D class.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
hidden_states (torch.FloatTensor): The input tensor to the UNetMidBlock2D.
|
389 |
+
temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
torch.FloatTensor: The output tensor after passing through the UNetMidBlock2D.
|
393 |
+
"""
|
394 |
+
# Your implementation here
|
395 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
396 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
397 |
+
if attn is not None:
|
398 |
+
hidden_states = attn(hidden_states, temb=temb)
|
399 |
+
hidden_states = resnet(hidden_states, temb)
|
400 |
+
|
401 |
+
return hidden_states
|
402 |
+
|
403 |
+
|
404 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
405 |
+
"""
|
406 |
+
UNetMidBlock2DCrossAttn is a class that represents a mid-block 2D UNet with cross-attention.
|
407 |
+
|
408 |
+
This block is responsible for processing the input tensor with a series of residual blocks,
|
409 |
+
and applying cross-attention mechanism to attend to the global information in the encoder.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
in_channels (int): The number of input channels.
|
413 |
+
temb_channels (int): The number of channels for the token embedding.
|
414 |
+
dropout (float, optional): The dropout rate. Defaults to 0.0.
|
415 |
+
num_layers (int, optional): The number of layers in the residual blocks. Defaults to 1.
|
416 |
+
resnet_eps (float, optional): The epsilon value for the residual blocks. Defaults to 1e-6.
|
417 |
+
resnet_time_scale_shift (str, optional): The time scale shift type for the residual blocks. Defaults to "default".
|
418 |
+
resnet_act_fn (str, optional): The activation function for the residual blocks. Defaults to "swish".
|
419 |
+
resnet_groups (int, optional): The number of groups for the residual blocks. Defaults to 32.
|
420 |
+
resnet_pre_norm (bool, optional): Whether to apply pre-normalization for the residual blocks. Defaults to True.
|
421 |
+
num_attention_heads (int, optional): The number of attention heads for cross-attention. Defaults to 1.
|
422 |
+
cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 1280.
|
423 |
+
output_scale_factor (float, optional): The scale factor for the output tensor. Defaults to 1.0.
|
424 |
+
"""
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
in_channels: int,
|
428 |
+
temb_channels: int,
|
429 |
+
dropout: float = 0.0,
|
430 |
+
num_layers: int = 1,
|
431 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
432 |
+
resnet_eps: float = 1e-6,
|
433 |
+
resnet_time_scale_shift: str = "default",
|
434 |
+
resnet_act_fn: str = "swish",
|
435 |
+
resnet_groups: int = 32,
|
436 |
+
resnet_pre_norm: bool = True,
|
437 |
+
num_attention_heads: int = 1,
|
438 |
+
output_scale_factor: float = 1.0,
|
439 |
+
cross_attention_dim: int = 1280,
|
440 |
+
dual_cross_attention: bool = False,
|
441 |
+
use_linear_projection: bool = False,
|
442 |
+
upcast_attention: bool = False,
|
443 |
+
attention_type: str = "default",
|
444 |
+
):
|
445 |
+
super().__init__()
|
446 |
+
|
447 |
+
self.has_cross_attention = True
|
448 |
+
self.num_attention_heads = num_attention_heads
|
449 |
+
resnet_groups = (
|
450 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
451 |
+
)
|
452 |
+
|
453 |
+
# support for variable transformer layers per block
|
454 |
+
if isinstance(transformer_layers_per_block, int):
|
455 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
456 |
+
|
457 |
+
# there is always at least one resnet
|
458 |
+
resnets = [
|
459 |
+
ResnetBlock2D(
|
460 |
+
in_channels=in_channels,
|
461 |
+
out_channels=in_channels,
|
462 |
+
temb_channels=temb_channels,
|
463 |
+
eps=resnet_eps,
|
464 |
+
groups=resnet_groups,
|
465 |
+
dropout=dropout,
|
466 |
+
time_embedding_norm=resnet_time_scale_shift,
|
467 |
+
non_linearity=resnet_act_fn,
|
468 |
+
output_scale_factor=output_scale_factor,
|
469 |
+
pre_norm=resnet_pre_norm,
|
470 |
+
)
|
471 |
+
]
|
472 |
+
attentions = []
|
473 |
+
|
474 |
+
for i in range(num_layers):
|
475 |
+
if not dual_cross_attention:
|
476 |
+
attentions.append(
|
477 |
+
Transformer2DModel(
|
478 |
+
num_attention_heads,
|
479 |
+
in_channels // num_attention_heads,
|
480 |
+
in_channels=in_channels,
|
481 |
+
num_layers=transformer_layers_per_block[i],
|
482 |
+
cross_attention_dim=cross_attention_dim,
|
483 |
+
norm_num_groups=resnet_groups,
|
484 |
+
use_linear_projection=use_linear_projection,
|
485 |
+
upcast_attention=upcast_attention,
|
486 |
+
attention_type=attention_type,
|
487 |
+
)
|
488 |
+
)
|
489 |
+
else:
|
490 |
+
attentions.append(
|
491 |
+
DualTransformer2DModel(
|
492 |
+
num_attention_heads,
|
493 |
+
in_channels // num_attention_heads,
|
494 |
+
in_channels=in_channels,
|
495 |
+
num_layers=1,
|
496 |
+
cross_attention_dim=cross_attention_dim,
|
497 |
+
norm_num_groups=resnet_groups,
|
498 |
+
)
|
499 |
+
)
|
500 |
+
resnets.append(
|
501 |
+
ResnetBlock2D(
|
502 |
+
in_channels=in_channels,
|
503 |
+
out_channels=in_channels,
|
504 |
+
temb_channels=temb_channels,
|
505 |
+
eps=resnet_eps,
|
506 |
+
groups=resnet_groups,
|
507 |
+
dropout=dropout,
|
508 |
+
time_embedding_norm=resnet_time_scale_shift,
|
509 |
+
non_linearity=resnet_act_fn,
|
510 |
+
output_scale_factor=output_scale_factor,
|
511 |
+
pre_norm=resnet_pre_norm,
|
512 |
+
)
|
513 |
+
)
|
514 |
+
|
515 |
+
self.attentions = nn.ModuleList(attentions)
|
516 |
+
self.resnets = nn.ModuleList(resnets)
|
517 |
+
|
518 |
+
self.gradient_checkpointing = False
|
519 |
+
|
520 |
+
def forward(
|
521 |
+
self,
|
522 |
+
hidden_states: torch.FloatTensor,
|
523 |
+
temb: Optional[torch.FloatTensor] = None,
|
524 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
525 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
526 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
527 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
528 |
+
) -> torch.FloatTensor:
|
529 |
+
"""
|
530 |
+
Forward pass for the UNetMidBlock2DCrossAttn class.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
hidden_states (torch.FloatTensor): The input hidden states tensor.
|
534 |
+
temb (Optional[torch.FloatTensor], optional): The optional tensor for time embeddings.
|
535 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional): The optional encoder hidden states tensor.
|
536 |
+
attention_mask (Optional[torch.FloatTensor], optional): The optional attention mask tensor.
|
537 |
+
cross_attention_kwargs (Optional[Dict[str, Any]], optional): The optional cross-attention kwargs tensor.
|
538 |
+
encoder_attention_mask (Optional[torch.FloatTensor], optional): The optional encoder attention mask tensor.
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
torch.FloatTensor: The output tensor after passing through the UNetMidBlock2DCrossAttn layers.
|
542 |
+
"""
|
543 |
+
lora_scale = (
|
544 |
+
cross_attention_kwargs.get("scale", 1.0)
|
545 |
+
if cross_attention_kwargs is not None
|
546 |
+
else 1.0
|
547 |
+
)
|
548 |
+
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
549 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
550 |
+
if self.training and self.gradient_checkpointing:
|
551 |
+
|
552 |
+
def create_custom_forward(module, return_dict=None):
|
553 |
+
def custom_forward(*inputs):
|
554 |
+
if return_dict is not None:
|
555 |
+
return module(*inputs, return_dict=return_dict)
|
556 |
+
|
557 |
+
return module(*inputs)
|
558 |
+
|
559 |
+
return custom_forward
|
560 |
+
|
561 |
+
ckpt_kwargs: Dict[str, Any] = (
|
562 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
563 |
+
)
|
564 |
+
hidden_states, _ref_feature = attn(
|
565 |
+
hidden_states,
|
566 |
+
encoder_hidden_states=encoder_hidden_states,
|
567 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
568 |
+
attention_mask=attention_mask,
|
569 |
+
encoder_attention_mask=encoder_attention_mask,
|
570 |
+
return_dict=False,
|
571 |
+
)
|
572 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
573 |
+
create_custom_forward(resnet),
|
574 |
+
hidden_states,
|
575 |
+
temb,
|
576 |
+
**ckpt_kwargs,
|
577 |
+
)
|
578 |
+
else:
|
579 |
+
hidden_states, _ref_feature = attn(
|
580 |
+
hidden_states,
|
581 |
+
encoder_hidden_states=encoder_hidden_states,
|
582 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
583 |
+
attention_mask=attention_mask,
|
584 |
+
encoder_attention_mask=encoder_attention_mask,
|
585 |
+
return_dict=False,
|
586 |
+
)
|
587 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
588 |
+
|
589 |
+
return hidden_states
|
590 |
+
|
591 |
+
|
592 |
+
class CrossAttnDownBlock2D(nn.Module):
|
593 |
+
"""
|
594 |
+
CrossAttnDownBlock2D is a class that represents a 2D cross-attention downsampling block.
|
595 |
+
|
596 |
+
This block is used in the UNet model and consists of a series of ResNet blocks and Transformer layers.
|
597 |
+
It takes input hidden states, a tensor embedding, and optional encoder hidden states, attention mask,
|
598 |
+
and cross-attention kwargs. The block performs a series of operations including downsampling, cross-attention,
|
599 |
+
and residual connections.
|
600 |
+
|
601 |
+
Attributes:
|
602 |
+
in_channels (int): The number of input channels.
|
603 |
+
out_channels (int): The number of output channels.
|
604 |
+
temb_channels (int): The number of tensor embedding channels.
|
605 |
+
dropout (float): The dropout rate.
|
606 |
+
num_layers (int): The number of ResNet layers.
|
607 |
+
transformer_layers_per_block (Union[int, Tuple[int]]): The number of Transformer layers per block.
|
608 |
+
resnet_eps (float): The ResNet epsilon value.
|
609 |
+
resnet_time_scale_shift (str): The ResNet time scale shift type.
|
610 |
+
resnet_act_fn (str): The ResNet activation function.
|
611 |
+
resnet_groups (int): The ResNet group size.
|
612 |
+
resnet_pre_norm (bool): Whether to use ResNet pre-normalization.
|
613 |
+
num_attention_heads (int): The number of attention heads.
|
614 |
+
cross_attention_dim (int): The cross-attention dimension.
|
615 |
+
output_scale_factor (float): The output scale factor.
|
616 |
+
downsample_padding (int): The downsampling padding.
|
617 |
+
add_downsample (bool): Whether to add downsampling.
|
618 |
+
dual_cross_attention (bool): Whether to use dual cross-attention.
|
619 |
+
use_linear_projection (bool): Whether to use linear projection.
|
620 |
+
only_cross_attention (bool): Whether to use only cross-attention.
|
621 |
+
upcast_attention (bool): Whether to upcast attention.
|
622 |
+
attention_type (str): The attention type.
|
623 |
+
"""
|
624 |
+
def __init__(
|
625 |
+
self,
|
626 |
+
in_channels: int,
|
627 |
+
out_channels: int,
|
628 |
+
temb_channels: int,
|
629 |
+
dropout: float = 0.0,
|
630 |
+
num_layers: int = 1,
|
631 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
632 |
+
resnet_eps: float = 1e-6,
|
633 |
+
resnet_time_scale_shift: str = "default",
|
634 |
+
resnet_act_fn: str = "swish",
|
635 |
+
resnet_groups: int = 32,
|
636 |
+
resnet_pre_norm: bool = True,
|
637 |
+
num_attention_heads: int = 1,
|
638 |
+
cross_attention_dim: int = 1280,
|
639 |
+
output_scale_factor: float = 1.0,
|
640 |
+
downsample_padding: int = 1,
|
641 |
+
add_downsample: bool = True,
|
642 |
+
dual_cross_attention: bool = False,
|
643 |
+
use_linear_projection: bool = False,
|
644 |
+
only_cross_attention: bool = False,
|
645 |
+
upcast_attention: bool = False,
|
646 |
+
attention_type: str = "default",
|
647 |
+
):
|
648 |
+
super().__init__()
|
649 |
+
resnets = []
|
650 |
+
attentions = []
|
651 |
+
|
652 |
+
self.has_cross_attention = True
|
653 |
+
self.num_attention_heads = num_attention_heads
|
654 |
+
if isinstance(transformer_layers_per_block, int):
|
655 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
656 |
+
|
657 |
+
for i in range(num_layers):
|
658 |
+
in_channels = in_channels if i == 0 else out_channels
|
659 |
+
resnets.append(
|
660 |
+
ResnetBlock2D(
|
661 |
+
in_channels=in_channels,
|
662 |
+
out_channels=out_channels,
|
663 |
+
temb_channels=temb_channels,
|
664 |
+
eps=resnet_eps,
|
665 |
+
groups=resnet_groups,
|
666 |
+
dropout=dropout,
|
667 |
+
time_embedding_norm=resnet_time_scale_shift,
|
668 |
+
non_linearity=resnet_act_fn,
|
669 |
+
output_scale_factor=output_scale_factor,
|
670 |
+
pre_norm=resnet_pre_norm,
|
671 |
+
)
|
672 |
+
)
|
673 |
+
if not dual_cross_attention:
|
674 |
+
attentions.append(
|
675 |
+
Transformer2DModel(
|
676 |
+
num_attention_heads,
|
677 |
+
out_channels // num_attention_heads,
|
678 |
+
in_channels=out_channels,
|
679 |
+
num_layers=transformer_layers_per_block[i],
|
680 |
+
cross_attention_dim=cross_attention_dim,
|
681 |
+
norm_num_groups=resnet_groups,
|
682 |
+
use_linear_projection=use_linear_projection,
|
683 |
+
only_cross_attention=only_cross_attention,
|
684 |
+
upcast_attention=upcast_attention,
|
685 |
+
attention_type=attention_type,
|
686 |
+
)
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
attentions.append(
|
690 |
+
DualTransformer2DModel(
|
691 |
+
num_attention_heads,
|
692 |
+
out_channels // num_attention_heads,
|
693 |
+
in_channels=out_channels,
|
694 |
+
num_layers=1,
|
695 |
+
cross_attention_dim=cross_attention_dim,
|
696 |
+
norm_num_groups=resnet_groups,
|
697 |
+
)
|
698 |
+
)
|
699 |
+
self.attentions = nn.ModuleList(attentions)
|
700 |
+
self.resnets = nn.ModuleList(resnets)
|
701 |
+
|
702 |
+
if add_downsample:
|
703 |
+
self.downsamplers = nn.ModuleList(
|
704 |
+
[
|
705 |
+
Downsample2D(
|
706 |
+
out_channels,
|
707 |
+
use_conv=True,
|
708 |
+
out_channels=out_channels,
|
709 |
+
padding=downsample_padding,
|
710 |
+
name="op",
|
711 |
+
)
|
712 |
+
]
|
713 |
+
)
|
714 |
+
else:
|
715 |
+
self.downsamplers = None
|
716 |
+
|
717 |
+
self.gradient_checkpointing = False
|
718 |
+
|
719 |
+
def forward(
|
720 |
+
self,
|
721 |
+
hidden_states: torch.FloatTensor,
|
722 |
+
temb: Optional[torch.FloatTensor] = None,
|
723 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
724 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
725 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
726 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
727 |
+
additional_residuals: Optional[torch.FloatTensor] = None,
|
728 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
729 |
+
"""
|
730 |
+
Forward pass for the CrossAttnDownBlock2D class.
|
731 |
+
|
732 |
+
Args:
|
733 |
+
hidden_states (torch.FloatTensor): The input hidden states.
|
734 |
+
temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None.
|
735 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states. Defaults to None.
|
736 |
+
attention_mask (Optional[torch.FloatTensor], optional): The attention mask. Defaults to None.
|
737 |
+
cross_attention_kwargs (Optional[Dict[str, Any]], optional): The cross-attention kwargs. Defaults to None.
|
738 |
+
encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask. Defaults to None.
|
739 |
+
additional_residuals (Optional[torch.FloatTensor], optional): The additional residuals. Defaults to None.
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output hidden states and residuals.
|
743 |
+
"""
|
744 |
+
output_states = ()
|
745 |
+
|
746 |
+
lora_scale = (
|
747 |
+
cross_attention_kwargs.get("scale", 1.0)
|
748 |
+
if cross_attention_kwargs is not None
|
749 |
+
else 1.0
|
750 |
+
)
|
751 |
+
|
752 |
+
blocks = list(zip(self.resnets, self.attentions))
|
753 |
+
|
754 |
+
for i, (resnet, attn) in enumerate(blocks):
|
755 |
+
if self.training and self.gradient_checkpointing:
|
756 |
+
|
757 |
+
def create_custom_forward(module, return_dict=None):
|
758 |
+
def custom_forward(*inputs):
|
759 |
+
if return_dict is not None:
|
760 |
+
return module(*inputs, return_dict=return_dict)
|
761 |
+
|
762 |
+
return module(*inputs)
|
763 |
+
|
764 |
+
return custom_forward
|
765 |
+
|
766 |
+
ckpt_kwargs: Dict[str, Any] = (
|
767 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
768 |
+
)
|
769 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
770 |
+
create_custom_forward(resnet),
|
771 |
+
hidden_states,
|
772 |
+
temb,
|
773 |
+
**ckpt_kwargs,
|
774 |
+
)
|
775 |
+
hidden_states, _ref_feature = attn(
|
776 |
+
hidden_states,
|
777 |
+
encoder_hidden_states=encoder_hidden_states,
|
778 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
779 |
+
attention_mask=attention_mask,
|
780 |
+
encoder_attention_mask=encoder_attention_mask,
|
781 |
+
return_dict=False,
|
782 |
+
)
|
783 |
+
else:
|
784 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
785 |
+
hidden_states, _ref_feature = attn(
|
786 |
+
hidden_states,
|
787 |
+
encoder_hidden_states=encoder_hidden_states,
|
788 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
789 |
+
attention_mask=attention_mask,
|
790 |
+
encoder_attention_mask=encoder_attention_mask,
|
791 |
+
return_dict=False,
|
792 |
+
)
|
793 |
+
|
794 |
+
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
795 |
+
if i == len(blocks) - 1 and additional_residuals is not None:
|
796 |
+
hidden_states = hidden_states + additional_residuals
|
797 |
+
|
798 |
+
output_states = output_states + (hidden_states,)
|
799 |
+
|
800 |
+
if self.downsamplers is not None:
|
801 |
+
for downsampler in self.downsamplers:
|
802 |
+
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
803 |
+
|
804 |
+
output_states = output_states + (hidden_states,)
|
805 |
+
|
806 |
+
return hidden_states, output_states
|
807 |
+
|
808 |
+
|
809 |
+
class DownBlock2D(nn.Module):
|
810 |
+
"""
|
811 |
+
DownBlock2D is a class that represents a 2D downsampling block in a neural network.
|
812 |
+
|
813 |
+
It takes the following parameters:
|
814 |
+
- in_channels (int): The number of input channels in the block.
|
815 |
+
- out_channels (int): The number of output channels in the block.
|
816 |
+
- temb_channels (int): The number of channels in the token embedding.
|
817 |
+
- dropout (float): The dropout rate for the block.
|
818 |
+
- num_layers (int): The number of layers in the block.
|
819 |
+
- resnet_eps (float): The epsilon value for the ResNet layer.
|
820 |
+
- resnet_time_scale_shift (str): The type of activation function for the ResNet layer.
|
821 |
+
- resnet_act_fn (str): The activation function for the ResNet layer.
|
822 |
+
- resnet_groups (int): The number of groups in the ResNet layer.
|
823 |
+
- resnet_pre_norm (bool): Whether to apply layer normalization before the ResNet layer.
|
824 |
+
- output_scale_factor (float): The scale factor for the output.
|
825 |
+
- add_downsample (bool): Whether to add a downsampling layer.
|
826 |
+
- downsample_padding (int): The padding value for the downsampling layer.
|
827 |
+
|
828 |
+
The DownBlock2D class inherits from the nn.Module class and defines the following methods:
|
829 |
+
- __init__: Initializes the DownBlock2D class with the given parameters.
|
830 |
+
- forward: Forward pass of the DownBlock2D class.
|
831 |
+
|
832 |
+
The forward method takes the following parameters:
|
833 |
+
- hidden_states (torch.FloatTensor): The input tensor to the block.
|
834 |
+
- temb (Optional[torch.FloatTensor]): The token embedding tensor.
|
835 |
+
- scale (float): The scale factor for the input tensor.
|
836 |
+
|
837 |
+
The forward method returns a tuple containing the output tensor and a tuple of hidden states.
|
838 |
+
"""
|
839 |
+
def __init__(
|
840 |
+
self,
|
841 |
+
in_channels: int,
|
842 |
+
out_channels: int,
|
843 |
+
temb_channels: int,
|
844 |
+
dropout: float = 0.0,
|
845 |
+
num_layers: int = 1,
|
846 |
+
resnet_eps: float = 1e-6,
|
847 |
+
resnet_time_scale_shift: str = "default",
|
848 |
+
resnet_act_fn: str = "swish",
|
849 |
+
resnet_groups: int = 32,
|
850 |
+
resnet_pre_norm: bool = True,
|
851 |
+
output_scale_factor: float = 1.0,
|
852 |
+
add_downsample: bool = True,
|
853 |
+
downsample_padding: int = 1,
|
854 |
+
):
|
855 |
+
super().__init__()
|
856 |
+
resnets = []
|
857 |
+
|
858 |
+
for i in range(num_layers):
|
859 |
+
in_channels = in_channels if i == 0 else out_channels
|
860 |
+
resnets.append(
|
861 |
+
ResnetBlock2D(
|
862 |
+
in_channels=in_channels,
|
863 |
+
out_channels=out_channels,
|
864 |
+
temb_channels=temb_channels,
|
865 |
+
eps=resnet_eps,
|
866 |
+
groups=resnet_groups,
|
867 |
+
dropout=dropout,
|
868 |
+
time_embedding_norm=resnet_time_scale_shift,
|
869 |
+
non_linearity=resnet_act_fn,
|
870 |
+
output_scale_factor=output_scale_factor,
|
871 |
+
pre_norm=resnet_pre_norm,
|
872 |
+
)
|
873 |
+
)
|
874 |
+
|
875 |
+
self.resnets = nn.ModuleList(resnets)
|
876 |
+
|
877 |
+
if add_downsample:
|
878 |
+
self.downsamplers = nn.ModuleList(
|
879 |
+
[
|
880 |
+
Downsample2D(
|
881 |
+
out_channels,
|
882 |
+
use_conv=True,
|
883 |
+
out_channels=out_channels,
|
884 |
+
padding=downsample_padding,
|
885 |
+
name="op",
|
886 |
+
)
|
887 |
+
]
|
888 |
+
)
|
889 |
+
else:
|
890 |
+
self.downsamplers = None
|
891 |
+
|
892 |
+
self.gradient_checkpointing = False
|
893 |
+
|
894 |
+
def forward(
|
895 |
+
self,
|
896 |
+
hidden_states: torch.FloatTensor,
|
897 |
+
temb: Optional[torch.FloatTensor] = None,
|
898 |
+
scale: float = 1.0,
|
899 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
900 |
+
"""
|
901 |
+
Forward pass of the DownBlock2D class.
|
902 |
+
|
903 |
+
Args:
|
904 |
+
hidden_states (torch.FloatTensor): The input tensor to the DownBlock2D layer.
|
905 |
+
temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None.
|
906 |
+
scale (float, optional): The scale factor for the input tensor. Defaults to 1.0.
|
907 |
+
|
908 |
+
Returns:
|
909 |
+
Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output tensor and any additional hidden states.
|
910 |
+
"""
|
911 |
+
output_states = ()
|
912 |
+
|
913 |
+
for resnet in self.resnets:
|
914 |
+
if self.training and self.gradient_checkpointing:
|
915 |
+
|
916 |
+
def create_custom_forward(module):
|
917 |
+
def custom_forward(*inputs):
|
918 |
+
return module(*inputs)
|
919 |
+
|
920 |
+
return custom_forward
|
921 |
+
|
922 |
+
if is_torch_version(">=", "1.11.0"):
|
923 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
924 |
+
create_custom_forward(resnet),
|
925 |
+
hidden_states,
|
926 |
+
temb,
|
927 |
+
use_reentrant=False,
|
928 |
+
)
|
929 |
+
else:
|
930 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
931 |
+
create_custom_forward(resnet), hidden_states, temb
|
932 |
+
)
|
933 |
+
else:
|
934 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
935 |
+
|
936 |
+
output_states = output_states + (hidden_states,)
|
937 |
+
|
938 |
+
if self.downsamplers is not None:
|
939 |
+
for downsampler in self.downsamplers:
|
940 |
+
hidden_states = downsampler(hidden_states, scale=scale)
|
941 |
+
|
942 |
+
output_states = output_states + (hidden_states,)
|
943 |
+
|
944 |
+
return hidden_states, output_states
|
945 |
+
|
946 |
+
|
947 |
+
class CrossAttnUpBlock2D(nn.Module):
|
948 |
+
"""
|
949 |
+
CrossAttnUpBlock2D is a class that represents a cross-attention UpBlock in a 2D UNet architecture.
|
950 |
+
|
951 |
+
This block is responsible for upsampling the input tensor and performing cross-attention with the encoder's hidden states.
|
952 |
+
|
953 |
+
Args:
|
954 |
+
in_channels (int): The number of input channels in the tensor.
|
955 |
+
out_channels (int): The number of output channels in the tensor.
|
956 |
+
prev_output_channel (int): The number of channels in the previous output tensor.
|
957 |
+
temb_channels (int): The number of channels in the token embedding tensor.
|
958 |
+
resolution_idx (Optional[int]): The index of the resolution in the model.
|
959 |
+
dropout (float): The dropout rate for the layer.
|
960 |
+
num_layers (int): The number of layers in the ResNet block.
|
961 |
+
transformer_layers_per_block (Union[int, Tuple[int]]): The number of transformer layers per block.
|
962 |
+
resnet_eps (float): The epsilon value for the ResNet layer.
|
963 |
+
resnet_time_scale_shift (str): The type of time scale shift to be applied in the ResNet layer.
|
964 |
+
resnet_act_fn (str): The activation function to be used in the ResNet layer.
|
965 |
+
resnet_groups (int): The number of groups in the ResNet layer.
|
966 |
+
resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet layer.
|
967 |
+
num_attention_heads (int): The number of attention heads in the cross-attention layer.
|
968 |
+
cross_attention_dim (int): The dimension of the cross-attention layer.
|
969 |
+
output_scale_factor (float): The scale factor for the output tensor.
|
970 |
+
add_upsample (bool): Whether to add upsampling to the block.
|
971 |
+
dual_cross_attention (bool): Whether to use dual cross-attention.
|
972 |
+
use_linear_projection (bool): Whether to use linear projection in the cross-attention layer.
|
973 |
+
only_cross_attention (bool): Whether to only use cross-attention and no self-attention.
|
974 |
+
upcast_attention (bool): Whether to upcast the attention weights.
|
975 |
+
attention_type (str): The type of attention to be used in the cross-attention layer.
|
976 |
+
|
977 |
+
Attributes:
|
978 |
+
up_block (nn.Module): The UpBlock module responsible for upsampling the input tensor.
|
979 |
+
cross_attn (nn.Module): The cross-attention module that performs attention between
|
980 |
+
the decoder's hidden states and the encoder's hidden states.
|
981 |
+
resnet_blocks (nn.ModuleList): A list of ResNet blocks that make up the ResNet portion of the block.
|
982 |
+
"""
|
983 |
+
|
984 |
+
def __init__(
|
985 |
+
self,
|
986 |
+
in_channels: int,
|
987 |
+
out_channels: int,
|
988 |
+
prev_output_channel: int,
|
989 |
+
temb_channels: int,
|
990 |
+
resolution_idx: Optional[int] = None,
|
991 |
+
dropout: float = 0.0,
|
992 |
+
num_layers: int = 1,
|
993 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
994 |
+
resnet_eps: float = 1e-6,
|
995 |
+
resnet_time_scale_shift: str = "default",
|
996 |
+
resnet_act_fn: str = "swish",
|
997 |
+
resnet_groups: int = 32,
|
998 |
+
resnet_pre_norm: bool = True,
|
999 |
+
num_attention_heads: int = 1,
|
1000 |
+
cross_attention_dim: int = 1280,
|
1001 |
+
output_scale_factor: float = 1.0,
|
1002 |
+
add_upsample: bool = True,
|
1003 |
+
dual_cross_attention: bool = False,
|
1004 |
+
use_linear_projection: bool = False,
|
1005 |
+
only_cross_attention: bool = False,
|
1006 |
+
upcast_attention: bool = False,
|
1007 |
+
attention_type: str = "default",
|
1008 |
+
):
|
1009 |
+
super().__init__()
|
1010 |
+
resnets = []
|
1011 |
+
attentions = []
|
1012 |
+
|
1013 |
+
self.has_cross_attention = True
|
1014 |
+
self.num_attention_heads = num_attention_heads
|
1015 |
+
|
1016 |
+
if isinstance(transformer_layers_per_block, int):
|
1017 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
1018 |
+
|
1019 |
+
for i in range(num_layers):
|
1020 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1021 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1022 |
+
|
1023 |
+
resnets.append(
|
1024 |
+
ResnetBlock2D(
|
1025 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1026 |
+
out_channels=out_channels,
|
1027 |
+
temb_channels=temb_channels,
|
1028 |
+
eps=resnet_eps,
|
1029 |
+
groups=resnet_groups,
|
1030 |
+
dropout=dropout,
|
1031 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1032 |
+
non_linearity=resnet_act_fn,
|
1033 |
+
output_scale_factor=output_scale_factor,
|
1034 |
+
pre_norm=resnet_pre_norm,
|
1035 |
+
)
|
1036 |
+
)
|
1037 |
+
if not dual_cross_attention:
|
1038 |
+
attentions.append(
|
1039 |
+
Transformer2DModel(
|
1040 |
+
num_attention_heads,
|
1041 |
+
out_channels // num_attention_heads,
|
1042 |
+
in_channels=out_channels,
|
1043 |
+
num_layers=transformer_layers_per_block[i],
|
1044 |
+
cross_attention_dim=cross_attention_dim,
|
1045 |
+
norm_num_groups=resnet_groups,
|
1046 |
+
use_linear_projection=use_linear_projection,
|
1047 |
+
only_cross_attention=only_cross_attention,
|
1048 |
+
upcast_attention=upcast_attention,
|
1049 |
+
attention_type=attention_type,
|
1050 |
+
)
|
1051 |
+
)
|
1052 |
+
else:
|
1053 |
+
attentions.append(
|
1054 |
+
DualTransformer2DModel(
|
1055 |
+
num_attention_heads,
|
1056 |
+
out_channels // num_attention_heads,
|
1057 |
+
in_channels=out_channels,
|
1058 |
+
num_layers=1,
|
1059 |
+
cross_attention_dim=cross_attention_dim,
|
1060 |
+
norm_num_groups=resnet_groups,
|
1061 |
+
)
|
1062 |
+
)
|
1063 |
+
self.attentions = nn.ModuleList(attentions)
|
1064 |
+
self.resnets = nn.ModuleList(resnets)
|
1065 |
+
|
1066 |
+
if add_upsample:
|
1067 |
+
self.upsamplers = nn.ModuleList(
|
1068 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1069 |
+
)
|
1070 |
+
else:
|
1071 |
+
self.upsamplers = None
|
1072 |
+
|
1073 |
+
self.gradient_checkpointing = False
|
1074 |
+
self.resolution_idx = resolution_idx
|
1075 |
+
|
1076 |
+
def forward(
|
1077 |
+
self,
|
1078 |
+
hidden_states: torch.FloatTensor,
|
1079 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1080 |
+
temb: Optional[torch.FloatTensor] = None,
|
1081 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1082 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1083 |
+
upsample_size: Optional[int] = None,
|
1084 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1085 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1086 |
+
) -> torch.FloatTensor:
|
1087 |
+
"""
|
1088 |
+
Forward pass for the CrossAttnUpBlock2D class.
|
1089 |
+
|
1090 |
+
Args:
|
1091 |
+
self (CrossAttnUpBlock2D): An instance of the CrossAttnUpBlock2D class.
|
1092 |
+
hidden_states (torch.FloatTensor): The input hidden states tensor.
|
1093 |
+
res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states tensors.
|
1094 |
+
temb (Optional[torch.FloatTensor], optional): The token embeddings tensor. Defaults to None.
|
1095 |
+
encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states tensor. Defaults to None.
|
1096 |
+
cross_attention_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for cross attention. Defaults to None.
|
1097 |
+
upsample_size (Optional[int], optional): The upsample size. Defaults to None.
|
1098 |
+
attention_mask (Optional[torch.FloatTensor], optional): The attention mask tensor. Defaults to None.
|
1099 |
+
encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask tensor. Defaults to None.
|
1100 |
+
|
1101 |
+
Returns:
|
1102 |
+
torch.FloatTensor: The output tensor after passing through the block.
|
1103 |
+
"""
|
1104 |
+
lora_scale = (
|
1105 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1106 |
+
if cross_attention_kwargs is not None
|
1107 |
+
else 1.0
|
1108 |
+
)
|
1109 |
+
is_freeu_enabled = (
|
1110 |
+
getattr(self, "s1", None)
|
1111 |
+
and getattr(self, "s2", None)
|
1112 |
+
and getattr(self, "b1", None)
|
1113 |
+
and getattr(self, "b2", None)
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1117 |
+
# pop res hidden states
|
1118 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1119 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1120 |
+
|
1121 |
+
# FreeU: Only operate on the first two stages
|
1122 |
+
if is_freeu_enabled:
|
1123 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1124 |
+
self.resolution_idx,
|
1125 |
+
hidden_states,
|
1126 |
+
res_hidden_states,
|
1127 |
+
s1=self.s1,
|
1128 |
+
s2=self.s2,
|
1129 |
+
b1=self.b1,
|
1130 |
+
b2=self.b2,
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1134 |
+
|
1135 |
+
if self.training and self.gradient_checkpointing:
|
1136 |
+
|
1137 |
+
def create_custom_forward(module, return_dict=None):
|
1138 |
+
def custom_forward(*inputs):
|
1139 |
+
if return_dict is not None:
|
1140 |
+
return module(*inputs, return_dict=return_dict)
|
1141 |
+
|
1142 |
+
return module(*inputs)
|
1143 |
+
|
1144 |
+
return custom_forward
|
1145 |
+
|
1146 |
+
ckpt_kwargs: Dict[str, Any] = (
|
1147 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1148 |
+
)
|
1149 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1150 |
+
create_custom_forward(resnet),
|
1151 |
+
hidden_states,
|
1152 |
+
temb,
|
1153 |
+
**ckpt_kwargs,
|
1154 |
+
)
|
1155 |
+
hidden_states, _ref_feature = attn(
|
1156 |
+
hidden_states,
|
1157 |
+
encoder_hidden_states=encoder_hidden_states,
|
1158 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1159 |
+
attention_mask=attention_mask,
|
1160 |
+
encoder_attention_mask=encoder_attention_mask,
|
1161 |
+
return_dict=False,
|
1162 |
+
)
|
1163 |
+
else:
|
1164 |
+
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1165 |
+
hidden_states, _ref_feature = attn(
|
1166 |
+
hidden_states,
|
1167 |
+
encoder_hidden_states=encoder_hidden_states,
|
1168 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1169 |
+
attention_mask=attention_mask,
|
1170 |
+
encoder_attention_mask=encoder_attention_mask,
|
1171 |
+
return_dict=False,
|
1172 |
+
)
|
1173 |
+
|
1174 |
+
if self.upsamplers is not None:
|
1175 |
+
for upsampler in self.upsamplers:
|
1176 |
+
hidden_states = upsampler(
|
1177 |
+
hidden_states, upsample_size, scale=lora_scale
|
1178 |
+
)
|
1179 |
+
|
1180 |
+
return hidden_states
|
1181 |
+
|
1182 |
+
|
1183 |
+
class UpBlock2D(nn.Module):
|
1184 |
+
"""
|
1185 |
+
UpBlock2D is a class that represents a 2D upsampling block in a neural network.
|
1186 |
+
|
1187 |
+
This block is used for upsampling the input tensor by a factor of 2 in both dimensions.
|
1188 |
+
It takes the previous output channel, input channels, and output channels as input
|
1189 |
+
and applies a series of convolutional layers, batch normalization, and activation
|
1190 |
+
functions to produce the upsampled tensor.
|
1191 |
+
|
1192 |
+
Args:
|
1193 |
+
in_channels (int): The number of input channels in the tensor.
|
1194 |
+
prev_output_channel (int): The number of channels in the previous output tensor.
|
1195 |
+
out_channels (int): The number of output channels in the tensor.
|
1196 |
+
temb_channels (int): The number of channels in the time embedding tensor.
|
1197 |
+
resolution_idx (Optional[int], optional): The index of the resolution in the sequence of resolutions. Defaults to None.
|
1198 |
+
dropout (float, optional): The dropout rate to be applied to the convolutional layers. Defaults to 0.0.
|
1199 |
+
num_layers (int, optional): The number of convolutional layers in the block. Defaults to 1.
|
1200 |
+
resnet_eps (float, optional): The epsilon value used in the batch normalization layer. Defaults to 1e-6.
|
1201 |
+
resnet_time_scale_shift (str, optional): The type of activation function to be applied after the convolutional layers. Defaults to "default".
|
1202 |
+
resnet_act_fn (str, optional): The activation function to be applied after the batch normalization layer. Defaults to "swish".
|
1203 |
+
resnet_groups (int, optional): The number of groups in the group normalization layer. Defaults to 32.
|
1204 |
+
resnet_pre_norm (bool, optional): A flag indicating whether to apply layer normalization before the activation function. Defaults to True.
|
1205 |
+
output_scale_factor (float, optional): The scale factor to be applied to the output tensor. Defaults to 1.0.
|
1206 |
+
add_upsample (bool, optional): A flag indicating whether to add an upsampling layer to the block. Defaults to True.
|
1207 |
+
|
1208 |
+
Attributes:
|
1209 |
+
layers (nn.ModuleList): A list of nn.Module objects representing the convolutional layers in the block.
|
1210 |
+
upsample (nn.Module): The upsampling layer in the block, if add_upsample is True.
|
1211 |
+
|
1212 |
+
"""
|
1213 |
+
|
1214 |
+
def __init__(
|
1215 |
+
self,
|
1216 |
+
in_channels: int,
|
1217 |
+
prev_output_channel: int,
|
1218 |
+
out_channels: int,
|
1219 |
+
temb_channels: int,
|
1220 |
+
resolution_idx: Optional[int] = None,
|
1221 |
+
dropout: float = 0.0,
|
1222 |
+
num_layers: int = 1,
|
1223 |
+
resnet_eps: float = 1e-6,
|
1224 |
+
resnet_time_scale_shift: str = "default",
|
1225 |
+
resnet_act_fn: str = "swish",
|
1226 |
+
resnet_groups: int = 32,
|
1227 |
+
resnet_pre_norm: bool = True,
|
1228 |
+
output_scale_factor: float = 1.0,
|
1229 |
+
add_upsample: bool = True,
|
1230 |
+
):
|
1231 |
+
super().__init__()
|
1232 |
+
resnets = []
|
1233 |
+
|
1234 |
+
for i in range(num_layers):
|
1235 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1236 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1237 |
+
|
1238 |
+
resnets.append(
|
1239 |
+
ResnetBlock2D(
|
1240 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1241 |
+
out_channels=out_channels,
|
1242 |
+
temb_channels=temb_channels,
|
1243 |
+
eps=resnet_eps,
|
1244 |
+
groups=resnet_groups,
|
1245 |
+
dropout=dropout,
|
1246 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1247 |
+
non_linearity=resnet_act_fn,
|
1248 |
+
output_scale_factor=output_scale_factor,
|
1249 |
+
pre_norm=resnet_pre_norm,
|
1250 |
+
)
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
self.resnets = nn.ModuleList(resnets)
|
1254 |
+
|
1255 |
+
if add_upsample:
|
1256 |
+
self.upsamplers = nn.ModuleList(
|
1257 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
|
1258 |
+
)
|
1259 |
+
else:
|
1260 |
+
self.upsamplers = None
|
1261 |
+
|
1262 |
+
self.gradient_checkpointing = False
|
1263 |
+
self.resolution_idx = resolution_idx
|
1264 |
+
|
1265 |
+
def forward(
|
1266 |
+
self,
|
1267 |
+
hidden_states: torch.FloatTensor,
|
1268 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1269 |
+
temb: Optional[torch.FloatTensor] = None,
|
1270 |
+
upsample_size: Optional[int] = None,
|
1271 |
+
scale: float = 1.0,
|
1272 |
+
) -> torch.FloatTensor:
|
1273 |
+
|
1274 |
+
"""
|
1275 |
+
Forward pass for the UpBlock2D class.
|
1276 |
+
|
1277 |
+
Args:
|
1278 |
+
self (UpBlock2D): An instance of the UpBlock2D class.
|
1279 |
+
hidden_states (torch.FloatTensor): The input tensor to the block.
|
1280 |
+
res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states.
|
1281 |
+
temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None.
|
1282 |
+
upsample_size (Optional[int], optional): The size to upsample the input tensor to. Defaults to None.
|
1283 |
+
scale (float, optional): The scale factor to apply to the input tensor. Defaults to 1.0.
|
1284 |
+
|
1285 |
+
Returns:
|
1286 |
+
torch.FloatTensor: The output tensor after passing through the block.
|
1287 |
+
"""
|
1288 |
+
is_freeu_enabled = (
|
1289 |
+
getattr(self, "s1", None)
|
1290 |
+
and getattr(self, "s2", None)
|
1291 |
+
and getattr(self, "b1", None)
|
1292 |
+
and getattr(self, "b2", None)
|
1293 |
+
)
|
1294 |
+
|
1295 |
+
for resnet in self.resnets:
|
1296 |
+
# pop res hidden states
|
1297 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1298 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1299 |
+
|
1300 |
+
# FreeU: Only operate on the first two stages
|
1301 |
+
if is_freeu_enabled:
|
1302 |
+
hidden_states, res_hidden_states = apply_freeu(
|
1303 |
+
self.resolution_idx,
|
1304 |
+
hidden_states,
|
1305 |
+
res_hidden_states,
|
1306 |
+
s1=self.s1,
|
1307 |
+
s2=self.s2,
|
1308 |
+
b1=self.b1,
|
1309 |
+
b2=self.b2,
|
1310 |
+
)
|
1311 |
+
|
1312 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1313 |
+
|
1314 |
+
if self.training and self.gradient_checkpointing:
|
1315 |
+
|
1316 |
+
def create_custom_forward(module):
|
1317 |
+
def custom_forward(*inputs):
|
1318 |
+
return module(*inputs)
|
1319 |
+
|
1320 |
+
return custom_forward
|
1321 |
+
|
1322 |
+
if is_torch_version(">=", "1.11.0"):
|
1323 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1324 |
+
create_custom_forward(resnet),
|
1325 |
+
hidden_states,
|
1326 |
+
temb,
|
1327 |
+
use_reentrant=False,
|
1328 |
+
)
|
1329 |
+
else:
|
1330 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1331 |
+
create_custom_forward(resnet), hidden_states, temb
|
1332 |
+
)
|
1333 |
+
else:
|
1334 |
+
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1335 |
+
|
1336 |
+
if self.upsamplers is not None:
|
1337 |
+
for upsampler in self.upsamplers:
|
1338 |
+
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
1339 |
+
|
1340 |
+
return hidden_states
|
joyhallo/models/unet_2d_condition.py
ADDED
@@ -0,0 +1,1428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module implements the `UNet2DConditionModel`,
|
3 |
+
a variant of the 2D U-Net architecture designed for conditional image generation tasks.
|
4 |
+
The model is capable of taking a noisy input sample and conditioning it based on additional information such as class labels,
|
5 |
+
time steps, and encoder hidden states to produce a denoised output.
|
6 |
+
|
7 |
+
The `UNet2DConditionModel` leverages various components such as time embeddings,
|
8 |
+
class embeddings, and cross-attention mechanisms to integrate the conditioning information effectively.
|
9 |
+
It is built upon several sub-blocks including down-blocks, a middle block, and up-blocks,
|
10 |
+
each responsible for different stages of the U-Net's downsampling and upsampling process.
|
11 |
+
|
12 |
+
Key Features:
|
13 |
+
- Support for multiple types of down and up blocks, including those with cross-attention capabilities.
|
14 |
+
- Flexible configuration of the model's layers, including the number of layers per block and the output channels for each block.
|
15 |
+
- Integration of time embeddings and class embeddings to condition the model's output on additional information.
|
16 |
+
- Implementation of cross-attention to leverage encoder hidden states for conditional generation.
|
17 |
+
- The model supports gradient checkpointing to reduce memory usage during training.
|
18 |
+
|
19 |
+
The module also includes utility functions and classes such as `UNet2DConditionOutput` for structured output
|
20 |
+
and `load_change_cross_attention_dim` for loading and modifying pre-trained models.
|
21 |
+
|
22 |
+
Example Usage:
|
23 |
+
>>> import torch
|
24 |
+
>>> from unet_2d_condition_model import UNet2DConditionModel
|
25 |
+
>>> model = UNet2DConditionModel(
|
26 |
+
... sample_size=(64, 64),
|
27 |
+
... in_channels=3,
|
28 |
+
... out_channels=3,
|
29 |
+
... encoder_hid_dim=512,
|
30 |
+
... cross_attention_dim=1024,
|
31 |
+
... )
|
32 |
+
>>> # Prepare input tensors
|
33 |
+
>>> sample = torch.randn(1, 3, 64, 64)
|
34 |
+
>>> timestep = 0
|
35 |
+
>>> encoder_hidden_states = torch.randn(1, 14, 512)
|
36 |
+
>>> # Forward pass through the model
|
37 |
+
>>> output = model(sample, timestep, encoder_hidden_states)
|
38 |
+
|
39 |
+
This module is part of a larger ecosystem of diffusion models and can be used for various conditional image generation tasks.
|
40 |
+
"""
|
41 |
+
|
42 |
+
from dataclasses import dataclass
|
43 |
+
from os import PathLike
|
44 |
+
from pathlib import Path
|
45 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
46 |
+
|
47 |
+
import torch
|
48 |
+
import torch.utils.checkpoint
|
49 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
50 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
51 |
+
from diffusers.models.activations import get_activation
|
52 |
+
from diffusers.models.attention_processor import (
|
53 |
+
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
|
54 |
+
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
|
55 |
+
from diffusers.models.embeddings import (GaussianFourierProjection,
|
56 |
+
GLIGENTextBoundingboxProjection,
|
57 |
+
ImageHintTimeEmbedding,
|
58 |
+
ImageProjection, ImageTimeEmbedding,
|
59 |
+
TextImageProjection,
|
60 |
+
TextImageTimeEmbedding,
|
61 |
+
TextTimeEmbedding, TimestepEmbedding,
|
62 |
+
Timesteps)
|
63 |
+
from diffusers.models.modeling_utils import ModelMixin
|
64 |
+
from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND,
|
65 |
+
WEIGHTS_NAME, BaseOutput, deprecate, logging,
|
66 |
+
scale_lora_layers, unscale_lora_layers)
|
67 |
+
from safetensors.torch import load_file
|
68 |
+
from torch import nn
|
69 |
+
|
70 |
+
from .unet_2d_blocks import (UNetMidBlock2D, UNetMidBlock2DCrossAttn,
|
71 |
+
get_down_block, get_up_block)
|
72 |
+
|
73 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class UNet2DConditionOutput(BaseOutput):
|
77 |
+
"""
|
78 |
+
The output of [`UNet2DConditionModel`].
|
79 |
+
|
80 |
+
Args:
|
81 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
82 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
83 |
+
"""
|
84 |
+
|
85 |
+
sample: torch.FloatTensor = None
|
86 |
+
ref_features: Tuple[torch.FloatTensor] = None
|
87 |
+
|
88 |
+
|
89 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
90 |
+
r"""
|
91 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
92 |
+
shaped output.
|
93 |
+
|
94 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
95 |
+
for all models (such as downloading or saving).
|
96 |
+
|
97 |
+
Parameters:
|
98 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
99 |
+
Height and width of input/output sample.
|
100 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
101 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
102 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
103 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
104 |
+
Whether to flip the sin to cos in the time embedding.
|
105 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
106 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to
|
107 |
+
`("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
108 |
+
The tuple of downsample blocks to use.
|
109 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
110 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
111 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
112 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
113 |
+
The tuple of upsample blocks to use.
|
114 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
115 |
+
Whether to include self-attention in the basic transformer blocks, see
|
116 |
+
[`~models.attention.BasicTransformerBlock`].
|
117 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
118 |
+
The tuple of output channels for each block.
|
119 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
120 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
121 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
122 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
123 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
124 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
125 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
126 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
127 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
128 |
+
The dimension of the cross attention features.
|
129 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
130 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
131 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
132 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
133 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
134 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
135 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
136 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
137 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
138 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
139 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
140 |
+
dimension to `cross_attention_dim`.
|
141 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
142 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
143 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
144 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
145 |
+
num_attention_heads (`int`, *optional*):
|
146 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
147 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
148 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
149 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
150 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
151 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
152 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
153 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
154 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
155 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
156 |
+
Dimension for the timestep embeddings.
|
157 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
158 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
159 |
+
class conditioning with `class_embed_type` equal to `None`.
|
160 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
161 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
162 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
163 |
+
An optional override for the dimension of the projected time embedding.
|
164 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
165 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
166 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
167 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
168 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
169 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
170 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
171 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
172 |
+
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
173 |
+
*optional*): The dimension of the `class_labels` input when
|
174 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
175 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
176 |
+
embeddings with the class embeddings.
|
177 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
178 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
179 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
180 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
181 |
+
otherwise.
|
182 |
+
"""
|
183 |
+
|
184 |
+
_supports_gradient_checkpointing = True
|
185 |
+
|
186 |
+
@register_to_config
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
sample_size: Optional[int] = None,
|
190 |
+
in_channels: int = 4,
|
191 |
+
_out_channels: int = 4,
|
192 |
+
_center_input_sample: bool = False,
|
193 |
+
flip_sin_to_cos: bool = True,
|
194 |
+
freq_shift: int = 0,
|
195 |
+
down_block_types: Tuple[str] = (
|
196 |
+
"CrossAttnDownBlock2D",
|
197 |
+
"CrossAttnDownBlock2D",
|
198 |
+
"CrossAttnDownBlock2D",
|
199 |
+
"DownBlock2D",
|
200 |
+
),
|
201 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
202 |
+
up_block_types: Tuple[str] = (
|
203 |
+
"UpBlock2D",
|
204 |
+
"CrossAttnUpBlock2D",
|
205 |
+
"CrossAttnUpBlock2D",
|
206 |
+
"CrossAttnUpBlock2D",
|
207 |
+
),
|
208 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
209 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
210 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
211 |
+
downsample_padding: int = 1,
|
212 |
+
mid_block_scale_factor: float = 1,
|
213 |
+
dropout: float = 0.0,
|
214 |
+
act_fn: str = "silu",
|
215 |
+
norm_num_groups: Optional[int] = 32,
|
216 |
+
norm_eps: float = 1e-5,
|
217 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
218 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
219 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
220 |
+
encoder_hid_dim: Optional[int] = None,
|
221 |
+
encoder_hid_dim_type: Optional[str] = None,
|
222 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
223 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
224 |
+
dual_cross_attention: bool = False,
|
225 |
+
use_linear_projection: bool = False,
|
226 |
+
class_embed_type: Optional[str] = None,
|
227 |
+
addition_embed_type: Optional[str] = None,
|
228 |
+
addition_time_embed_dim: Optional[int] = None,
|
229 |
+
num_class_embeds: Optional[int] = None,
|
230 |
+
upcast_attention: bool = False,
|
231 |
+
resnet_time_scale_shift: str = "default",
|
232 |
+
time_embedding_type: str = "positional",
|
233 |
+
time_embedding_dim: Optional[int] = None,
|
234 |
+
time_embedding_act_fn: Optional[str] = None,
|
235 |
+
timestep_post_act: Optional[str] = None,
|
236 |
+
time_cond_proj_dim: Optional[int] = None,
|
237 |
+
conv_in_kernel: int = 3,
|
238 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
239 |
+
attention_type: str = "default",
|
240 |
+
class_embeddings_concat: bool = False,
|
241 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
242 |
+
addition_embed_type_num_heads=64,
|
243 |
+
_landmark_net=False,
|
244 |
+
):
|
245 |
+
super().__init__()
|
246 |
+
|
247 |
+
self.sample_size = sample_size
|
248 |
+
|
249 |
+
if num_attention_heads is not None:
|
250 |
+
raise ValueError(
|
251 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
|
252 |
+
"because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131."
|
253 |
+
"Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
254 |
+
)
|
255 |
+
|
256 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
257 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
258 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
259 |
+
# when this library was created. The incorrect naming was only discovered much later in
|
260 |
+
# https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
261 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
262 |
+
# which is why we correct for the naming here.
|
263 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
264 |
+
|
265 |
+
# Check inputs
|
266 |
+
if len(down_block_types) != len(up_block_types):
|
267 |
+
raise ValueError(
|
268 |
+
"Must provide the same number of `down_block_types` as `up_block_types`."
|
269 |
+
f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
270 |
+
)
|
271 |
+
|
272 |
+
if len(block_out_channels) != len(down_block_types):
|
273 |
+
raise ValueError(
|
274 |
+
"Must provide the same number of `block_out_channels` as `down_block_types`."
|
275 |
+
f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
276 |
+
)
|
277 |
+
|
278 |
+
if not isinstance(only_cross_attention, bool) and len(
|
279 |
+
only_cross_attention
|
280 |
+
) != len(down_block_types):
|
281 |
+
raise ValueError(
|
282 |
+
"Must provide the same number of `only_cross_attention` as `down_block_types`."
|
283 |
+
f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
284 |
+
)
|
285 |
+
|
286 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
287 |
+
down_block_types
|
288 |
+
):
|
289 |
+
raise ValueError(
|
290 |
+
"Must provide the same number of `num_attention_heads` as `down_block_types`."
|
291 |
+
f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
292 |
+
)
|
293 |
+
|
294 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
|
295 |
+
down_block_types
|
296 |
+
):
|
297 |
+
raise ValueError(
|
298 |
+
"Must provide the same number of `attention_head_dim` as `down_block_types`."
|
299 |
+
f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
300 |
+
)
|
301 |
+
|
302 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
|
303 |
+
down_block_types
|
304 |
+
):
|
305 |
+
raise ValueError(
|
306 |
+
"Must provide the same number of `cross_attention_dim` as `down_block_types`."
|
307 |
+
f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
308 |
+
)
|
309 |
+
|
310 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
|
311 |
+
down_block_types
|
312 |
+
):
|
313 |
+
raise ValueError(
|
314 |
+
"Must provide the same number of `layers_per_block` as `down_block_types`."
|
315 |
+
f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
316 |
+
)
|
317 |
+
if (
|
318 |
+
isinstance(transformer_layers_per_block, list)
|
319 |
+
and reverse_transformer_layers_per_block is None
|
320 |
+
):
|
321 |
+
for layer_number_per_block in transformer_layers_per_block:
|
322 |
+
if isinstance(layer_number_per_block, list):
|
323 |
+
raise ValueError(
|
324 |
+
"Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
|
325 |
+
)
|
326 |
+
|
327 |
+
# input
|
328 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
329 |
+
self.conv_in = nn.Conv2d(
|
330 |
+
in_channels,
|
331 |
+
block_out_channels[0],
|
332 |
+
kernel_size=conv_in_kernel,
|
333 |
+
padding=conv_in_padding,
|
334 |
+
)
|
335 |
+
|
336 |
+
# time
|
337 |
+
if time_embedding_type == "fourier":
|
338 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
339 |
+
if time_embed_dim % 2 != 0:
|
340 |
+
raise ValueError(
|
341 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
342 |
+
)
|
343 |
+
self.time_proj = GaussianFourierProjection(
|
344 |
+
time_embed_dim // 2,
|
345 |
+
set_W_to_weight=False,
|
346 |
+
log=False,
|
347 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
348 |
+
)
|
349 |
+
timestep_input_dim = time_embed_dim
|
350 |
+
elif time_embedding_type == "positional":
|
351 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
352 |
+
|
353 |
+
self.time_proj = Timesteps(
|
354 |
+
block_out_channels[0], flip_sin_to_cos, freq_shift
|
355 |
+
)
|
356 |
+
timestep_input_dim = block_out_channels[0]
|
357 |
+
else:
|
358 |
+
raise ValueError(
|
359 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
360 |
+
)
|
361 |
+
|
362 |
+
self.time_embedding = TimestepEmbedding(
|
363 |
+
timestep_input_dim,
|
364 |
+
time_embed_dim,
|
365 |
+
act_fn=act_fn,
|
366 |
+
post_act_fn=timestep_post_act,
|
367 |
+
cond_proj_dim=time_cond_proj_dim,
|
368 |
+
)
|
369 |
+
|
370 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
371 |
+
encoder_hid_dim_type = "text_proj"
|
372 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
373 |
+
logger.info(
|
374 |
+
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
375 |
+
)
|
376 |
+
|
377 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
378 |
+
raise ValueError(
|
379 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
380 |
+
)
|
381 |
+
|
382 |
+
if encoder_hid_dim_type == "text_proj":
|
383 |
+
self.encoder_hid_proj = nn.Linear(
|
384 |
+
encoder_hid_dim, cross_attention_dim)
|
385 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
386 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
387 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
388 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
389 |
+
self.encoder_hid_proj = TextImageProjection(
|
390 |
+
text_embed_dim=encoder_hid_dim,
|
391 |
+
image_embed_dim=cross_attention_dim,
|
392 |
+
cross_attention_dim=cross_attention_dim,
|
393 |
+
)
|
394 |
+
elif encoder_hid_dim_type == "image_proj":
|
395 |
+
# Kandinsky 2.2
|
396 |
+
self.encoder_hid_proj = ImageProjection(
|
397 |
+
image_embed_dim=encoder_hid_dim,
|
398 |
+
cross_attention_dim=cross_attention_dim,
|
399 |
+
)
|
400 |
+
elif encoder_hid_dim_type is not None:
|
401 |
+
raise ValueError(
|
402 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
self.encoder_hid_proj = None
|
406 |
+
|
407 |
+
# class embedding
|
408 |
+
if class_embed_type is None and num_class_embeds is not None:
|
409 |
+
self.class_embedding = nn.Embedding(
|
410 |
+
num_class_embeds, time_embed_dim)
|
411 |
+
elif class_embed_type == "timestep":
|
412 |
+
self.class_embedding = TimestepEmbedding(
|
413 |
+
timestep_input_dim, time_embed_dim, act_fn=act_fn
|
414 |
+
)
|
415 |
+
elif class_embed_type == "identity":
|
416 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
417 |
+
elif class_embed_type == "projection":
|
418 |
+
if projection_class_embeddings_input_dim is None:
|
419 |
+
raise ValueError(
|
420 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
421 |
+
)
|
422 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
423 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
424 |
+
# 2. it projects from an arbitrary input dimension.
|
425 |
+
#
|
426 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
427 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
428 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
429 |
+
self.class_embedding = TimestepEmbedding(
|
430 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
431 |
+
)
|
432 |
+
elif class_embed_type == "simple_projection":
|
433 |
+
if projection_class_embeddings_input_dim is None:
|
434 |
+
raise ValueError(
|
435 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
436 |
+
)
|
437 |
+
self.class_embedding = nn.Linear(
|
438 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
self.class_embedding = None
|
442 |
+
|
443 |
+
if addition_embed_type == "text":
|
444 |
+
if encoder_hid_dim is not None:
|
445 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
446 |
+
else:
|
447 |
+
text_time_embedding_from_dim = cross_attention_dim
|
448 |
+
|
449 |
+
self.add_embedding = TextTimeEmbedding(
|
450 |
+
text_time_embedding_from_dim,
|
451 |
+
time_embed_dim,
|
452 |
+
num_heads=addition_embed_type_num_heads,
|
453 |
+
)
|
454 |
+
elif addition_embed_type == "text_image":
|
455 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
456 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
457 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
458 |
+
self.add_embedding = TextImageTimeEmbedding(
|
459 |
+
text_embed_dim=cross_attention_dim,
|
460 |
+
image_embed_dim=cross_attention_dim,
|
461 |
+
time_embed_dim=time_embed_dim,
|
462 |
+
)
|
463 |
+
elif addition_embed_type == "text_time":
|
464 |
+
self.add_time_proj = Timesteps(
|
465 |
+
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
466 |
+
)
|
467 |
+
self.add_embedding = TimestepEmbedding(
|
468 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
469 |
+
)
|
470 |
+
elif addition_embed_type == "image":
|
471 |
+
# Kandinsky 2.2
|
472 |
+
self.add_embedding = ImageTimeEmbedding(
|
473 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
474 |
+
)
|
475 |
+
elif addition_embed_type == "image_hint":
|
476 |
+
# Kandinsky 2.2 ControlNet
|
477 |
+
self.add_embedding = ImageHintTimeEmbedding(
|
478 |
+
image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
|
479 |
+
)
|
480 |
+
elif addition_embed_type is not None:
|
481 |
+
raise ValueError(
|
482 |
+
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
483 |
+
)
|
484 |
+
|
485 |
+
if time_embedding_act_fn is None:
|
486 |
+
self.time_embed_act = None
|
487 |
+
else:
|
488 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
489 |
+
|
490 |
+
self.down_blocks = nn.ModuleList([])
|
491 |
+
self.up_blocks = nn.ModuleList([])
|
492 |
+
|
493 |
+
if isinstance(only_cross_attention, bool):
|
494 |
+
if mid_block_only_cross_attention is None:
|
495 |
+
mid_block_only_cross_attention = only_cross_attention
|
496 |
+
|
497 |
+
only_cross_attention = [
|
498 |
+
only_cross_attention] * len(down_block_types)
|
499 |
+
|
500 |
+
if mid_block_only_cross_attention is None:
|
501 |
+
mid_block_only_cross_attention = False
|
502 |
+
|
503 |
+
if isinstance(num_attention_heads, int):
|
504 |
+
num_attention_heads = (num_attention_heads,) * \
|
505 |
+
len(down_block_types)
|
506 |
+
|
507 |
+
if isinstance(attention_head_dim, int):
|
508 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
509 |
+
|
510 |
+
if isinstance(cross_attention_dim, int):
|
511 |
+
cross_attention_dim = (cross_attention_dim,) * \
|
512 |
+
len(down_block_types)
|
513 |
+
|
514 |
+
if isinstance(layers_per_block, int):
|
515 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
516 |
+
|
517 |
+
if isinstance(transformer_layers_per_block, int):
|
518 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
519 |
+
down_block_types
|
520 |
+
)
|
521 |
+
|
522 |
+
if class_embeddings_concat:
|
523 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
524 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
525 |
+
# regular time embeddings
|
526 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
527 |
+
else:
|
528 |
+
blocks_time_embed_dim = time_embed_dim
|
529 |
+
|
530 |
+
# down
|
531 |
+
output_channel = block_out_channels[0]
|
532 |
+
for i, down_block_type in enumerate(down_block_types):
|
533 |
+
input_channel = output_channel
|
534 |
+
output_channel = block_out_channels[i]
|
535 |
+
is_final_block = i == len(block_out_channels) - 1
|
536 |
+
|
537 |
+
down_block = get_down_block(
|
538 |
+
down_block_type,
|
539 |
+
num_layers=layers_per_block[i],
|
540 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
541 |
+
in_channels=input_channel,
|
542 |
+
out_channels=output_channel,
|
543 |
+
temb_channels=blocks_time_embed_dim,
|
544 |
+
add_downsample=not is_final_block,
|
545 |
+
resnet_eps=norm_eps,
|
546 |
+
resnet_act_fn=act_fn,
|
547 |
+
resnet_groups=norm_num_groups,
|
548 |
+
cross_attention_dim=cross_attention_dim[i],
|
549 |
+
num_attention_heads=num_attention_heads[i],
|
550 |
+
downsample_padding=downsample_padding,
|
551 |
+
dual_cross_attention=dual_cross_attention,
|
552 |
+
use_linear_projection=use_linear_projection,
|
553 |
+
only_cross_attention=only_cross_attention[i],
|
554 |
+
upcast_attention=upcast_attention,
|
555 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
556 |
+
attention_type=attention_type,
|
557 |
+
attention_head_dim=(
|
558 |
+
attention_head_dim[i]
|
559 |
+
if attention_head_dim[i] is not None
|
560 |
+
else output_channel
|
561 |
+
),
|
562 |
+
dropout=dropout,
|
563 |
+
)
|
564 |
+
self.down_blocks.append(down_block)
|
565 |
+
|
566 |
+
# mid
|
567 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
568 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
569 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
570 |
+
in_channels=block_out_channels[-1],
|
571 |
+
temb_channels=blocks_time_embed_dim,
|
572 |
+
dropout=dropout,
|
573 |
+
resnet_eps=norm_eps,
|
574 |
+
resnet_act_fn=act_fn,
|
575 |
+
output_scale_factor=mid_block_scale_factor,
|
576 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
577 |
+
cross_attention_dim=cross_attention_dim[-1],
|
578 |
+
num_attention_heads=num_attention_heads[-1],
|
579 |
+
resnet_groups=norm_num_groups,
|
580 |
+
dual_cross_attention=dual_cross_attention,
|
581 |
+
use_linear_projection=use_linear_projection,
|
582 |
+
upcast_attention=upcast_attention,
|
583 |
+
attention_type=attention_type,
|
584 |
+
)
|
585 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
586 |
+
raise NotImplementedError(
|
587 |
+
f"Unsupport mid_block_type: {mid_block_type}")
|
588 |
+
elif mid_block_type == "UNetMidBlock2D":
|
589 |
+
self.mid_block = UNetMidBlock2D(
|
590 |
+
in_channels=block_out_channels[-1],
|
591 |
+
temb_channels=blocks_time_embed_dim,
|
592 |
+
dropout=dropout,
|
593 |
+
num_layers=0,
|
594 |
+
resnet_eps=norm_eps,
|
595 |
+
resnet_act_fn=act_fn,
|
596 |
+
output_scale_factor=mid_block_scale_factor,
|
597 |
+
resnet_groups=norm_num_groups,
|
598 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
599 |
+
add_attention=False,
|
600 |
+
)
|
601 |
+
elif mid_block_type is None:
|
602 |
+
self.mid_block = None
|
603 |
+
else:
|
604 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
605 |
+
|
606 |
+
# count how many layers upsample the images
|
607 |
+
self.num_upsamplers = 0
|
608 |
+
|
609 |
+
# up
|
610 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
611 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
612 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
613 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
614 |
+
reversed_transformer_layers_per_block = (
|
615 |
+
list(reversed(transformer_layers_per_block))
|
616 |
+
if reverse_transformer_layers_per_block is None
|
617 |
+
else reverse_transformer_layers_per_block
|
618 |
+
)
|
619 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
620 |
+
|
621 |
+
output_channel = reversed_block_out_channels[0]
|
622 |
+
for i, up_block_type in enumerate(up_block_types):
|
623 |
+
is_final_block = i == len(block_out_channels) - 1
|
624 |
+
|
625 |
+
prev_output_channel = output_channel
|
626 |
+
output_channel = reversed_block_out_channels[i]
|
627 |
+
input_channel = reversed_block_out_channels[
|
628 |
+
min(i + 1, len(block_out_channels) - 1)
|
629 |
+
]
|
630 |
+
|
631 |
+
# add upsample block for all BUT final layer
|
632 |
+
if not is_final_block:
|
633 |
+
add_upsample = True
|
634 |
+
self.num_upsamplers += 1
|
635 |
+
else:
|
636 |
+
add_upsample = False
|
637 |
+
|
638 |
+
up_block = get_up_block(
|
639 |
+
up_block_type,
|
640 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
641 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
642 |
+
in_channels=input_channel,
|
643 |
+
out_channels=output_channel,
|
644 |
+
prev_output_channel=prev_output_channel,
|
645 |
+
temb_channels=blocks_time_embed_dim,
|
646 |
+
add_upsample=add_upsample,
|
647 |
+
resnet_eps=norm_eps,
|
648 |
+
resnet_act_fn=act_fn,
|
649 |
+
resolution_idx=i,
|
650 |
+
resnet_groups=norm_num_groups,
|
651 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
652 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
653 |
+
dual_cross_attention=dual_cross_attention,
|
654 |
+
use_linear_projection=use_linear_projection,
|
655 |
+
only_cross_attention=only_cross_attention[i],
|
656 |
+
upcast_attention=upcast_attention,
|
657 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
658 |
+
attention_type=attention_type,
|
659 |
+
attention_head_dim=(
|
660 |
+
attention_head_dim[i]
|
661 |
+
if attention_head_dim[i] is not None
|
662 |
+
else output_channel
|
663 |
+
),
|
664 |
+
dropout=dropout,
|
665 |
+
)
|
666 |
+
self.up_blocks.append(up_block)
|
667 |
+
prev_output_channel = output_channel
|
668 |
+
|
669 |
+
# out
|
670 |
+
if norm_num_groups is not None:
|
671 |
+
self.conv_norm_out = nn.GroupNorm(
|
672 |
+
num_channels=block_out_channels[0],
|
673 |
+
num_groups=norm_num_groups,
|
674 |
+
eps=norm_eps,
|
675 |
+
)
|
676 |
+
|
677 |
+
self.conv_act = get_activation(act_fn)
|
678 |
+
|
679 |
+
else:
|
680 |
+
self.conv_norm_out = None
|
681 |
+
self.conv_act = None
|
682 |
+
self.conv_norm_out = None
|
683 |
+
|
684 |
+
if attention_type in ["gated", "gated-text-image"]:
|
685 |
+
positive_len = 768
|
686 |
+
if isinstance(cross_attention_dim, int):
|
687 |
+
positive_len = cross_attention_dim
|
688 |
+
elif isinstance(cross_attention_dim, (tuple, list)):
|
689 |
+
positive_len = cross_attention_dim[0]
|
690 |
+
|
691 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
692 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
693 |
+
positive_len=positive_len,
|
694 |
+
out_dim=cross_attention_dim,
|
695 |
+
feature_type=feature_type,
|
696 |
+
)
|
697 |
+
|
698 |
+
@property
|
699 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
700 |
+
r"""
|
701 |
+
Returns:
|
702 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
703 |
+
indexed by its weight name.
|
704 |
+
"""
|
705 |
+
# set recursively
|
706 |
+
processors = {}
|
707 |
+
|
708 |
+
def fn_recursive_add_processors(
|
709 |
+
name: str,
|
710 |
+
module: torch.nn.Module,
|
711 |
+
processors: Dict[str, AttentionProcessor],
|
712 |
+
):
|
713 |
+
if hasattr(module, "get_processor"):
|
714 |
+
processors[f"{name}.processor"] = module.get_processor(
|
715 |
+
return_deprecated_lora=True
|
716 |
+
)
|
717 |
+
|
718 |
+
for sub_name, child in module.named_children():
|
719 |
+
fn_recursive_add_processors(
|
720 |
+
f"{name}.{sub_name}", child, processors)
|
721 |
+
|
722 |
+
return processors
|
723 |
+
|
724 |
+
for name, module in self.named_children():
|
725 |
+
fn_recursive_add_processors(name, module, processors)
|
726 |
+
|
727 |
+
return processors
|
728 |
+
|
729 |
+
def set_attn_processor(
|
730 |
+
self,
|
731 |
+
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
|
732 |
+
_remove_lora=False,
|
733 |
+
):
|
734 |
+
r"""
|
735 |
+
Sets the attention processor to use to compute attention.
|
736 |
+
|
737 |
+
Parameters:
|
738 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
739 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
740 |
+
for **all** `Attention` layers.
|
741 |
+
|
742 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
743 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
744 |
+
|
745 |
+
"""
|
746 |
+
count = len(self.attn_processors.keys())
|
747 |
+
|
748 |
+
if isinstance(processor, dict) and len(processor) != count:
|
749 |
+
raise ValueError(
|
750 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
751 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
752 |
+
)
|
753 |
+
|
754 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
755 |
+
if hasattr(module, "set_processor"):
|
756 |
+
if not isinstance(processor, dict):
|
757 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
758 |
+
else:
|
759 |
+
module.set_processor(
|
760 |
+
processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
|
761 |
+
)
|
762 |
+
|
763 |
+
for sub_name, child in module.named_children():
|
764 |
+
fn_recursive_attn_processor(
|
765 |
+
f"{name}.{sub_name}", child, processor)
|
766 |
+
|
767 |
+
for name, module in self.named_children():
|
768 |
+
fn_recursive_attn_processor(name, module, processor)
|
769 |
+
|
770 |
+
def set_default_attn_processor(self):
|
771 |
+
"""
|
772 |
+
Disables custom attention processors and sets the default attention implementation.
|
773 |
+
"""
|
774 |
+
if all(
|
775 |
+
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
776 |
+
for proc in self.attn_processors.values()
|
777 |
+
):
|
778 |
+
processor = AttnAddedKVProcessor()
|
779 |
+
elif all(
|
780 |
+
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
781 |
+
for proc in self.attn_processors.values()
|
782 |
+
):
|
783 |
+
processor = AttnProcessor()
|
784 |
+
else:
|
785 |
+
raise ValueError(
|
786 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
787 |
+
)
|
788 |
+
|
789 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
790 |
+
|
791 |
+
def set_attention_slice(self, slice_size):
|
792 |
+
r"""
|
793 |
+
Enable sliced attention computation.
|
794 |
+
|
795 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
796 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
797 |
+
|
798 |
+
Args:
|
799 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
800 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
801 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
802 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
803 |
+
must be a multiple of `slice_size`.
|
804 |
+
"""
|
805 |
+
sliceable_head_dims = []
|
806 |
+
|
807 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
808 |
+
if hasattr(module, "set_attention_slice"):
|
809 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
810 |
+
|
811 |
+
for child in module.children():
|
812 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
813 |
+
|
814 |
+
# retrieve number of attention layers
|
815 |
+
for module in self.children():
|
816 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
817 |
+
|
818 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
819 |
+
|
820 |
+
if slice_size == "auto":
|
821 |
+
# half the attention head size is usually a good trade-off between
|
822 |
+
# speed and memory
|
823 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
824 |
+
elif slice_size == "max":
|
825 |
+
# make smallest slice possible
|
826 |
+
slice_size = num_sliceable_layers * [1]
|
827 |
+
|
828 |
+
slice_size = (
|
829 |
+
num_sliceable_layers * [slice_size]
|
830 |
+
if not isinstance(slice_size, list)
|
831 |
+
else slice_size
|
832 |
+
)
|
833 |
+
|
834 |
+
if len(slice_size) != len(sliceable_head_dims):
|
835 |
+
raise ValueError(
|
836 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
837 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
838 |
+
)
|
839 |
+
|
840 |
+
for i, size in enumerate(slice_size):
|
841 |
+
dim = sliceable_head_dims[i]
|
842 |
+
if size is not None and size > dim:
|
843 |
+
raise ValueError(
|
844 |
+
f"size {size} has to be smaller or equal to {dim}.")
|
845 |
+
|
846 |
+
# Recursively walk through all the children.
|
847 |
+
# Any children which exposes the set_attention_slice method
|
848 |
+
# gets the message
|
849 |
+
def fn_recursive_set_attention_slice(
|
850 |
+
module: torch.nn.Module, slice_size: List[int]
|
851 |
+
):
|
852 |
+
if hasattr(module, "set_attention_slice"):
|
853 |
+
module.set_attention_slice(slice_size.pop())
|
854 |
+
|
855 |
+
for child in module.children():
|
856 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
857 |
+
|
858 |
+
reversed_slice_size = list(reversed(slice_size))
|
859 |
+
for module in self.children():
|
860 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
861 |
+
|
862 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
863 |
+
if hasattr(module, "gradient_checkpointing"):
|
864 |
+
module.gradient_checkpointing = value
|
865 |
+
|
866 |
+
def enable_freeu(self, s1, s2, b1, b2):
|
867 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
868 |
+
|
869 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
870 |
+
|
871 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
872 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
873 |
+
|
874 |
+
Args:
|
875 |
+
s1 (`float`):
|
876 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
877 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
878 |
+
s2 (`float`):
|
879 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
880 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
881 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
882 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
883 |
+
"""
|
884 |
+
for _, upsample_block in enumerate(self.up_blocks):
|
885 |
+
setattr(upsample_block, "s1", s1)
|
886 |
+
setattr(upsample_block, "s2", s2)
|
887 |
+
setattr(upsample_block, "b1", b1)
|
888 |
+
setattr(upsample_block, "b2", b2)
|
889 |
+
|
890 |
+
def disable_freeu(self):
|
891 |
+
"""Disables the FreeU mechanism."""
|
892 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
893 |
+
for _, upsample_block in enumerate(self.up_blocks):
|
894 |
+
for k in freeu_keys:
|
895 |
+
if (
|
896 |
+
hasattr(upsample_block, k)
|
897 |
+
or getattr(upsample_block, k, None) is not None
|
898 |
+
):
|
899 |
+
setattr(upsample_block, k, None)
|
900 |
+
|
901 |
+
def forward(
|
902 |
+
self,
|
903 |
+
sample: torch.FloatTensor,
|
904 |
+
timestep: Union[torch.Tensor, float, int],
|
905 |
+
encoder_hidden_states: torch.Tensor,
|
906 |
+
cond_tensor: torch.FloatTensor=None,
|
907 |
+
class_labels: Optional[torch.Tensor] = None,
|
908 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
909 |
+
attention_mask: Optional[torch.Tensor] = None,
|
910 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
911 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
912 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
913 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
914 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
915 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
916 |
+
return_dict: bool = True,
|
917 |
+
post_process: bool = False,
|
918 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
919 |
+
r"""
|
920 |
+
The [`UNet2DConditionModel`] forward method.
|
921 |
+
|
922 |
+
Args:
|
923 |
+
sample (`torch.FloatTensor`):
|
924 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
925 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
926 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
927 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
928 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
929 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
930 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
931 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
932 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
933 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
934 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
935 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
936 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
937 |
+
cross_attention_kwargs (`dict`, *optional*):
|
938 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
939 |
+
`self.processor` in
|
940 |
+
[diffusers.models.attention_processor]
|
941 |
+
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
942 |
+
added_cond_kwargs: (`dict`, *optional*):
|
943 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
944 |
+
are passed along to the UNet blocks.
|
945 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
946 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
947 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
948 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
949 |
+
encoder_attention_mask (`torch.Tensor`):
|
950 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
951 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
952 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
953 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
954 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
955 |
+
tuple.
|
956 |
+
cross_attention_kwargs (`dict`, *optional*):
|
957 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
958 |
+
added_cond_kwargs: (`dict`, *optional*):
|
959 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
960 |
+
are passed along to the UNet blocks.
|
961 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
962 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
963 |
+
example from ControlNet side model(s)
|
964 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
965 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
966 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
967 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
968 |
+
|
969 |
+
Returns:
|
970 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
971 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
972 |
+
a `tuple` is returned where the first element is the sample tensor.
|
973 |
+
"""
|
974 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
975 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
976 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
977 |
+
# on the fly if necessary.
|
978 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
979 |
+
|
980 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
981 |
+
forward_upsample_size = False
|
982 |
+
upsample_size = None
|
983 |
+
|
984 |
+
for dim in sample.shape[-2:]:
|
985 |
+
if dim % default_overall_up_factor != 0:
|
986 |
+
# Forward upsample size to force interpolation output size.
|
987 |
+
forward_upsample_size = True
|
988 |
+
break
|
989 |
+
|
990 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
991 |
+
# expects mask of shape:
|
992 |
+
# [batch, key_tokens]
|
993 |
+
# adds singleton query_tokens dimension:
|
994 |
+
# [batch, 1, key_tokens]
|
995 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
996 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
997 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
998 |
+
if attention_mask is not None:
|
999 |
+
# assume that mask is expressed as:
|
1000 |
+
# (1 = keep, 0 = discard)
|
1001 |
+
# convert mask into a bias that can be added to attention scores:
|
1002 |
+
# (keep = +0, discard = -10000.0)
|
1003 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
1004 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1005 |
+
|
1006 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
1007 |
+
if encoder_attention_mask is not None:
|
1008 |
+
encoder_attention_mask = (
|
1009 |
+
1 - encoder_attention_mask.to(sample.dtype)
|
1010 |
+
) * -10000.0
|
1011 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
1012 |
+
|
1013 |
+
# 0. center input if necessary
|
1014 |
+
if self.config.center_input_sample:
|
1015 |
+
sample = 2 * sample - 1.0
|
1016 |
+
|
1017 |
+
# 1. time
|
1018 |
+
timesteps = timestep
|
1019 |
+
if not torch.is_tensor(timesteps):
|
1020 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1021 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
1022 |
+
is_mps = sample.device.type == "mps"
|
1023 |
+
if isinstance(timestep, float):
|
1024 |
+
dtype = torch.float32 if is_mps else torch.float64
|
1025 |
+
else:
|
1026 |
+
dtype = torch.int32 if is_mps else torch.int64
|
1027 |
+
timesteps = torch.tensor(
|
1028 |
+
[timesteps], dtype=dtype, device=sample.device)
|
1029 |
+
elif len(timesteps.shape) == 0:
|
1030 |
+
timesteps = timesteps[None].to(sample.device)
|
1031 |
+
|
1032 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1033 |
+
timesteps = timesteps.expand(sample.shape[0])
|
1034 |
+
|
1035 |
+
t_emb = self.time_proj(timesteps)
|
1036 |
+
|
1037 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
1038 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1039 |
+
# there might be better ways to encapsulate this.
|
1040 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
1041 |
+
|
1042 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
1043 |
+
aug_emb = None
|
1044 |
+
|
1045 |
+
if self.class_embedding is not None:
|
1046 |
+
if class_labels is None:
|
1047 |
+
raise ValueError(
|
1048 |
+
"class_labels should be provided when num_class_embeds > 0"
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
if self.config.class_embed_type == "timestep":
|
1052 |
+
class_labels = self.time_proj(class_labels)
|
1053 |
+
|
1054 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
1055 |
+
# there might be better ways to encapsulate this.
|
1056 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
1057 |
+
|
1058 |
+
class_emb = self.class_embedding(
|
1059 |
+
class_labels).to(dtype=sample.dtype)
|
1060 |
+
|
1061 |
+
if self.config.class_embeddings_concat:
|
1062 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
1063 |
+
else:
|
1064 |
+
emb = emb + class_emb
|
1065 |
+
|
1066 |
+
if self.config.addition_embed_type == "text":
|
1067 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
1068 |
+
elif self.config.addition_embed_type == "text_image":
|
1069 |
+
# Kandinsky 2.1 - style
|
1070 |
+
if "image_embeds" not in added_cond_kwargs:
|
1071 |
+
raise ValueError(
|
1072 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image'"
|
1073 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1077 |
+
text_embs = added_cond_kwargs.get(
|
1078 |
+
"text_embeds", encoder_hidden_states)
|
1079 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
1080 |
+
elif self.config.addition_embed_type == "text_time":
|
1081 |
+
# SDXL - style
|
1082 |
+
if "text_embeds" not in added_cond_kwargs:
|
1083 |
+
raise ValueError(
|
1084 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
|
1085 |
+
"which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1086 |
+
)
|
1087 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
1088 |
+
if "time_ids" not in added_cond_kwargs:
|
1089 |
+
raise ValueError(
|
1090 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
|
1091 |
+
"which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1092 |
+
)
|
1093 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
1094 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
1095 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
1096 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
1097 |
+
add_embeds = add_embeds.to(emb.dtype)
|
1098 |
+
aug_emb = self.add_embedding(add_embeds)
|
1099 |
+
elif self.config.addition_embed_type == "image":
|
1100 |
+
# Kandinsky 2.2 - style
|
1101 |
+
if "image_embeds" not in added_cond_kwargs:
|
1102 |
+
raise ValueError(
|
1103 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image'"
|
1104 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1105 |
+
)
|
1106 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1107 |
+
aug_emb = self.add_embedding(image_embs)
|
1108 |
+
elif self.config.addition_embed_type == "image_hint":
|
1109 |
+
# Kandinsky 2.2 - style
|
1110 |
+
if (
|
1111 |
+
"image_embeds" not in added_cond_kwargs
|
1112 |
+
or "hint" not in added_cond_kwargs
|
1113 |
+
):
|
1114 |
+
raise ValueError(
|
1115 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint'"
|
1116 |
+
"which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1117 |
+
)
|
1118 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1119 |
+
hint = added_cond_kwargs.get("hint")
|
1120 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
1121 |
+
sample = torch.cat([sample, hint], dim=1)
|
1122 |
+
|
1123 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
1124 |
+
|
1125 |
+
if self.time_embed_act is not None:
|
1126 |
+
emb = self.time_embed_act(emb)
|
1127 |
+
|
1128 |
+
if (
|
1129 |
+
self.encoder_hid_proj is not None
|
1130 |
+
and self.config.encoder_hid_dim_type == "text_proj"
|
1131 |
+
):
|
1132 |
+
encoder_hidden_states = self.encoder_hid_proj(
|
1133 |
+
encoder_hidden_states)
|
1134 |
+
elif (
|
1135 |
+
self.encoder_hid_proj is not None
|
1136 |
+
and self.config.encoder_hid_dim_type == "text_image_proj"
|
1137 |
+
):
|
1138 |
+
# Kadinsky 2.1 - style
|
1139 |
+
if "image_embeds" not in added_cond_kwargs:
|
1140 |
+
raise ValueError(
|
1141 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj'"
|
1142 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1143 |
+
)
|
1144 |
+
|
1145 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1146 |
+
encoder_hidden_states = self.encoder_hid_proj(
|
1147 |
+
encoder_hidden_states, image_embeds
|
1148 |
+
)
|
1149 |
+
elif (
|
1150 |
+
self.encoder_hid_proj is not None
|
1151 |
+
and self.config.encoder_hid_dim_type == "image_proj"
|
1152 |
+
):
|
1153 |
+
# Kandinsky 2.2 - style
|
1154 |
+
if "image_embeds" not in added_cond_kwargs:
|
1155 |
+
raise ValueError(
|
1156 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj'"
|
1157 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1158 |
+
)
|
1159 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1160 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1161 |
+
elif (
|
1162 |
+
self.encoder_hid_proj is not None
|
1163 |
+
and self.config.encoder_hid_dim_type == "ip_image_proj"
|
1164 |
+
):
|
1165 |
+
if "image_embeds" not in added_cond_kwargs:
|
1166 |
+
raise ValueError(
|
1167 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj'"
|
1168 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1169 |
+
)
|
1170 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1171 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(
|
1172 |
+
encoder_hidden_states.dtype
|
1173 |
+
)
|
1174 |
+
encoder_hidden_states = torch.cat(
|
1175 |
+
[encoder_hidden_states, image_embeds], dim=1
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
# 2. pre-process
|
1179 |
+
sample = self.conv_in(sample)
|
1180 |
+
if cond_tensor is not None:
|
1181 |
+
sample = sample + cond_tensor
|
1182 |
+
|
1183 |
+
# 2.5 GLIGEN position net
|
1184 |
+
if (
|
1185 |
+
cross_attention_kwargs is not None
|
1186 |
+
and cross_attention_kwargs.get("gligen", None) is not None
|
1187 |
+
):
|
1188 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1189 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
1190 |
+
cross_attention_kwargs["gligen"] = {
|
1191 |
+
"objs": self.position_net(**gligen_args)
|
1192 |
+
}
|
1193 |
+
|
1194 |
+
# 3. down
|
1195 |
+
lora_scale = (
|
1196 |
+
cross_attention_kwargs.get("scale", 1.0)
|
1197 |
+
if cross_attention_kwargs is not None
|
1198 |
+
else 1.0
|
1199 |
+
)
|
1200 |
+
if USE_PEFT_BACKEND:
|
1201 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1202 |
+
scale_lora_layers(self, lora_scale)
|
1203 |
+
|
1204 |
+
is_controlnet = (
|
1205 |
+
mid_block_additional_residual is not None
|
1206 |
+
and down_block_additional_residuals is not None
|
1207 |
+
)
|
1208 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
1209 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
1210 |
+
# maintain backward compatibility for legacy usage, where
|
1211 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
1212 |
+
# but can only use one or the other
|
1213 |
+
if (
|
1214 |
+
not is_adapter
|
1215 |
+
and mid_block_additional_residual is None
|
1216 |
+
and down_block_additional_residuals is not None
|
1217 |
+
):
|
1218 |
+
deprecate(
|
1219 |
+
"T2I should not use down_block_additional_residuals",
|
1220 |
+
"1.3.0",
|
1221 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1222 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1223 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1224 |
+
standard_warn=False,
|
1225 |
+
)
|
1226 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
1227 |
+
is_adapter = True
|
1228 |
+
|
1229 |
+
down_block_res_samples = (sample,)
|
1230 |
+
for downsample_block in self.down_blocks:
|
1231 |
+
if (
|
1232 |
+
hasattr(downsample_block, "has_cross_attention")
|
1233 |
+
and downsample_block.has_cross_attention
|
1234 |
+
):
|
1235 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
1236 |
+
additional_residuals = {}
|
1237 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1238 |
+
additional_residuals["additional_residuals"] = (
|
1239 |
+
down_intrablock_additional_residuals.pop(0)
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
sample, res_samples = downsample_block(
|
1243 |
+
hidden_states=sample,
|
1244 |
+
temb=emb,
|
1245 |
+
encoder_hidden_states=encoder_hidden_states,
|
1246 |
+
attention_mask=attention_mask,
|
1247 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1248 |
+
encoder_attention_mask=encoder_attention_mask,
|
1249 |
+
**additional_residuals,
|
1250 |
+
)
|
1251 |
+
else:
|
1252 |
+
sample, res_samples = downsample_block(
|
1253 |
+
hidden_states=sample, temb=emb, scale=lora_scale
|
1254 |
+
)
|
1255 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1256 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1257 |
+
|
1258 |
+
down_block_res_samples += res_samples
|
1259 |
+
|
1260 |
+
if is_controlnet:
|
1261 |
+
new_down_block_res_samples = ()
|
1262 |
+
|
1263 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1264 |
+
down_block_res_samples, down_block_additional_residuals
|
1265 |
+
):
|
1266 |
+
down_block_res_sample = (
|
1267 |
+
down_block_res_sample + down_block_additional_residual
|
1268 |
+
)
|
1269 |
+
new_down_block_res_samples = new_down_block_res_samples + (
|
1270 |
+
down_block_res_sample,
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
down_block_res_samples = new_down_block_res_samples
|
1274 |
+
|
1275 |
+
# 4. mid
|
1276 |
+
if self.mid_block is not None:
|
1277 |
+
if (
|
1278 |
+
hasattr(self.mid_block, "has_cross_attention")
|
1279 |
+
and self.mid_block.has_cross_attention
|
1280 |
+
):
|
1281 |
+
sample = self.mid_block(
|
1282 |
+
sample,
|
1283 |
+
emb,
|
1284 |
+
encoder_hidden_states=encoder_hidden_states,
|
1285 |
+
attention_mask=attention_mask,
|
1286 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1287 |
+
encoder_attention_mask=encoder_attention_mask,
|
1288 |
+
)
|
1289 |
+
else:
|
1290 |
+
sample = self.mid_block(sample, emb)
|
1291 |
+
|
1292 |
+
# To support T2I-Adapter-XL
|
1293 |
+
if (
|
1294 |
+
is_adapter
|
1295 |
+
and len(down_intrablock_additional_residuals) > 0
|
1296 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1297 |
+
):
|
1298 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1299 |
+
|
1300 |
+
if is_controlnet:
|
1301 |
+
sample = sample + mid_block_additional_residual
|
1302 |
+
|
1303 |
+
# 5. up
|
1304 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1305 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1306 |
+
|
1307 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
1308 |
+
down_block_res_samples = down_block_res_samples[
|
1309 |
+
: -len(upsample_block.resnets)
|
1310 |
+
]
|
1311 |
+
|
1312 |
+
# if we have not reached the final block and need to forward the
|
1313 |
+
# upsample size, we do it here
|
1314 |
+
if not is_final_block and forward_upsample_size:
|
1315 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1316 |
+
|
1317 |
+
if (
|
1318 |
+
hasattr(upsample_block, "has_cross_attention")
|
1319 |
+
and upsample_block.has_cross_attention
|
1320 |
+
):
|
1321 |
+
sample = upsample_block(
|
1322 |
+
hidden_states=sample,
|
1323 |
+
temb=emb,
|
1324 |
+
res_hidden_states_tuple=res_samples,
|
1325 |
+
encoder_hidden_states=encoder_hidden_states,
|
1326 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1327 |
+
upsample_size=upsample_size,
|
1328 |
+
attention_mask=attention_mask,
|
1329 |
+
encoder_attention_mask=encoder_attention_mask,
|
1330 |
+
)
|
1331 |
+
else:
|
1332 |
+
sample = upsample_block(
|
1333 |
+
hidden_states=sample,
|
1334 |
+
temb=emb,
|
1335 |
+
res_hidden_states_tuple=res_samples,
|
1336 |
+
upsample_size=upsample_size,
|
1337 |
+
scale=lora_scale,
|
1338 |
+
)
|
1339 |
+
|
1340 |
+
# 6. post-process
|
1341 |
+
if post_process:
|
1342 |
+
if self.conv_norm_out:
|
1343 |
+
sample = self.conv_norm_out(sample)
|
1344 |
+
sample = self.conv_act(sample)
|
1345 |
+
sample = self.conv_out(sample)
|
1346 |
+
|
1347 |
+
if USE_PEFT_BACKEND:
|
1348 |
+
# remove `lora_scale` from each PEFT layer
|
1349 |
+
unscale_lora_layers(self, lora_scale)
|
1350 |
+
|
1351 |
+
if not return_dict:
|
1352 |
+
return (sample,)
|
1353 |
+
|
1354 |
+
return UNet2DConditionOutput(sample=sample)
|
1355 |
+
|
1356 |
+
@classmethod
|
1357 |
+
def load_change_cross_attention_dim(
|
1358 |
+
cls,
|
1359 |
+
pretrained_model_path: PathLike,
|
1360 |
+
subfolder=None,
|
1361 |
+
# unet_additional_kwargs=None,
|
1362 |
+
):
|
1363 |
+
"""
|
1364 |
+
Load or change the cross-attention dimension of a pre-trained model.
|
1365 |
+
|
1366 |
+
Parameters:
|
1367 |
+
pretrained_model_name_or_path (:class:`~typing.Union[str, :class:`~pathlib.Path`]`):
|
1368 |
+
The identifier of the pre-trained model or the path to the local folder containing the model.
|
1369 |
+
force_download (:class:`~bool`):
|
1370 |
+
If True, re-download the model even if it is already cached.
|
1371 |
+
resume_download (:class:`~bool`):
|
1372 |
+
If True, resume the download of the model if partially downloaded.
|
1373 |
+
proxies (:class:`~dict`):
|
1374 |
+
A dictionary of proxy servers to use for downloading the model.
|
1375 |
+
cache_dir (:class:`~Optional[str]`):
|
1376 |
+
The path to the cache directory for storing downloaded models.
|
1377 |
+
use_auth_token (:class:`~bool`):
|
1378 |
+
If True, use the authentication token for private models.
|
1379 |
+
revision (:class:`~str`):
|
1380 |
+
The specific model version to use.
|
1381 |
+
use_safetensors (:class:`~bool`):
|
1382 |
+
If True, use the SafeTensors format for loading the model weights.
|
1383 |
+
**kwargs (:class:`~dict`):
|
1384 |
+
Additional keyword arguments passed to the model.
|
1385 |
+
|
1386 |
+
"""
|
1387 |
+
pretrained_model_path = Path(pretrained_model_path)
|
1388 |
+
if subfolder is not None:
|
1389 |
+
pretrained_model_path = pretrained_model_path.joinpath(subfolder)
|
1390 |
+
config_file = pretrained_model_path / "config.json"
|
1391 |
+
if not (config_file.exists() and config_file.is_file()):
|
1392 |
+
raise RuntimeError(
|
1393 |
+
f"{config_file} does not exist or is not a file")
|
1394 |
+
|
1395 |
+
unet_config = cls.load_config(config_file)
|
1396 |
+
unet_config["cross_attention_dim"] = 1024
|
1397 |
+
|
1398 |
+
model = cls.from_config(unet_config)
|
1399 |
+
# load the vanilla weights
|
1400 |
+
if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
|
1401 |
+
logger.debug(
|
1402 |
+
f"loading safeTensors weights from {pretrained_model_path} ..."
|
1403 |
+
)
|
1404 |
+
state_dict = load_file(
|
1405 |
+
pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
|
1409 |
+
logger.debug(f"loading weights from {pretrained_model_path} ...")
|
1410 |
+
state_dict = torch.load(
|
1411 |
+
pretrained_model_path.joinpath(WEIGHTS_NAME),
|
1412 |
+
map_location="cpu",
|
1413 |
+
weights_only=True,
|
1414 |
+
)
|
1415 |
+
else:
|
1416 |
+
raise FileNotFoundError(
|
1417 |
+
f"no weights file found in {pretrained_model_path}")
|
1418 |
+
|
1419 |
+
model_state_dict = model.state_dict()
|
1420 |
+
for k in state_dict:
|
1421 |
+
if k in model_state_dict:
|
1422 |
+
if state_dict[k].shape != model_state_dict[k].shape:
|
1423 |
+
state_dict[k] = model_state_dict[k]
|
1424 |
+
# load the weights into the model
|
1425 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
1426 |
+
print(m, u)
|
1427 |
+
|
1428 |
+
return model
|
joyhallo/models/unet_3d.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is the main file for the UNet3DConditionModel, which defines the UNet3D model architecture.
|
3 |
+
|
4 |
+
The UNet3D model is a 3D convolutional neural network designed for image segmentation and
|
5 |
+
other computer vision tasks. It consists of an encoder, a decoder, and skip connections between
|
6 |
+
the corresponding layers of the encoder and decoder. The model can handle 3D data and
|
7 |
+
performs well on tasks such as image segmentation, object detection, and video analysis.
|
8 |
+
|
9 |
+
This file contains the necessary imports, the main UNet3DConditionModel class, and its
|
10 |
+
methods for setting attention slice, setting gradient checkpointing, setting attention
|
11 |
+
processor, and the forward method for model inference.
|
12 |
+
|
13 |
+
The module provides a comprehensive solution for 3D image segmentation tasks and can be
|
14 |
+
easily extended for other computer vision tasks as well.
|
15 |
+
"""
|
16 |
+
|
17 |
+
from collections import OrderedDict
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from os import PathLike
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
28 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
29 |
+
from diffusers.models.modeling_utils import ModelMixin
|
30 |
+
from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME,
|
31 |
+
BaseOutput, logging)
|
32 |
+
from safetensors.torch import load_file
|
33 |
+
|
34 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
35 |
+
from .unet_3d_blocks import (UNetMidBlock3DCrossAttn, get_down_block,
|
36 |
+
get_up_block)
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class UNet3DConditionOutput(BaseOutput):
|
43 |
+
"""
|
44 |
+
Data class that serves as the output of the UNet3DConditionModel.
|
45 |
+
|
46 |
+
Attributes:
|
47 |
+
sample (`torch.FloatTensor`):
|
48 |
+
A tensor representing the processed sample. The shape and nature of this tensor will depend on the
|
49 |
+
specific configuration of the model and the input data.
|
50 |
+
"""
|
51 |
+
sample: torch.FloatTensor
|
52 |
+
|
53 |
+
|
54 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
55 |
+
"""
|
56 |
+
A 3D UNet model designed to handle conditional image and video generation tasks. This model is particularly
|
57 |
+
suited for tasks that require the generation of 3D data, such as volumetric medical imaging or 3D video
|
58 |
+
generation, while incorporating additional conditioning information.
|
59 |
+
|
60 |
+
The model consists of an encoder-decoder structure with skip connections. It utilizes a series of downsampling
|
61 |
+
and upsampling blocks, with a middle block for further processing. Each block can be customized with different
|
62 |
+
types of layers and attention mechanisms.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
sample_size (`int`, optional): The size of the input sample.
|
66 |
+
in_channels (`int`, defaults to 8): The number of input channels.
|
67 |
+
out_channels (`int`, defaults to 8): The number of output channels.
|
68 |
+
center_input_sample (`bool`, defaults to False): Whether to center the input sample.
|
69 |
+
flip_sin_to_cos (`bool`, defaults to True): Whether to flip the sine to cosine in the time embedding.
|
70 |
+
freq_shift (`int`, defaults to 0): The frequency shift for the time embedding.
|
71 |
+
down_block_types (`Tuple[str]`): A tuple of strings specifying the types of downsampling blocks.
|
72 |
+
mid_block_type (`str`): The type of middle block.
|
73 |
+
up_block_types (`Tuple[str]`): A tuple of strings specifying the types of upsampling blocks.
|
74 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`): Whether to use only cross-attention.
|
75 |
+
block_out_channels (`Tuple[int]`): A tuple of integers specifying the output channels for each block.
|
76 |
+
layers_per_block (`int`, defaults to 2): The number of layers per block.
|
77 |
+
downsample_padding (`int`, defaults to 1): The padding used in downsampling.
|
78 |
+
mid_block_scale_factor (`float`, defaults to 1): The scale factor for the middle block.
|
79 |
+
act_fn (`str`, defaults to 'silu'): The activation function to be used.
|
80 |
+
norm_num_groups (`int`, defaults to 32): The number of groups for normalization.
|
81 |
+
norm_eps (`float`, defaults to 1e-5): The epsilon for normalization.
|
82 |
+
cross_attention_dim (`int`, defaults to 1280): The dimension for cross-attention.
|
83 |
+
attention_head_dim (`Union[int, Tuple[int]]`): The dimension for attention heads.
|
84 |
+
dual_cross_attention (`bool`, defaults to False): Whether to use dual cross-attention.
|
85 |
+
use_linear_projection (`bool`, defaults to False): Whether to use linear projection.
|
86 |
+
class_embed_type (`str`, optional): The type of class embedding.
|
87 |
+
num_class_embeds (`int`, optional): The number of class embeddings.
|
88 |
+
upcast_attention (`bool`, defaults to False): Whether to upcast attention.
|
89 |
+
resnet_time_scale_shift (`str`, defaults to 'default'): The time scale shift for the ResNet.
|
90 |
+
use_inflated_groupnorm (`bool`, defaults to False): Whether to use inflated group normalization.
|
91 |
+
use_motion_module (`bool`, defaults to False): Whether to use a motion module.
|
92 |
+
motion_module_resolutions (`Tuple[int]`): A tuple of resolutions for the motion module.
|
93 |
+
motion_module_mid_block (`bool`, defaults to False): Whether to use a motion module in the middle block.
|
94 |
+
motion_module_decoder_only (`bool`, defaults to False): Whether to use the motion module only in the decoder.
|
95 |
+
motion_module_type (`str`, optional): The type of motion module.
|
96 |
+
motion_module_kwargs (`dict`): Keyword arguments for the motion module.
|
97 |
+
unet_use_cross_frame_attention (`bool`, optional): Whether to use cross-frame attention in the UNet.
|
98 |
+
unet_use_temporal_attention (`bool`, optional): Whether to use temporal attention in the UNet.
|
99 |
+
use_audio_module (`bool`, defaults to False): Whether to use an audio module.
|
100 |
+
audio_attention_dim (`int`, defaults to 768): The dimension for audio attention.
|
101 |
+
|
102 |
+
The model supports various features such as gradient checkpointing, attention processors, and sliced attention
|
103 |
+
computation, making it flexible and efficient for different computational requirements and use cases.
|
104 |
+
|
105 |
+
The forward method of the model accepts a sample, timestep, and encoder hidden states as input, and it returns
|
106 |
+
the processed sample as output. The method also supports additional conditioning information such as class
|
107 |
+
labels, audio embeddings, and masks for specialized tasks.
|
108 |
+
|
109 |
+
The from_pretrained_2d class method allows loading a pre-trained 2D UNet model and adapting it for 3D tasks by
|
110 |
+
incorporating motion modules and other 3D specific features.
|
111 |
+
"""
|
112 |
+
|
113 |
+
_supports_gradient_checkpointing = True
|
114 |
+
|
115 |
+
@register_to_config
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
sample_size: Optional[int] = None,
|
119 |
+
in_channels: int = 8,
|
120 |
+
out_channels: int = 8,
|
121 |
+
flip_sin_to_cos: bool = True,
|
122 |
+
freq_shift: int = 0,
|
123 |
+
down_block_types: Tuple[str] = (
|
124 |
+
"CrossAttnDownBlock3D",
|
125 |
+
"CrossAttnDownBlock3D",
|
126 |
+
"CrossAttnDownBlock3D",
|
127 |
+
"DownBlock3D",
|
128 |
+
),
|
129 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
130 |
+
up_block_types: Tuple[str] = (
|
131 |
+
"UpBlock3D",
|
132 |
+
"CrossAttnUpBlock3D",
|
133 |
+
"CrossAttnUpBlock3D",
|
134 |
+
"CrossAttnUpBlock3D",
|
135 |
+
),
|
136 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
137 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
138 |
+
layers_per_block: int = 2,
|
139 |
+
downsample_padding: int = 1,
|
140 |
+
mid_block_scale_factor: float = 1,
|
141 |
+
act_fn: str = "silu",
|
142 |
+
norm_num_groups: int = 32,
|
143 |
+
norm_eps: float = 1e-5,
|
144 |
+
cross_attention_dim: int = 1280,
|
145 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
146 |
+
dual_cross_attention: bool = False,
|
147 |
+
use_linear_projection: bool = False,
|
148 |
+
class_embed_type: Optional[str] = None,
|
149 |
+
num_class_embeds: Optional[int] = None,
|
150 |
+
upcast_attention: bool = False,
|
151 |
+
resnet_time_scale_shift: str = "default",
|
152 |
+
use_inflated_groupnorm=False,
|
153 |
+
# Additional
|
154 |
+
use_motion_module=False,
|
155 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
156 |
+
motion_module_mid_block=False,
|
157 |
+
motion_module_decoder_only=False,
|
158 |
+
motion_module_type=None,
|
159 |
+
motion_module_kwargs=None,
|
160 |
+
unet_use_cross_frame_attention=None,
|
161 |
+
unet_use_temporal_attention=None,
|
162 |
+
# audio
|
163 |
+
use_audio_module=False,
|
164 |
+
audio_attention_dim=768,
|
165 |
+
stack_enable_blocks_name=None,
|
166 |
+
stack_enable_blocks_depth=None,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
|
170 |
+
self.sample_size = sample_size
|
171 |
+
time_embed_dim = block_out_channels[0] * 4
|
172 |
+
|
173 |
+
# input
|
174 |
+
self.conv_in = InflatedConv3d(
|
175 |
+
in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
|
176 |
+
)
|
177 |
+
|
178 |
+
# time
|
179 |
+
self.time_proj = Timesteps(
|
180 |
+
block_out_channels[0], flip_sin_to_cos, freq_shift)
|
181 |
+
timestep_input_dim = block_out_channels[0]
|
182 |
+
|
183 |
+
self.time_embedding = TimestepEmbedding(
|
184 |
+
timestep_input_dim, time_embed_dim)
|
185 |
+
|
186 |
+
# class embedding
|
187 |
+
if class_embed_type is None and num_class_embeds is not None:
|
188 |
+
self.class_embedding = nn.Embedding(
|
189 |
+
num_class_embeds, time_embed_dim)
|
190 |
+
elif class_embed_type == "timestep":
|
191 |
+
self.class_embedding = TimestepEmbedding(
|
192 |
+
timestep_input_dim, time_embed_dim)
|
193 |
+
elif class_embed_type == "identity":
|
194 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
195 |
+
else:
|
196 |
+
self.class_embedding = None
|
197 |
+
|
198 |
+
self.down_blocks = nn.ModuleList([])
|
199 |
+
self.mid_block = None
|
200 |
+
self.up_blocks = nn.ModuleList([])
|
201 |
+
|
202 |
+
if isinstance(only_cross_attention, bool):
|
203 |
+
only_cross_attention = [
|
204 |
+
only_cross_attention] * len(down_block_types)
|
205 |
+
|
206 |
+
if isinstance(attention_head_dim, int):
|
207 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
208 |
+
|
209 |
+
# down
|
210 |
+
output_channel = block_out_channels[0]
|
211 |
+
for i, down_block_type in enumerate(down_block_types):
|
212 |
+
res = 2**i
|
213 |
+
input_channel = output_channel
|
214 |
+
output_channel = block_out_channels[i]
|
215 |
+
is_final_block = i == len(block_out_channels) - 1
|
216 |
+
|
217 |
+
down_block = get_down_block(
|
218 |
+
down_block_type,
|
219 |
+
num_layers=layers_per_block,
|
220 |
+
in_channels=input_channel,
|
221 |
+
out_channels=output_channel,
|
222 |
+
temb_channels=time_embed_dim,
|
223 |
+
add_downsample=not is_final_block,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
cross_attention_dim=cross_attention_dim,
|
228 |
+
attn_num_head_channels=attention_head_dim[i],
|
229 |
+
downsample_padding=downsample_padding,
|
230 |
+
dual_cross_attention=dual_cross_attention,
|
231 |
+
use_linear_projection=use_linear_projection,
|
232 |
+
only_cross_attention=only_cross_attention[i],
|
233 |
+
upcast_attention=upcast_attention,
|
234 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
235 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
236 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
237 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
238 |
+
use_motion_module=use_motion_module
|
239 |
+
and (res in motion_module_resolutions)
|
240 |
+
and (not motion_module_decoder_only),
|
241 |
+
motion_module_type=motion_module_type,
|
242 |
+
motion_module_kwargs=motion_module_kwargs,
|
243 |
+
use_audio_module=use_audio_module,
|
244 |
+
audio_attention_dim=audio_attention_dim,
|
245 |
+
depth=i,
|
246 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
247 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
248 |
+
)
|
249 |
+
self.down_blocks.append(down_block)
|
250 |
+
|
251 |
+
# mid
|
252 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
253 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
254 |
+
in_channels=block_out_channels[-1],
|
255 |
+
temb_channels=time_embed_dim,
|
256 |
+
resnet_eps=norm_eps,
|
257 |
+
resnet_act_fn=act_fn,
|
258 |
+
output_scale_factor=mid_block_scale_factor,
|
259 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
260 |
+
cross_attention_dim=cross_attention_dim,
|
261 |
+
attn_num_head_channels=attention_head_dim[-1],
|
262 |
+
resnet_groups=norm_num_groups,
|
263 |
+
dual_cross_attention=dual_cross_attention,
|
264 |
+
use_linear_projection=use_linear_projection,
|
265 |
+
upcast_attention=upcast_attention,
|
266 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
267 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
268 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
269 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
270 |
+
motion_module_type=motion_module_type,
|
271 |
+
motion_module_kwargs=motion_module_kwargs,
|
272 |
+
use_audio_module=use_audio_module,
|
273 |
+
audio_attention_dim=audio_attention_dim,
|
274 |
+
depth=3,
|
275 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
276 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
280 |
+
|
281 |
+
# count how many layers upsample the videos
|
282 |
+
self.num_upsamplers = 0
|
283 |
+
|
284 |
+
# up
|
285 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
286 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
287 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
288 |
+
output_channel = reversed_block_out_channels[0]
|
289 |
+
for i, up_block_type in enumerate(up_block_types):
|
290 |
+
res = 2 ** (3 - i)
|
291 |
+
is_final_block = i == len(block_out_channels) - 1
|
292 |
+
|
293 |
+
prev_output_channel = output_channel
|
294 |
+
output_channel = reversed_block_out_channels[i]
|
295 |
+
input_channel = reversed_block_out_channels[
|
296 |
+
min(i + 1, len(block_out_channels) - 1)
|
297 |
+
]
|
298 |
+
|
299 |
+
# add upsample block for all BUT final layer
|
300 |
+
if not is_final_block:
|
301 |
+
add_upsample = True
|
302 |
+
self.num_upsamplers += 1
|
303 |
+
else:
|
304 |
+
add_upsample = False
|
305 |
+
|
306 |
+
up_block = get_up_block(
|
307 |
+
up_block_type,
|
308 |
+
num_layers=layers_per_block + 1,
|
309 |
+
in_channels=input_channel,
|
310 |
+
out_channels=output_channel,
|
311 |
+
prev_output_channel=prev_output_channel,
|
312 |
+
temb_channels=time_embed_dim,
|
313 |
+
add_upsample=add_upsample,
|
314 |
+
resnet_eps=norm_eps,
|
315 |
+
resnet_act_fn=act_fn,
|
316 |
+
resnet_groups=norm_num_groups,
|
317 |
+
cross_attention_dim=cross_attention_dim,
|
318 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
319 |
+
dual_cross_attention=dual_cross_attention,
|
320 |
+
use_linear_projection=use_linear_projection,
|
321 |
+
only_cross_attention=only_cross_attention[i],
|
322 |
+
upcast_attention=upcast_attention,
|
323 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
324 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
325 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
326 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
327 |
+
use_motion_module=use_motion_module
|
328 |
+
and (res in motion_module_resolutions),
|
329 |
+
motion_module_type=motion_module_type,
|
330 |
+
motion_module_kwargs=motion_module_kwargs,
|
331 |
+
use_audio_module=use_audio_module,
|
332 |
+
audio_attention_dim=audio_attention_dim,
|
333 |
+
depth=3-i,
|
334 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
335 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
336 |
+
)
|
337 |
+
self.up_blocks.append(up_block)
|
338 |
+
prev_output_channel = output_channel
|
339 |
+
|
340 |
+
# out
|
341 |
+
if use_inflated_groupnorm:
|
342 |
+
self.conv_norm_out = InflatedGroupNorm(
|
343 |
+
num_channels=block_out_channels[0],
|
344 |
+
num_groups=norm_num_groups,
|
345 |
+
eps=norm_eps,
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
self.conv_norm_out = nn.GroupNorm(
|
349 |
+
num_channels=block_out_channels[0],
|
350 |
+
num_groups=norm_num_groups,
|
351 |
+
eps=norm_eps,
|
352 |
+
)
|
353 |
+
self.conv_act = nn.SiLU()
|
354 |
+
self.conv_out = InflatedConv3d(
|
355 |
+
block_out_channels[0], out_channels, kernel_size=3, padding=1
|
356 |
+
)
|
357 |
+
|
358 |
+
@property
|
359 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
360 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
361 |
+
r"""
|
362 |
+
Returns:
|
363 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
364 |
+
indexed by its weight name.
|
365 |
+
"""
|
366 |
+
# set recursively
|
367 |
+
processors = {}
|
368 |
+
|
369 |
+
def fn_recursive_add_processors(
|
370 |
+
name: str,
|
371 |
+
module: torch.nn.Module,
|
372 |
+
processors: Dict[str, AttentionProcessor],
|
373 |
+
):
|
374 |
+
if hasattr(module, "set_processor"):
|
375 |
+
processors[f"{name}.processor"] = module.processor
|
376 |
+
|
377 |
+
for sub_name, child in module.named_children():
|
378 |
+
if "temporal_transformer" not in sub_name:
|
379 |
+
fn_recursive_add_processors(
|
380 |
+
f"{name}.{sub_name}", child, processors)
|
381 |
+
|
382 |
+
return processors
|
383 |
+
|
384 |
+
for name, module in self.named_children():
|
385 |
+
if "temporal_transformer" not in name:
|
386 |
+
fn_recursive_add_processors(name, module, processors)
|
387 |
+
|
388 |
+
return processors
|
389 |
+
|
390 |
+
def set_attention_slice(self, slice_size):
|
391 |
+
r"""
|
392 |
+
Enable sliced attention computation.
|
393 |
+
|
394 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
395 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
399 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
400 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
401 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
402 |
+
must be a multiple of `slice_size`.
|
403 |
+
"""
|
404 |
+
sliceable_head_dims = []
|
405 |
+
|
406 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
407 |
+
if hasattr(module, "set_attention_slice"):
|
408 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
409 |
+
|
410 |
+
for child in module.children():
|
411 |
+
fn_recursive_retrieve_slicable_dims(child)
|
412 |
+
|
413 |
+
# retrieve number of attention layers
|
414 |
+
for module in self.children():
|
415 |
+
fn_recursive_retrieve_slicable_dims(module)
|
416 |
+
|
417 |
+
num_slicable_layers = len(sliceable_head_dims)
|
418 |
+
|
419 |
+
if slice_size == "auto":
|
420 |
+
# half the attention head size is usually a good trade-off between
|
421 |
+
# speed and memory
|
422 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
423 |
+
elif slice_size == "max":
|
424 |
+
# make smallest slice possible
|
425 |
+
slice_size = num_slicable_layers * [1]
|
426 |
+
|
427 |
+
slice_size = (
|
428 |
+
num_slicable_layers * [slice_size]
|
429 |
+
if not isinstance(slice_size, list)
|
430 |
+
else slice_size
|
431 |
+
)
|
432 |
+
|
433 |
+
if len(slice_size) != len(sliceable_head_dims):
|
434 |
+
raise ValueError(
|
435 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
436 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
437 |
+
)
|
438 |
+
|
439 |
+
for i, size in enumerate(slice_size):
|
440 |
+
dim = sliceable_head_dims[i]
|
441 |
+
if size is not None and size > dim:
|
442 |
+
raise ValueError(
|
443 |
+
f"size {size} has to be smaller or equal to {dim}.")
|
444 |
+
|
445 |
+
# Recursively walk through all the children.
|
446 |
+
# Any children which exposes the set_attention_slice method
|
447 |
+
# gets the message
|
448 |
+
def fn_recursive_set_attention_slice(
|
449 |
+
module: torch.nn.Module, slice_size: List[int]
|
450 |
+
):
|
451 |
+
if hasattr(module, "set_attention_slice"):
|
452 |
+
module.set_attention_slice(slice_size.pop())
|
453 |
+
|
454 |
+
for child in module.children():
|
455 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
456 |
+
|
457 |
+
reversed_slice_size = list(reversed(slice_size))
|
458 |
+
for module in self.children():
|
459 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
460 |
+
|
461 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
462 |
+
if hasattr(module, "gradient_checkpointing"):
|
463 |
+
module.gradient_checkpointing = value
|
464 |
+
|
465 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
466 |
+
def set_attn_processor(
|
467 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
468 |
+
):
|
469 |
+
r"""
|
470 |
+
Sets the attention processor to use to compute attention.
|
471 |
+
|
472 |
+
Parameters:
|
473 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
474 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
475 |
+
for **all** `Attention` layers.
|
476 |
+
|
477 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
478 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
479 |
+
|
480 |
+
"""
|
481 |
+
count = len(self.attn_processors.keys())
|
482 |
+
|
483 |
+
if isinstance(processor, dict) and len(processor) != count:
|
484 |
+
raise ValueError(
|
485 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
486 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
487 |
+
)
|
488 |
+
|
489 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
490 |
+
if hasattr(module, "set_processor"):
|
491 |
+
if not isinstance(processor, dict):
|
492 |
+
module.set_processor(processor)
|
493 |
+
else:
|
494 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
495 |
+
|
496 |
+
for sub_name, child in module.named_children():
|
497 |
+
if "temporal_transformer" not in sub_name:
|
498 |
+
fn_recursive_attn_processor(
|
499 |
+
f"{name}.{sub_name}", child, processor)
|
500 |
+
|
501 |
+
for name, module in self.named_children():
|
502 |
+
if "temporal_transformer" not in name:
|
503 |
+
fn_recursive_attn_processor(name, module, processor)
|
504 |
+
|
505 |
+
def forward(
|
506 |
+
self,
|
507 |
+
sample: torch.FloatTensor,
|
508 |
+
timestep: Union[torch.Tensor, float, int],
|
509 |
+
encoder_hidden_states: torch.Tensor,
|
510 |
+
audio_embedding: Optional[torch.Tensor] = None,
|
511 |
+
class_labels: Optional[torch.Tensor] = None,
|
512 |
+
mask_cond_fea: Optional[torch.Tensor] = None,
|
513 |
+
attention_mask: Optional[torch.Tensor] = None,
|
514 |
+
full_mask: Optional[torch.Tensor] = None,
|
515 |
+
face_mask: Optional[torch.Tensor] = None,
|
516 |
+
lip_mask: Optional[torch.Tensor] = None,
|
517 |
+
motion_scale: Optional[torch.Tensor] = None,
|
518 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
519 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
520 |
+
return_dict: bool = True,
|
521 |
+
# start: bool = False,
|
522 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
523 |
+
r"""
|
524 |
+
Args:
|
525 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
526 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
527 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states, face_emb
|
528 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
529 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
530 |
+
|
531 |
+
mask_cond_fea (`torch.FloatTensor`, *optional*): mask_feature tensor
|
532 |
+
audio_embedding (`torch.FloatTensor`, *optional*): audio embedding tensor, audio_emb
|
533 |
+
full_mask (`torch.FloatTensor`, *optional*): full mask tensor, full_mask
|
534 |
+
face_mask (`torch.FloatTensor`, *optional*): face mask tensor, face_mask
|
535 |
+
lip_mask (`torch.FloatTensor`, *optional*): lip mask tensor, lip_mask
|
536 |
+
|
537 |
+
Returns:
|
538 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
539 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
540 |
+
returning a tuple, the first element is the sample tensor.
|
541 |
+
"""
|
542 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
543 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
544 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
545 |
+
# on the fly if necessary.
|
546 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
547 |
+
|
548 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
549 |
+
forward_upsample_size = False
|
550 |
+
upsample_size = None
|
551 |
+
|
552 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
553 |
+
logger.info(
|
554 |
+
"Forward upsample size to force interpolation output size.")
|
555 |
+
forward_upsample_size = True
|
556 |
+
|
557 |
+
# prepare attention_mask
|
558 |
+
if attention_mask is not None:
|
559 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
560 |
+
attention_mask = attention_mask.unsqueeze(1)
|
561 |
+
|
562 |
+
# center input if necessary
|
563 |
+
if self.config.center_input_sample:
|
564 |
+
sample = 2 * sample - 1.0
|
565 |
+
|
566 |
+
# time
|
567 |
+
timesteps = timestep
|
568 |
+
if not torch.is_tensor(timesteps):
|
569 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
570 |
+
is_mps = sample.device.type == "mps"
|
571 |
+
if isinstance(timestep, float):
|
572 |
+
dtype = torch.float32 if is_mps else torch.float64
|
573 |
+
else:
|
574 |
+
dtype = torch.int32 if is_mps else torch.int64
|
575 |
+
timesteps = torch.tensor(
|
576 |
+
[timesteps], dtype=dtype, device=sample.device)
|
577 |
+
elif len(timesteps.shape) == 0:
|
578 |
+
timesteps = timesteps[None].to(sample.device)
|
579 |
+
|
580 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
581 |
+
timesteps = timesteps.expand(sample.shape[0])
|
582 |
+
|
583 |
+
t_emb = self.time_proj(timesteps)
|
584 |
+
|
585 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
586 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
587 |
+
# there might be better ways to encapsulate this.
|
588 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
589 |
+
emb = self.time_embedding(t_emb)
|
590 |
+
|
591 |
+
if self.class_embedding is not None:
|
592 |
+
if class_labels is None:
|
593 |
+
raise ValueError(
|
594 |
+
"class_labels should be provided when num_class_embeds > 0"
|
595 |
+
)
|
596 |
+
|
597 |
+
if self.config.class_embed_type == "timestep":
|
598 |
+
class_labels = self.time_proj(class_labels)
|
599 |
+
|
600 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
601 |
+
emb = emb + class_emb
|
602 |
+
|
603 |
+
# pre-process
|
604 |
+
sample = self.conv_in(sample)
|
605 |
+
if mask_cond_fea is not None:
|
606 |
+
sample = sample + mask_cond_fea
|
607 |
+
|
608 |
+
# down
|
609 |
+
down_block_res_samples = (sample,)
|
610 |
+
for downsample_block in self.down_blocks:
|
611 |
+
if (
|
612 |
+
hasattr(downsample_block, "has_cross_attention")
|
613 |
+
and downsample_block.has_cross_attention
|
614 |
+
):
|
615 |
+
sample, res_samples = downsample_block(
|
616 |
+
hidden_states=sample,
|
617 |
+
temb=emb,
|
618 |
+
encoder_hidden_states=encoder_hidden_states,
|
619 |
+
attention_mask=attention_mask,
|
620 |
+
full_mask=full_mask,
|
621 |
+
face_mask=face_mask,
|
622 |
+
lip_mask=lip_mask,
|
623 |
+
audio_embedding=audio_embedding,
|
624 |
+
motion_scale=motion_scale,
|
625 |
+
)
|
626 |
+
# print("")
|
627 |
+
else:
|
628 |
+
sample, res_samples = downsample_block(
|
629 |
+
hidden_states=sample,
|
630 |
+
temb=emb,
|
631 |
+
encoder_hidden_states=encoder_hidden_states,
|
632 |
+
# audio_embedding=audio_embedding,
|
633 |
+
)
|
634 |
+
# print("")
|
635 |
+
|
636 |
+
down_block_res_samples += res_samples
|
637 |
+
|
638 |
+
if down_block_additional_residuals is not None:
|
639 |
+
new_down_block_res_samples = ()
|
640 |
+
|
641 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
642 |
+
down_block_res_samples, down_block_additional_residuals
|
643 |
+
):
|
644 |
+
down_block_res_sample = (
|
645 |
+
down_block_res_sample + down_block_additional_residual
|
646 |
+
)
|
647 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
648 |
+
|
649 |
+
down_block_res_samples = new_down_block_res_samples
|
650 |
+
|
651 |
+
# mid
|
652 |
+
sample = self.mid_block(
|
653 |
+
sample,
|
654 |
+
emb,
|
655 |
+
encoder_hidden_states=encoder_hidden_states,
|
656 |
+
attention_mask=attention_mask,
|
657 |
+
full_mask=full_mask,
|
658 |
+
face_mask=face_mask,
|
659 |
+
lip_mask=lip_mask,
|
660 |
+
audio_embedding=audio_embedding,
|
661 |
+
motion_scale=motion_scale,
|
662 |
+
)
|
663 |
+
|
664 |
+
if mid_block_additional_residual is not None:
|
665 |
+
sample = sample + mid_block_additional_residual
|
666 |
+
|
667 |
+
# up
|
668 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
669 |
+
is_final_block = i == len(self.up_blocks) - 1
|
670 |
+
|
671 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
672 |
+
down_block_res_samples = down_block_res_samples[
|
673 |
+
: -len(upsample_block.resnets)
|
674 |
+
]
|
675 |
+
|
676 |
+
# if we have not reached the final block and need to forward the
|
677 |
+
# upsample size, we do it here
|
678 |
+
if not is_final_block and forward_upsample_size:
|
679 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
680 |
+
|
681 |
+
if (
|
682 |
+
hasattr(upsample_block, "has_cross_attention")
|
683 |
+
and upsample_block.has_cross_attention
|
684 |
+
):
|
685 |
+
sample = upsample_block(
|
686 |
+
hidden_states=sample,
|
687 |
+
temb=emb,
|
688 |
+
res_hidden_states_tuple=res_samples,
|
689 |
+
encoder_hidden_states=encoder_hidden_states,
|
690 |
+
upsample_size=upsample_size,
|
691 |
+
attention_mask=attention_mask,
|
692 |
+
full_mask=full_mask,
|
693 |
+
face_mask=face_mask,
|
694 |
+
lip_mask=lip_mask,
|
695 |
+
audio_embedding=audio_embedding,
|
696 |
+
motion_scale=motion_scale,
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
sample = upsample_block(
|
700 |
+
hidden_states=sample,
|
701 |
+
temb=emb,
|
702 |
+
res_hidden_states_tuple=res_samples,
|
703 |
+
upsample_size=upsample_size,
|
704 |
+
encoder_hidden_states=encoder_hidden_states,
|
705 |
+
# audio_embedding=audio_embedding,
|
706 |
+
)
|
707 |
+
|
708 |
+
# post-process
|
709 |
+
sample = self.conv_norm_out(sample)
|
710 |
+
sample = self.conv_act(sample)
|
711 |
+
sample = self.conv_out(sample)
|
712 |
+
|
713 |
+
if not return_dict:
|
714 |
+
return (sample,)
|
715 |
+
|
716 |
+
return UNet3DConditionOutput(sample=sample)
|
717 |
+
|
718 |
+
@classmethod
|
719 |
+
def from_pretrained_2d(
|
720 |
+
cls,
|
721 |
+
pretrained_model_path: PathLike,
|
722 |
+
motion_module_path: PathLike,
|
723 |
+
subfolder=None,
|
724 |
+
unet_additional_kwargs=None,
|
725 |
+
mm_zero_proj_out=False,
|
726 |
+
use_landmark=True,
|
727 |
+
):
|
728 |
+
"""
|
729 |
+
Load a pre-trained 2D UNet model from a given directory.
|
730 |
+
|
731 |
+
Parameters:
|
732 |
+
pretrained_model_path (`str` or `PathLike`):
|
733 |
+
Path to the directory containing a pre-trained 2D UNet model.
|
734 |
+
dtype (`torch.dtype`, *optional*):
|
735 |
+
The data type of the loaded model. If not provided, the default data type is used.
|
736 |
+
device (`torch.device`, *optional*):
|
737 |
+
The device on which the loaded model will be placed. If not provided, the default device is used.
|
738 |
+
**kwargs (`Any`):
|
739 |
+
Additional keyword arguments passed to the model.
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
`UNet3DConditionModel`:
|
743 |
+
The loaded 2D UNet model.
|
744 |
+
"""
|
745 |
+
pretrained_model_path = Path(pretrained_model_path)
|
746 |
+
motion_module_path = Path(motion_module_path)
|
747 |
+
if subfolder is not None:
|
748 |
+
pretrained_model_path = pretrained_model_path.joinpath(subfolder)
|
749 |
+
logger.info(
|
750 |
+
f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
|
751 |
+
)
|
752 |
+
|
753 |
+
config_file = pretrained_model_path / "config.json"
|
754 |
+
if not (config_file.exists() and config_file.is_file()):
|
755 |
+
raise RuntimeError(
|
756 |
+
f"{config_file} does not exist or is not a file")
|
757 |
+
|
758 |
+
unet_config = cls.load_config(config_file)
|
759 |
+
unet_config["_class_name"] = cls.__name__
|
760 |
+
unet_config["down_block_types"] = [
|
761 |
+
"CrossAttnDownBlock3D",
|
762 |
+
"CrossAttnDownBlock3D",
|
763 |
+
"CrossAttnDownBlock3D",
|
764 |
+
"DownBlock3D",
|
765 |
+
]
|
766 |
+
unet_config["up_block_types"] = [
|
767 |
+
"UpBlock3D",
|
768 |
+
"CrossAttnUpBlock3D",
|
769 |
+
"CrossAttnUpBlock3D",
|
770 |
+
"CrossAttnUpBlock3D",
|
771 |
+
]
|
772 |
+
unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
|
773 |
+
if use_landmark:
|
774 |
+
unet_config["in_channels"] = 8
|
775 |
+
unet_config["out_channels"] = 8
|
776 |
+
|
777 |
+
model = cls.from_config(unet_config, **unet_additional_kwargs)
|
778 |
+
# load the vanilla weights
|
779 |
+
if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
|
780 |
+
logger.debug(
|
781 |
+
f"loading safeTensors weights from {pretrained_model_path} ..."
|
782 |
+
)
|
783 |
+
state_dict = load_file(
|
784 |
+
pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
|
785 |
+
)
|
786 |
+
|
787 |
+
elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
|
788 |
+
logger.debug(f"loading weights from {pretrained_model_path} ...")
|
789 |
+
state_dict = torch.load(
|
790 |
+
pretrained_model_path.joinpath(WEIGHTS_NAME),
|
791 |
+
map_location="cpu",
|
792 |
+
weights_only=True,
|
793 |
+
)
|
794 |
+
else:
|
795 |
+
raise FileNotFoundError(
|
796 |
+
f"no weights file found in {pretrained_model_path}")
|
797 |
+
|
798 |
+
# load the motion module weights
|
799 |
+
if motion_module_path.exists() and motion_module_path.is_file():
|
800 |
+
if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
|
801 |
+
print(
|
802 |
+
f"Load motion module params from {motion_module_path}")
|
803 |
+
motion_state_dict = torch.load(
|
804 |
+
motion_module_path, map_location="cpu", weights_only=True
|
805 |
+
)
|
806 |
+
elif motion_module_path.suffix.lower() == ".safetensors":
|
807 |
+
motion_state_dict = load_file(motion_module_path, device="cpu")
|
808 |
+
else:
|
809 |
+
raise RuntimeError(
|
810 |
+
f"unknown file format for motion module weights: {motion_module_path.suffix}"
|
811 |
+
)
|
812 |
+
if mm_zero_proj_out:
|
813 |
+
logger.info(
|
814 |
+
"Zero initialize proj_out layers in motion module...")
|
815 |
+
new_motion_state_dict = OrderedDict()
|
816 |
+
for k in motion_state_dict:
|
817 |
+
if "proj_out" in k:
|
818 |
+
continue
|
819 |
+
new_motion_state_dict[k] = motion_state_dict[k]
|
820 |
+
motion_state_dict = new_motion_state_dict
|
821 |
+
|
822 |
+
# merge the state dicts
|
823 |
+
state_dict.update(motion_state_dict)
|
824 |
+
|
825 |
+
model_state_dict = model.state_dict()
|
826 |
+
for k in state_dict:
|
827 |
+
if k in model_state_dict:
|
828 |
+
if state_dict[k].shape != model_state_dict[k].shape:
|
829 |
+
state_dict[k] = model_state_dict[k]
|
830 |
+
# load the weights into the model
|
831 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
832 |
+
logger.debug(
|
833 |
+
f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
834 |
+
|
835 |
+
params = [
|
836 |
+
p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
|
837 |
+
]
|
838 |
+
logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
|
839 |
+
|
840 |
+
return model
|
joyhallo/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,1398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module defines various 3D UNet blocks used in the video model.
|
3 |
+
|
4 |
+
The blocks include:
|
5 |
+
- UNetMidBlock3DCrossAttn: The middle block of the UNet with cross attention.
|
6 |
+
- CrossAttnDownBlock3D: The downsampling block with cross attention.
|
7 |
+
- DownBlock3D: The standard downsampling block without cross attention.
|
8 |
+
- CrossAttnUpBlock3D: The upsampling block with cross attention.
|
9 |
+
- UpBlock3D: The standard upsampling block without cross attention.
|
10 |
+
|
11 |
+
These blocks are used to construct the 3D UNet architecture for video-related tasks.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from einops import rearrange
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .motion_module import get_motion_module
|
19 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
20 |
+
from .transformer_3d import Transformer3DModel
|
21 |
+
|
22 |
+
|
23 |
+
def get_down_block(
|
24 |
+
down_block_type,
|
25 |
+
num_layers,
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
temb_channels,
|
29 |
+
add_downsample,
|
30 |
+
resnet_eps,
|
31 |
+
resnet_act_fn,
|
32 |
+
attn_num_head_channels,
|
33 |
+
resnet_groups=None,
|
34 |
+
cross_attention_dim=None,
|
35 |
+
audio_attention_dim=None,
|
36 |
+
downsample_padding=None,
|
37 |
+
dual_cross_attention=False,
|
38 |
+
use_linear_projection=False,
|
39 |
+
only_cross_attention=False,
|
40 |
+
upcast_attention=False,
|
41 |
+
resnet_time_scale_shift="default",
|
42 |
+
unet_use_cross_frame_attention=None,
|
43 |
+
unet_use_temporal_attention=None,
|
44 |
+
use_inflated_groupnorm=None,
|
45 |
+
use_motion_module=None,
|
46 |
+
motion_module_type=None,
|
47 |
+
motion_module_kwargs=None,
|
48 |
+
use_audio_module=None,
|
49 |
+
depth=0,
|
50 |
+
stack_enable_blocks_name=None,
|
51 |
+
stack_enable_blocks_depth=None,
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Factory function to instantiate a down-block module for the 3D UNet architecture.
|
55 |
+
|
56 |
+
Down blocks are used in the downsampling part of the U-Net to reduce the spatial dimensions
|
57 |
+
of the feature maps while increasing the depth. This function can create blocks with or without
|
58 |
+
cross attention based on the specified parameters.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
- down_block_type (str): The type of down block to instantiate.
|
62 |
+
- num_layers (int): The number of layers in the block.
|
63 |
+
- in_channels (int): The number of input channels.
|
64 |
+
- out_channels (int): The number of output channels.
|
65 |
+
- temb_channels (int): The number of token embedding channels.
|
66 |
+
- add_downsample (bool): Flag to add a downsampling layer.
|
67 |
+
- resnet_eps (float): Epsilon for residual block stability.
|
68 |
+
- resnet_act_fn (callable): Activation function for the residual block.
|
69 |
+
- ... (remaining parameters): Additional parameters for configuring the block.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
- nn.Module: An instance of a down-sampling block module.
|
73 |
+
"""
|
74 |
+
down_block_type = (
|
75 |
+
down_block_type[7:]
|
76 |
+
if down_block_type.startswith("UNetRes")
|
77 |
+
else down_block_type
|
78 |
+
)
|
79 |
+
if down_block_type == "DownBlock3D":
|
80 |
+
return DownBlock3D(
|
81 |
+
num_layers=num_layers,
|
82 |
+
in_channels=in_channels,
|
83 |
+
out_channels=out_channels,
|
84 |
+
temb_channels=temb_channels,
|
85 |
+
add_downsample=add_downsample,
|
86 |
+
resnet_eps=resnet_eps,
|
87 |
+
resnet_act_fn=resnet_act_fn,
|
88 |
+
resnet_groups=resnet_groups,
|
89 |
+
downsample_padding=downsample_padding,
|
90 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
91 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
92 |
+
use_motion_module=use_motion_module,
|
93 |
+
motion_module_type=motion_module_type,
|
94 |
+
motion_module_kwargs=motion_module_kwargs,
|
95 |
+
)
|
96 |
+
|
97 |
+
if down_block_type == "CrossAttnDownBlock3D":
|
98 |
+
if cross_attention_dim is None:
|
99 |
+
raise ValueError(
|
100 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock3D"
|
101 |
+
)
|
102 |
+
return CrossAttnDownBlock3D(
|
103 |
+
num_layers=num_layers,
|
104 |
+
in_channels=in_channels,
|
105 |
+
out_channels=out_channels,
|
106 |
+
temb_channels=temb_channels,
|
107 |
+
add_downsample=add_downsample,
|
108 |
+
resnet_eps=resnet_eps,
|
109 |
+
resnet_act_fn=resnet_act_fn,
|
110 |
+
resnet_groups=resnet_groups,
|
111 |
+
downsample_padding=downsample_padding,
|
112 |
+
cross_attention_dim=cross_attention_dim,
|
113 |
+
audio_attention_dim=audio_attention_dim,
|
114 |
+
attn_num_head_channels=attn_num_head_channels,
|
115 |
+
dual_cross_attention=dual_cross_attention,
|
116 |
+
use_linear_projection=use_linear_projection,
|
117 |
+
only_cross_attention=only_cross_attention,
|
118 |
+
upcast_attention=upcast_attention,
|
119 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
120 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
121 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
122 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
123 |
+
use_motion_module=use_motion_module,
|
124 |
+
motion_module_type=motion_module_type,
|
125 |
+
motion_module_kwargs=motion_module_kwargs,
|
126 |
+
use_audio_module=use_audio_module,
|
127 |
+
depth=depth,
|
128 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
129 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
130 |
+
)
|
131 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
132 |
+
|
133 |
+
|
134 |
+
def get_up_block(
|
135 |
+
up_block_type,
|
136 |
+
num_layers,
|
137 |
+
in_channels,
|
138 |
+
out_channels,
|
139 |
+
prev_output_channel,
|
140 |
+
temb_channels,
|
141 |
+
add_upsample,
|
142 |
+
resnet_eps,
|
143 |
+
resnet_act_fn,
|
144 |
+
attn_num_head_channels,
|
145 |
+
resnet_groups=None,
|
146 |
+
cross_attention_dim=None,
|
147 |
+
audio_attention_dim=None,
|
148 |
+
dual_cross_attention=False,
|
149 |
+
use_linear_projection=False,
|
150 |
+
only_cross_attention=False,
|
151 |
+
upcast_attention=False,
|
152 |
+
resnet_time_scale_shift="default",
|
153 |
+
unet_use_cross_frame_attention=None,
|
154 |
+
unet_use_temporal_attention=None,
|
155 |
+
use_inflated_groupnorm=None,
|
156 |
+
use_motion_module=None,
|
157 |
+
motion_module_type=None,
|
158 |
+
motion_module_kwargs=None,
|
159 |
+
use_audio_module=None,
|
160 |
+
depth=0,
|
161 |
+
stack_enable_blocks_name=None,
|
162 |
+
stack_enable_blocks_depth=None,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Factory function to instantiate an up-block module for the 3D UNet architecture.
|
166 |
+
|
167 |
+
Up blocks are used in the upsampling part of the U-Net to increase the spatial dimensions
|
168 |
+
of the feature maps while decreasing the depth. This function can create blocks with or without
|
169 |
+
cross attention based on the specified parameters.
|
170 |
+
|
171 |
+
Parameters:
|
172 |
+
- up_block_type (str): The type of up block to instantiate.
|
173 |
+
- num_layers (int): The number of layers in the block.
|
174 |
+
- in_channels (int): The number of input channels.
|
175 |
+
- out_channels (int): The number of output channels.
|
176 |
+
- prev_output_channel (int): The number of channels from the previous layer's output.
|
177 |
+
- temb_channels (int): The number of token embedding channels.
|
178 |
+
- add_upsample (bool): Flag to add an upsampling layer.
|
179 |
+
- resnet_eps (float): Epsilon for residual block stability.
|
180 |
+
- resnet_act_fn (callable): Activation function for the residual block.
|
181 |
+
- ... (remaining parameters): Additional parameters for configuring the block.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
- nn.Module: An instance of an up-sampling block module.
|
185 |
+
"""
|
186 |
+
up_block_type = (
|
187 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
188 |
+
)
|
189 |
+
if up_block_type == "UpBlock3D":
|
190 |
+
return UpBlock3D(
|
191 |
+
num_layers=num_layers,
|
192 |
+
in_channels=in_channels,
|
193 |
+
out_channels=out_channels,
|
194 |
+
prev_output_channel=prev_output_channel,
|
195 |
+
temb_channels=temb_channels,
|
196 |
+
add_upsample=add_upsample,
|
197 |
+
resnet_eps=resnet_eps,
|
198 |
+
resnet_act_fn=resnet_act_fn,
|
199 |
+
resnet_groups=resnet_groups,
|
200 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
201 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
202 |
+
use_motion_module=use_motion_module,
|
203 |
+
motion_module_type=motion_module_type,
|
204 |
+
motion_module_kwargs=motion_module_kwargs,
|
205 |
+
)
|
206 |
+
|
207 |
+
if up_block_type == "CrossAttnUpBlock3D":
|
208 |
+
if cross_attention_dim is None:
|
209 |
+
raise ValueError(
|
210 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock3D"
|
211 |
+
)
|
212 |
+
return CrossAttnUpBlock3D(
|
213 |
+
num_layers=num_layers,
|
214 |
+
in_channels=in_channels,
|
215 |
+
out_channels=out_channels,
|
216 |
+
prev_output_channel=prev_output_channel,
|
217 |
+
temb_channels=temb_channels,
|
218 |
+
add_upsample=add_upsample,
|
219 |
+
resnet_eps=resnet_eps,
|
220 |
+
resnet_act_fn=resnet_act_fn,
|
221 |
+
resnet_groups=resnet_groups,
|
222 |
+
cross_attention_dim=cross_attention_dim,
|
223 |
+
audio_attention_dim=audio_attention_dim,
|
224 |
+
attn_num_head_channels=attn_num_head_channels,
|
225 |
+
dual_cross_attention=dual_cross_attention,
|
226 |
+
use_linear_projection=use_linear_projection,
|
227 |
+
only_cross_attention=only_cross_attention,
|
228 |
+
upcast_attention=upcast_attention,
|
229 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
230 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
231 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
232 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
233 |
+
use_motion_module=use_motion_module,
|
234 |
+
motion_module_type=motion_module_type,
|
235 |
+
motion_module_kwargs=motion_module_kwargs,
|
236 |
+
use_audio_module=use_audio_module,
|
237 |
+
depth=depth,
|
238 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
239 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
240 |
+
)
|
241 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
242 |
+
|
243 |
+
|
244 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
245 |
+
"""
|
246 |
+
A 3D UNet middle block with cross attention mechanism. This block is part of the U-Net architecture
|
247 |
+
and is used for feature extraction in the middle of the downsampling path.
|
248 |
+
|
249 |
+
Parameters:
|
250 |
+
- in_channels (int): Number of input channels.
|
251 |
+
- temb_channels (int): Number of token embedding channels.
|
252 |
+
- dropout (float): Dropout rate.
|
253 |
+
- num_layers (int): Number of layers in the block.
|
254 |
+
- resnet_eps (float): Epsilon for residual block.
|
255 |
+
- resnet_time_scale_shift (str): Time scale shift for time embedding normalization.
|
256 |
+
- resnet_act_fn (str): Activation function for the residual block.
|
257 |
+
- resnet_groups (int): Number of groups for the convolutions in the residual block.
|
258 |
+
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
|
259 |
+
- attn_num_head_channels (int): Number of attention heads.
|
260 |
+
- cross_attention_dim (int): Dimensionality of the cross attention layers.
|
261 |
+
- audio_attention_dim (int): Dimensionality of the audio attention layers.
|
262 |
+
- dual_cross_attention (bool): Whether to use dual cross attention.
|
263 |
+
- use_linear_projection (bool): Whether to use linear projection in attention.
|
264 |
+
- upcast_attention (bool): Whether to upcast attention to the original input dimension.
|
265 |
+
- unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net.
|
266 |
+
- unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net.
|
267 |
+
- use_inflated_groupnorm (bool): Whether to use inflated group normalization.
|
268 |
+
- use_motion_module (bool): Whether to use motion module.
|
269 |
+
- motion_module_type (str): Type of motion module.
|
270 |
+
- motion_module_kwargs (dict): Keyword arguments for the motion module.
|
271 |
+
- use_audio_module (bool): Whether to use audio module.
|
272 |
+
- depth (int): Depth of the block in the network.
|
273 |
+
- stack_enable_blocks_name (str): Name of the stack enable blocks.
|
274 |
+
- stack_enable_blocks_depth (int): Depth of the stack enable blocks.
|
275 |
+
|
276 |
+
Forward method:
|
277 |
+
The forward method applies the residual blocks, cross attention, and optional motion and audio modules
|
278 |
+
to the input hidden states. It returns the transformed hidden states.
|
279 |
+
"""
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
in_channels: int,
|
283 |
+
temb_channels: int,
|
284 |
+
dropout: float = 0.0,
|
285 |
+
num_layers: int = 1,
|
286 |
+
resnet_eps: float = 1e-6,
|
287 |
+
resnet_time_scale_shift: str = "default",
|
288 |
+
resnet_act_fn: str = "swish",
|
289 |
+
resnet_groups: int = 32,
|
290 |
+
resnet_pre_norm: bool = True,
|
291 |
+
attn_num_head_channels=1,
|
292 |
+
output_scale_factor=1.0,
|
293 |
+
cross_attention_dim=1280,
|
294 |
+
audio_attention_dim=1024,
|
295 |
+
dual_cross_attention=False,
|
296 |
+
use_linear_projection=False,
|
297 |
+
upcast_attention=False,
|
298 |
+
unet_use_cross_frame_attention=None,
|
299 |
+
unet_use_temporal_attention=None,
|
300 |
+
use_inflated_groupnorm=None,
|
301 |
+
use_motion_module=None,
|
302 |
+
motion_module_type=None,
|
303 |
+
motion_module_kwargs=None,
|
304 |
+
use_audio_module=None,
|
305 |
+
depth=0,
|
306 |
+
stack_enable_blocks_name=None,
|
307 |
+
stack_enable_blocks_depth=None,
|
308 |
+
):
|
309 |
+
super().__init__()
|
310 |
+
|
311 |
+
self.has_cross_attention = True
|
312 |
+
self.attn_num_head_channels = attn_num_head_channels
|
313 |
+
resnet_groups = (
|
314 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
315 |
+
)
|
316 |
+
|
317 |
+
# there is always at least one resnet
|
318 |
+
resnets = [
|
319 |
+
ResnetBlock3D(
|
320 |
+
in_channels=in_channels,
|
321 |
+
out_channels=in_channels,
|
322 |
+
temb_channels=temb_channels,
|
323 |
+
eps=resnet_eps,
|
324 |
+
groups=resnet_groups,
|
325 |
+
dropout=dropout,
|
326 |
+
time_embedding_norm=resnet_time_scale_shift,
|
327 |
+
non_linearity=resnet_act_fn,
|
328 |
+
output_scale_factor=output_scale_factor,
|
329 |
+
pre_norm=resnet_pre_norm,
|
330 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
331 |
+
)
|
332 |
+
]
|
333 |
+
attentions = []
|
334 |
+
motion_modules = []
|
335 |
+
audio_modules = []
|
336 |
+
|
337 |
+
for _ in range(num_layers):
|
338 |
+
if dual_cross_attention:
|
339 |
+
raise NotImplementedError
|
340 |
+
attentions.append(
|
341 |
+
Transformer3DModel(
|
342 |
+
attn_num_head_channels,
|
343 |
+
in_channels // attn_num_head_channels,
|
344 |
+
in_channels=in_channels,
|
345 |
+
num_layers=1,
|
346 |
+
cross_attention_dim=cross_attention_dim,
|
347 |
+
norm_num_groups=resnet_groups,
|
348 |
+
use_linear_projection=use_linear_projection,
|
349 |
+
upcast_attention=upcast_attention,
|
350 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
351 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
audio_modules.append(
|
355 |
+
Transformer3DModel(
|
356 |
+
attn_num_head_channels,
|
357 |
+
in_channels // attn_num_head_channels,
|
358 |
+
in_channels=in_channels,
|
359 |
+
num_layers=1,
|
360 |
+
cross_attention_dim=audio_attention_dim,
|
361 |
+
norm_num_groups=resnet_groups,
|
362 |
+
use_linear_projection=use_linear_projection,
|
363 |
+
upcast_attention=upcast_attention,
|
364 |
+
use_audio_module=use_audio_module,
|
365 |
+
depth=depth,
|
366 |
+
unet_block_name="mid",
|
367 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
368 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
369 |
+
)
|
370 |
+
if use_audio_module
|
371 |
+
else None
|
372 |
+
)
|
373 |
+
|
374 |
+
motion_modules.append(
|
375 |
+
get_motion_module(
|
376 |
+
in_channels=in_channels,
|
377 |
+
motion_module_type=motion_module_type,
|
378 |
+
motion_module_kwargs=motion_module_kwargs,
|
379 |
+
)
|
380 |
+
if use_motion_module
|
381 |
+
else None
|
382 |
+
)
|
383 |
+
resnets.append(
|
384 |
+
ResnetBlock3D(
|
385 |
+
in_channels=in_channels,
|
386 |
+
out_channels=in_channels,
|
387 |
+
temb_channels=temb_channels,
|
388 |
+
eps=resnet_eps,
|
389 |
+
groups=resnet_groups,
|
390 |
+
dropout=dropout,
|
391 |
+
time_embedding_norm=resnet_time_scale_shift,
|
392 |
+
non_linearity=resnet_act_fn,
|
393 |
+
output_scale_factor=output_scale_factor,
|
394 |
+
pre_norm=resnet_pre_norm,
|
395 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
396 |
+
)
|
397 |
+
)
|
398 |
+
|
399 |
+
self.attentions = nn.ModuleList(attentions)
|
400 |
+
self.resnets = nn.ModuleList(resnets)
|
401 |
+
self.audio_modules = nn.ModuleList(audio_modules)
|
402 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
403 |
+
|
404 |
+
def forward(
|
405 |
+
self,
|
406 |
+
hidden_states,
|
407 |
+
temb=None,
|
408 |
+
encoder_hidden_states=None,
|
409 |
+
attention_mask=None,
|
410 |
+
full_mask=None,
|
411 |
+
face_mask=None,
|
412 |
+
lip_mask=None,
|
413 |
+
audio_embedding=None,
|
414 |
+
motion_scale=None,
|
415 |
+
):
|
416 |
+
"""
|
417 |
+
Forward pass for the UNetMidBlock3DCrossAttn class.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
self (UNetMidBlock3DCrossAttn): An instance of the UNetMidBlock3DCrossAttn class.
|
421 |
+
hidden_states (Tensor): The input hidden states tensor.
|
422 |
+
temb (Tensor, optional): The input temporal embedding tensor. Defaults to None.
|
423 |
+
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
|
424 |
+
attention_mask (Tensor, optional): The attention mask tensor. Defaults to None.
|
425 |
+
full_mask (Tensor, optional): The full mask tensor. Defaults to None.
|
426 |
+
face_mask (Tensor, optional): The face mask tensor. Defaults to None.
|
427 |
+
lip_mask (Tensor, optional): The lip mask tensor. Defaults to None.
|
428 |
+
audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
Tensor: The output tensor after passing through the UNetMidBlock3DCrossAttn layers.
|
432 |
+
"""
|
433 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
434 |
+
for attn, resnet, audio_module, motion_module in zip(
|
435 |
+
self.attentions, self.resnets[1:], self.audio_modules, self.motion_modules
|
436 |
+
):
|
437 |
+
hidden_states, motion_frame = attn(
|
438 |
+
hidden_states,
|
439 |
+
encoder_hidden_states=encoder_hidden_states,
|
440 |
+
return_dict=False,
|
441 |
+
) # .sample
|
442 |
+
if len(motion_frame[0]) > 0:
|
443 |
+
# if motion_frame[0][0].numel() > 0:
|
444 |
+
motion_frames = motion_frame[0][0]
|
445 |
+
motion_frames = rearrange(
|
446 |
+
motion_frames,
|
447 |
+
"b f (d1 d2) c -> b c f d1 d2",
|
448 |
+
d1=hidden_states.size(-1),
|
449 |
+
)
|
450 |
+
|
451 |
+
else:
|
452 |
+
motion_frames = torch.zeros(
|
453 |
+
hidden_states.shape[0],
|
454 |
+
hidden_states.shape[1],
|
455 |
+
4,
|
456 |
+
hidden_states.shape[3],
|
457 |
+
hidden_states.shape[4],
|
458 |
+
)
|
459 |
+
|
460 |
+
n_motion_frames = motion_frames.size(2)
|
461 |
+
if audio_module is not None:
|
462 |
+
hidden_states = (
|
463 |
+
audio_module(
|
464 |
+
hidden_states,
|
465 |
+
encoder_hidden_states=audio_embedding,
|
466 |
+
attention_mask=attention_mask,
|
467 |
+
full_mask=full_mask,
|
468 |
+
face_mask=face_mask,
|
469 |
+
lip_mask=lip_mask,
|
470 |
+
motion_scale=motion_scale,
|
471 |
+
return_dict=False,
|
472 |
+
)
|
473 |
+
)[0] # .sample
|
474 |
+
if motion_module is not None:
|
475 |
+
motion_frames = motion_frames.to(
|
476 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
477 |
+
)
|
478 |
+
|
479 |
+
_hidden_states = (
|
480 |
+
torch.cat([motion_frames, hidden_states], dim=2)
|
481 |
+
if n_motion_frames > 0
|
482 |
+
else hidden_states
|
483 |
+
)
|
484 |
+
hidden_states = motion_module(
|
485 |
+
_hidden_states, encoder_hidden_states=encoder_hidden_states
|
486 |
+
)
|
487 |
+
hidden_states = hidden_states[:, :, n_motion_frames:]
|
488 |
+
|
489 |
+
hidden_states = resnet(hidden_states, temb)
|
490 |
+
|
491 |
+
return hidden_states
|
492 |
+
|
493 |
+
|
494 |
+
class CrossAttnDownBlock3D(nn.Module):
|
495 |
+
"""
|
496 |
+
A 3D downsampling block with cross attention for the U-Net architecture.
|
497 |
+
|
498 |
+
Parameters:
|
499 |
+
- (same as above, refer to the constructor for details)
|
500 |
+
|
501 |
+
Forward method:
|
502 |
+
The forward method downsamples the input hidden states using residual blocks and cross attention.
|
503 |
+
It also applies optional motion and audio modules. The method supports gradient checkpointing
|
504 |
+
to save memory during training.
|
505 |
+
"""
|
506 |
+
def __init__(
|
507 |
+
self,
|
508 |
+
in_channels: int,
|
509 |
+
out_channels: int,
|
510 |
+
temb_channels: int,
|
511 |
+
dropout: float = 0.0,
|
512 |
+
num_layers: int = 1,
|
513 |
+
resnet_eps: float = 1e-6,
|
514 |
+
resnet_time_scale_shift: str = "default",
|
515 |
+
resnet_act_fn: str = "swish",
|
516 |
+
resnet_groups: int = 32,
|
517 |
+
resnet_pre_norm: bool = True,
|
518 |
+
attn_num_head_channels=1,
|
519 |
+
cross_attention_dim=1280,
|
520 |
+
audio_attention_dim=1024,
|
521 |
+
output_scale_factor=1.0,
|
522 |
+
downsample_padding=1,
|
523 |
+
add_downsample=True,
|
524 |
+
dual_cross_attention=False,
|
525 |
+
use_linear_projection=False,
|
526 |
+
only_cross_attention=False,
|
527 |
+
upcast_attention=False,
|
528 |
+
unet_use_cross_frame_attention=None,
|
529 |
+
unet_use_temporal_attention=None,
|
530 |
+
use_inflated_groupnorm=None,
|
531 |
+
use_motion_module=None,
|
532 |
+
motion_module_type=None,
|
533 |
+
motion_module_kwargs=None,
|
534 |
+
use_audio_module=None,
|
535 |
+
depth=0,
|
536 |
+
stack_enable_blocks_name=None,
|
537 |
+
stack_enable_blocks_depth=None,
|
538 |
+
):
|
539 |
+
super().__init__()
|
540 |
+
resnets = []
|
541 |
+
attentions = []
|
542 |
+
audio_modules = []
|
543 |
+
motion_modules = []
|
544 |
+
|
545 |
+
self.has_cross_attention = True
|
546 |
+
self.attn_num_head_channels = attn_num_head_channels
|
547 |
+
|
548 |
+
for i in range(num_layers):
|
549 |
+
in_channels = in_channels if i == 0 else out_channels
|
550 |
+
resnets.append(
|
551 |
+
ResnetBlock3D(
|
552 |
+
in_channels=in_channels,
|
553 |
+
out_channels=out_channels,
|
554 |
+
temb_channels=temb_channels,
|
555 |
+
eps=resnet_eps,
|
556 |
+
groups=resnet_groups,
|
557 |
+
dropout=dropout,
|
558 |
+
time_embedding_norm=resnet_time_scale_shift,
|
559 |
+
non_linearity=resnet_act_fn,
|
560 |
+
output_scale_factor=output_scale_factor,
|
561 |
+
pre_norm=resnet_pre_norm,
|
562 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
563 |
+
)
|
564 |
+
)
|
565 |
+
if dual_cross_attention:
|
566 |
+
raise NotImplementedError
|
567 |
+
attentions.append(
|
568 |
+
Transformer3DModel(
|
569 |
+
attn_num_head_channels,
|
570 |
+
out_channels // attn_num_head_channels,
|
571 |
+
in_channels=out_channels,
|
572 |
+
num_layers=1,
|
573 |
+
cross_attention_dim=cross_attention_dim,
|
574 |
+
norm_num_groups=resnet_groups,
|
575 |
+
use_linear_projection=use_linear_projection,
|
576 |
+
only_cross_attention=only_cross_attention,
|
577 |
+
upcast_attention=upcast_attention,
|
578 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
579 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
580 |
+
)
|
581 |
+
)
|
582 |
+
# TODO:检查维度
|
583 |
+
audio_modules.append(
|
584 |
+
Transformer3DModel(
|
585 |
+
attn_num_head_channels,
|
586 |
+
in_channels // attn_num_head_channels,
|
587 |
+
in_channels=out_channels,
|
588 |
+
num_layers=1,
|
589 |
+
cross_attention_dim=audio_attention_dim,
|
590 |
+
norm_num_groups=resnet_groups,
|
591 |
+
use_linear_projection=use_linear_projection,
|
592 |
+
only_cross_attention=only_cross_attention,
|
593 |
+
upcast_attention=upcast_attention,
|
594 |
+
use_audio_module=use_audio_module,
|
595 |
+
depth=depth,
|
596 |
+
unet_block_name="down",
|
597 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
598 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
599 |
+
)
|
600 |
+
if use_audio_module
|
601 |
+
else None
|
602 |
+
)
|
603 |
+
motion_modules.append(
|
604 |
+
get_motion_module(
|
605 |
+
in_channels=out_channels,
|
606 |
+
motion_module_type=motion_module_type,
|
607 |
+
motion_module_kwargs=motion_module_kwargs,
|
608 |
+
)
|
609 |
+
if use_motion_module
|
610 |
+
else None
|
611 |
+
)
|
612 |
+
|
613 |
+
self.attentions = nn.ModuleList(attentions)
|
614 |
+
self.resnets = nn.ModuleList(resnets)
|
615 |
+
self.audio_modules = nn.ModuleList(audio_modules)
|
616 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
617 |
+
|
618 |
+
if add_downsample:
|
619 |
+
self.downsamplers = nn.ModuleList(
|
620 |
+
[
|
621 |
+
Downsample3D(
|
622 |
+
out_channels,
|
623 |
+
use_conv=True,
|
624 |
+
out_channels=out_channels,
|
625 |
+
padding=downsample_padding,
|
626 |
+
name="op",
|
627 |
+
)
|
628 |
+
]
|
629 |
+
)
|
630 |
+
else:
|
631 |
+
self.downsamplers = None
|
632 |
+
|
633 |
+
self.gradient_checkpointing = False
|
634 |
+
|
635 |
+
def forward(
|
636 |
+
self,
|
637 |
+
hidden_states,
|
638 |
+
temb=None,
|
639 |
+
encoder_hidden_states=None,
|
640 |
+
attention_mask=None,
|
641 |
+
full_mask=None,
|
642 |
+
face_mask=None,
|
643 |
+
lip_mask=None,
|
644 |
+
audio_embedding=None,
|
645 |
+
motion_scale=None,
|
646 |
+
):
|
647 |
+
"""
|
648 |
+
Defines the forward pass for the CrossAttnDownBlock3D class.
|
649 |
+
|
650 |
+
Parameters:
|
651 |
+
- hidden_states : torch.Tensor
|
652 |
+
The input tensor to the block.
|
653 |
+
temb : torch.Tensor, optional
|
654 |
+
The token embeddings from the previous block.
|
655 |
+
encoder_hidden_states : torch.Tensor, optional
|
656 |
+
The hidden states from the encoder.
|
657 |
+
attention_mask : torch.Tensor, optional
|
658 |
+
The attention mask for the cross-attention mechanism.
|
659 |
+
full_mask : torch.Tensor, optional
|
660 |
+
The full mask for the cross-attention mechanism.
|
661 |
+
face_mask : torch.Tensor, optional
|
662 |
+
The face mask for the cross-attention mechanism.
|
663 |
+
lip_mask : torch.Tensor, optional
|
664 |
+
The lip mask for the cross-attention mechanism.
|
665 |
+
audio_embedding : torch.Tensor, optional
|
666 |
+
The audio embedding for the cross-attention mechanism.
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
-- torch.Tensor
|
670 |
+
The output tensor from the block.
|
671 |
+
"""
|
672 |
+
output_states = ()
|
673 |
+
|
674 |
+
for _, (resnet, attn, audio_module, motion_module) in enumerate(
|
675 |
+
zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
|
676 |
+
):
|
677 |
+
# self.gradient_checkpointing = False
|
678 |
+
if self.training and self.gradient_checkpointing:
|
679 |
+
|
680 |
+
def create_custom_forward(module, return_dict=None):
|
681 |
+
def custom_forward(*inputs):
|
682 |
+
if return_dict is not None:
|
683 |
+
return module(*inputs, return_dict=return_dict)
|
684 |
+
|
685 |
+
return module(*inputs)
|
686 |
+
|
687 |
+
return custom_forward
|
688 |
+
|
689 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
690 |
+
create_custom_forward(resnet), hidden_states, temb
|
691 |
+
)
|
692 |
+
|
693 |
+
motion_frames = []
|
694 |
+
hidden_states, motion_frame = torch.utils.checkpoint.checkpoint(
|
695 |
+
create_custom_forward(attn, return_dict=False),
|
696 |
+
hidden_states,
|
697 |
+
encoder_hidden_states,
|
698 |
+
)
|
699 |
+
if len(motion_frame[0]) > 0:
|
700 |
+
motion_frames = motion_frame[0][0]
|
701 |
+
# motion_frames = torch.cat(motion_frames, dim=0)
|
702 |
+
motion_frames = rearrange(
|
703 |
+
motion_frames,
|
704 |
+
"b f (d1 d2) c -> b c f d1 d2",
|
705 |
+
d1=hidden_states.size(-1),
|
706 |
+
)
|
707 |
+
|
708 |
+
else:
|
709 |
+
motion_frames = torch.zeros(
|
710 |
+
hidden_states.shape[0],
|
711 |
+
hidden_states.shape[1],
|
712 |
+
4,
|
713 |
+
hidden_states.shape[3],
|
714 |
+
hidden_states.shape[4],
|
715 |
+
)
|
716 |
+
|
717 |
+
n_motion_frames = motion_frames.size(2)
|
718 |
+
|
719 |
+
if audio_module is not None:
|
720 |
+
# audio_embedding = audio_embedding
|
721 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
722 |
+
create_custom_forward(audio_module, return_dict=False),
|
723 |
+
hidden_states,
|
724 |
+
audio_embedding,
|
725 |
+
attention_mask,
|
726 |
+
full_mask,
|
727 |
+
face_mask,
|
728 |
+
lip_mask,
|
729 |
+
motion_scale,
|
730 |
+
)[0]
|
731 |
+
|
732 |
+
# add motion module
|
733 |
+
if motion_module is not None:
|
734 |
+
motion_frames = motion_frames.to(
|
735 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
736 |
+
)
|
737 |
+
_hidden_states = torch.cat(
|
738 |
+
[motion_frames, hidden_states], dim=2
|
739 |
+
) # if n_motion_frames > 0 else hidden_states
|
740 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
741 |
+
create_custom_forward(motion_module),
|
742 |
+
_hidden_states,
|
743 |
+
encoder_hidden_states,
|
744 |
+
)
|
745 |
+
hidden_states = hidden_states[:, :, n_motion_frames:]
|
746 |
+
|
747 |
+
else:
|
748 |
+
hidden_states = resnet(hidden_states, temb)
|
749 |
+
hidden_states = attn(
|
750 |
+
hidden_states,
|
751 |
+
encoder_hidden_states=encoder_hidden_states,
|
752 |
+
).sample
|
753 |
+
if audio_module is not None:
|
754 |
+
hidden_states = audio_module(
|
755 |
+
hidden_states,
|
756 |
+
audio_embedding,
|
757 |
+
attention_mask=attention_mask,
|
758 |
+
full_mask=full_mask,
|
759 |
+
face_mask=face_mask,
|
760 |
+
lip_mask=lip_mask,
|
761 |
+
return_dict=False,
|
762 |
+
)[0]
|
763 |
+
# add motion module
|
764 |
+
if motion_module is not None:
|
765 |
+
hidden_states = motion_module(
|
766 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
767 |
+
)
|
768 |
+
|
769 |
+
output_states += (hidden_states,)
|
770 |
+
|
771 |
+
if self.downsamplers is not None:
|
772 |
+
for downsampler in self.downsamplers:
|
773 |
+
hidden_states = downsampler(hidden_states)
|
774 |
+
|
775 |
+
output_states += (hidden_states,)
|
776 |
+
|
777 |
+
return hidden_states, output_states
|
778 |
+
|
779 |
+
|
780 |
+
class DownBlock3D(nn.Module):
|
781 |
+
"""
|
782 |
+
A 3D downsampling block for the U-Net architecture. This block performs downsampling operations
|
783 |
+
using residual blocks and an optional motion module.
|
784 |
+
|
785 |
+
Parameters:
|
786 |
+
- in_channels (int): Number of input channels.
|
787 |
+
- out_channels (int): Number of output channels.
|
788 |
+
- temb_channels (int): Number of token embedding channels.
|
789 |
+
- dropout (float): Dropout rate for the block.
|
790 |
+
- num_layers (int): Number of layers in the block.
|
791 |
+
- resnet_eps (float): Epsilon for residual block stability.
|
792 |
+
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
|
793 |
+
- resnet_act_fn (str): Activation function used in the residual block.
|
794 |
+
- resnet_groups (int): Number of groups for the convolutions in the residual block.
|
795 |
+
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
|
796 |
+
- output_scale_factor (float): Scaling factor for the block's output.
|
797 |
+
- add_downsample (bool): Whether to add a downsampling layer.
|
798 |
+
- downsample_padding (int): Padding for the downsampling layer.
|
799 |
+
- use_inflated_groupnorm (bool): Whether to use inflated group normalization.
|
800 |
+
- use_motion_module (bool): Whether to include a motion module.
|
801 |
+
- motion_module_type (str): Type of motion module to use.
|
802 |
+
- motion_module_kwargs (dict): Keyword arguments for the motion module.
|
803 |
+
|
804 |
+
Forward method:
|
805 |
+
The forward method processes the input hidden states through the residual blocks and optional
|
806 |
+
motion modules, followed by an optional downsampling step. It supports gradient checkpointing
|
807 |
+
during training to reduce memory usage.
|
808 |
+
"""
|
809 |
+
def __init__(
|
810 |
+
self,
|
811 |
+
in_channels: int,
|
812 |
+
out_channels: int,
|
813 |
+
temb_channels: int,
|
814 |
+
dropout: float = 0.0,
|
815 |
+
num_layers: int = 1,
|
816 |
+
resnet_eps: float = 1e-6,
|
817 |
+
resnet_time_scale_shift: str = "default",
|
818 |
+
resnet_act_fn: str = "swish",
|
819 |
+
resnet_groups: int = 32,
|
820 |
+
resnet_pre_norm: bool = True,
|
821 |
+
output_scale_factor=1.0,
|
822 |
+
add_downsample=True,
|
823 |
+
downsample_padding=1,
|
824 |
+
use_inflated_groupnorm=None,
|
825 |
+
use_motion_module=None,
|
826 |
+
motion_module_type=None,
|
827 |
+
motion_module_kwargs=None,
|
828 |
+
):
|
829 |
+
super().__init__()
|
830 |
+
resnets = []
|
831 |
+
motion_modules = []
|
832 |
+
|
833 |
+
# use_motion_module = False
|
834 |
+
for i in range(num_layers):
|
835 |
+
in_channels = in_channels if i == 0 else out_channels
|
836 |
+
resnets.append(
|
837 |
+
ResnetBlock3D(
|
838 |
+
in_channels=in_channels,
|
839 |
+
out_channels=out_channels,
|
840 |
+
temb_channels=temb_channels,
|
841 |
+
eps=resnet_eps,
|
842 |
+
groups=resnet_groups,
|
843 |
+
dropout=dropout,
|
844 |
+
time_embedding_norm=resnet_time_scale_shift,
|
845 |
+
non_linearity=resnet_act_fn,
|
846 |
+
output_scale_factor=output_scale_factor,
|
847 |
+
pre_norm=resnet_pre_norm,
|
848 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
849 |
+
)
|
850 |
+
)
|
851 |
+
motion_modules.append(
|
852 |
+
get_motion_module(
|
853 |
+
in_channels=out_channels,
|
854 |
+
motion_module_type=motion_module_type,
|
855 |
+
motion_module_kwargs=motion_module_kwargs,
|
856 |
+
)
|
857 |
+
if use_motion_module
|
858 |
+
else None
|
859 |
+
)
|
860 |
+
|
861 |
+
self.resnets = nn.ModuleList(resnets)
|
862 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
863 |
+
|
864 |
+
if add_downsample:
|
865 |
+
self.downsamplers = nn.ModuleList(
|
866 |
+
[
|
867 |
+
Downsample3D(
|
868 |
+
out_channels,
|
869 |
+
use_conv=True,
|
870 |
+
out_channels=out_channels,
|
871 |
+
padding=downsample_padding,
|
872 |
+
name="op",
|
873 |
+
)
|
874 |
+
]
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
self.downsamplers = None
|
878 |
+
|
879 |
+
self.gradient_checkpointing = False
|
880 |
+
|
881 |
+
def forward(
|
882 |
+
self,
|
883 |
+
hidden_states,
|
884 |
+
temb=None,
|
885 |
+
encoder_hidden_states=None,
|
886 |
+
):
|
887 |
+
"""
|
888 |
+
forward method for the DownBlock3D class.
|
889 |
+
|
890 |
+
Args:
|
891 |
+
hidden_states (Tensor): The input tensor to the DownBlock3D layer.
|
892 |
+
temb (Tensor, optional): The token embeddings, if using transformer.
|
893 |
+
encoder_hidden_states (Tensor, optional): The hidden states from the encoder.
|
894 |
+
|
895 |
+
Returns:
|
896 |
+
Tensor: The output tensor after passing through the DownBlock3D layer.
|
897 |
+
"""
|
898 |
+
output_states = ()
|
899 |
+
|
900 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
901 |
+
# print(f"DownBlock3D {self.gradient_checkpointing = }")
|
902 |
+
if self.training and self.gradient_checkpointing:
|
903 |
+
|
904 |
+
def create_custom_forward(module):
|
905 |
+
def custom_forward(*inputs):
|
906 |
+
return module(*inputs)
|
907 |
+
|
908 |
+
return custom_forward
|
909 |
+
|
910 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
911 |
+
create_custom_forward(resnet), hidden_states, temb
|
912 |
+
)
|
913 |
+
|
914 |
+
else:
|
915 |
+
hidden_states = resnet(hidden_states, temb)
|
916 |
+
|
917 |
+
# add motion module
|
918 |
+
hidden_states = (
|
919 |
+
motion_module(
|
920 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
921 |
+
)
|
922 |
+
if motion_module is not None
|
923 |
+
else hidden_states
|
924 |
+
)
|
925 |
+
|
926 |
+
output_states += (hidden_states,)
|
927 |
+
|
928 |
+
if self.downsamplers is not None:
|
929 |
+
for downsampler in self.downsamplers:
|
930 |
+
hidden_states = downsampler(hidden_states)
|
931 |
+
|
932 |
+
output_states += (hidden_states,)
|
933 |
+
|
934 |
+
return hidden_states, output_states
|
935 |
+
|
936 |
+
|
937 |
+
class CrossAttnUpBlock3D(nn.Module):
|
938 |
+
"""
|
939 |
+
Standard 3D downsampling block for the U-Net architecture. This block performs downsampling
|
940 |
+
operations in the U-Net using residual blocks and an optional motion module.
|
941 |
+
|
942 |
+
Parameters:
|
943 |
+
- in_channels (int): Number of input channels.
|
944 |
+
- out_channels (int): Number of output channels.
|
945 |
+
- temb_channels (int): Number of channels for the temporal embedding.
|
946 |
+
- dropout (float): Dropout rate for the block.
|
947 |
+
- num_layers (int): Number of layers in the block.
|
948 |
+
- resnet_eps (float): Epsilon for residual block stability.
|
949 |
+
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
|
950 |
+
- resnet_act_fn (str): Activation function used in the residual block.
|
951 |
+
- resnet_groups (int): Number of groups for the convolutions in the residual block.
|
952 |
+
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
|
953 |
+
- output_scale_factor (float): Scaling factor for the block's output.
|
954 |
+
- add_downsample (bool): Whether to add a downsampling layer.
|
955 |
+
- downsample_padding (int): Padding for the downsampling layer.
|
956 |
+
- use_inflated_groupnorm (bool): Whether to use inflated group normalization.
|
957 |
+
- use_motion_module (bool): Whether to include a motion module.
|
958 |
+
- motion_module_type (str): Type of motion module to use.
|
959 |
+
- motion_module_kwargs (dict): Keyword arguments for the motion module.
|
960 |
+
|
961 |
+
Forward method:
|
962 |
+
The forward method processes the input hidden states through the residual blocks and optional
|
963 |
+
motion modules, followed by an optional downsampling step. It supports gradient checkpointing
|
964 |
+
during training to reduce memory usage.
|
965 |
+
"""
|
966 |
+
def __init__(
|
967 |
+
self,
|
968 |
+
in_channels: int,
|
969 |
+
out_channels: int,
|
970 |
+
prev_output_channel: int,
|
971 |
+
temb_channels: int,
|
972 |
+
dropout: float = 0.0,
|
973 |
+
num_layers: int = 1,
|
974 |
+
resnet_eps: float = 1e-6,
|
975 |
+
resnet_time_scale_shift: str = "default",
|
976 |
+
resnet_act_fn: str = "swish",
|
977 |
+
resnet_groups: int = 32,
|
978 |
+
resnet_pre_norm: bool = True,
|
979 |
+
attn_num_head_channels=1,
|
980 |
+
cross_attention_dim=1280,
|
981 |
+
audio_attention_dim=1024,
|
982 |
+
output_scale_factor=1.0,
|
983 |
+
add_upsample=True,
|
984 |
+
dual_cross_attention=False,
|
985 |
+
use_linear_projection=False,
|
986 |
+
only_cross_attention=False,
|
987 |
+
upcast_attention=False,
|
988 |
+
unet_use_cross_frame_attention=None,
|
989 |
+
unet_use_temporal_attention=None,
|
990 |
+
use_motion_module=None,
|
991 |
+
use_inflated_groupnorm=None,
|
992 |
+
motion_module_type=None,
|
993 |
+
motion_module_kwargs=None,
|
994 |
+
use_audio_module=None,
|
995 |
+
depth=0,
|
996 |
+
stack_enable_blocks_name=None,
|
997 |
+
stack_enable_blocks_depth=None,
|
998 |
+
):
|
999 |
+
super().__init__()
|
1000 |
+
resnets = []
|
1001 |
+
attentions = []
|
1002 |
+
audio_modules = []
|
1003 |
+
motion_modules = []
|
1004 |
+
|
1005 |
+
self.has_cross_attention = True
|
1006 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1007 |
+
|
1008 |
+
for i in range(num_layers):
|
1009 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1010 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1011 |
+
|
1012 |
+
resnets.append(
|
1013 |
+
ResnetBlock3D(
|
1014 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1015 |
+
out_channels=out_channels,
|
1016 |
+
temb_channels=temb_channels,
|
1017 |
+
eps=resnet_eps,
|
1018 |
+
groups=resnet_groups,
|
1019 |
+
dropout=dropout,
|
1020 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1021 |
+
non_linearity=resnet_act_fn,
|
1022 |
+
output_scale_factor=output_scale_factor,
|
1023 |
+
pre_norm=resnet_pre_norm,
|
1024 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
1025 |
+
)
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
if dual_cross_attention:
|
1029 |
+
raise NotImplementedError
|
1030 |
+
attentions.append(
|
1031 |
+
Transformer3DModel(
|
1032 |
+
attn_num_head_channels,
|
1033 |
+
out_channels // attn_num_head_channels,
|
1034 |
+
in_channels=out_channels,
|
1035 |
+
num_layers=1,
|
1036 |
+
cross_attention_dim=cross_attention_dim,
|
1037 |
+
norm_num_groups=resnet_groups,
|
1038 |
+
use_linear_projection=use_linear_projection,
|
1039 |
+
only_cross_attention=only_cross_attention,
|
1040 |
+
upcast_attention=upcast_attention,
|
1041 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
1042 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
1043 |
+
)
|
1044 |
+
)
|
1045 |
+
audio_modules.append(
|
1046 |
+
Transformer3DModel(
|
1047 |
+
attn_num_head_channels,
|
1048 |
+
in_channels // attn_num_head_channels,
|
1049 |
+
in_channels=out_channels,
|
1050 |
+
num_layers=1,
|
1051 |
+
cross_attention_dim=audio_attention_dim,
|
1052 |
+
norm_num_groups=resnet_groups,
|
1053 |
+
use_linear_projection=use_linear_projection,
|
1054 |
+
only_cross_attention=only_cross_attention,
|
1055 |
+
upcast_attention=upcast_attention,
|
1056 |
+
use_audio_module=use_audio_module,
|
1057 |
+
depth=depth,
|
1058 |
+
unet_block_name="up",
|
1059 |
+
stack_enable_blocks_name=stack_enable_blocks_name,
|
1060 |
+
stack_enable_blocks_depth=stack_enable_blocks_depth,
|
1061 |
+
)
|
1062 |
+
if use_audio_module
|
1063 |
+
else None
|
1064 |
+
)
|
1065 |
+
motion_modules.append(
|
1066 |
+
get_motion_module(
|
1067 |
+
in_channels=out_channels,
|
1068 |
+
motion_module_type=motion_module_type,
|
1069 |
+
motion_module_kwargs=motion_module_kwargs,
|
1070 |
+
)
|
1071 |
+
if use_motion_module
|
1072 |
+
else None
|
1073 |
+
)
|
1074 |
+
|
1075 |
+
self.attentions = nn.ModuleList(attentions)
|
1076 |
+
self.resnets = nn.ModuleList(resnets)
|
1077 |
+
self.audio_modules = nn.ModuleList(audio_modules)
|
1078 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1079 |
+
|
1080 |
+
if add_upsample:
|
1081 |
+
self.upsamplers = nn.ModuleList(
|
1082 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
1083 |
+
)
|
1084 |
+
else:
|
1085 |
+
self.upsamplers = None
|
1086 |
+
|
1087 |
+
self.gradient_checkpointing = False
|
1088 |
+
|
1089 |
+
def forward(
|
1090 |
+
self,
|
1091 |
+
hidden_states,
|
1092 |
+
res_hidden_states_tuple,
|
1093 |
+
temb=None,
|
1094 |
+
encoder_hidden_states=None,
|
1095 |
+
upsample_size=None,
|
1096 |
+
attention_mask=None,
|
1097 |
+
full_mask=None,
|
1098 |
+
face_mask=None,
|
1099 |
+
lip_mask=None,
|
1100 |
+
audio_embedding=None,
|
1101 |
+
motion_scale=None,
|
1102 |
+
):
|
1103 |
+
"""
|
1104 |
+
Forward pass for the CrossAttnUpBlock3D class.
|
1105 |
+
|
1106 |
+
Args:
|
1107 |
+
self (CrossAttnUpBlock3D): An instance of the CrossAttnUpBlock3D class.
|
1108 |
+
hidden_states (Tensor): The input hidden states tensor.
|
1109 |
+
res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors.
|
1110 |
+
temb (Tensor, optional): The token embeddings tensor. Defaults to None.
|
1111 |
+
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
|
1112 |
+
upsample_size (int, optional): The upsample size. Defaults to None.
|
1113 |
+
attention_mask (Tensor, optional): The attention mask tensor. Defaults to None.
|
1114 |
+
full_mask (Tensor, optional): The full mask tensor. Defaults to None.
|
1115 |
+
face_mask (Tensor, optional): The face mask tensor. Defaults to None.
|
1116 |
+
lip_mask (Tensor, optional): The lip mask tensor. Defaults to None.
|
1117 |
+
audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None.
|
1118 |
+
|
1119 |
+
Returns:
|
1120 |
+
Tensor: The output tensor after passing through the CrossAttnUpBlock3D.
|
1121 |
+
"""
|
1122 |
+
for _, (resnet, attn, audio_module, motion_module) in enumerate(
|
1123 |
+
zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
|
1124 |
+
):
|
1125 |
+
# pop res hidden states
|
1126 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1127 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1128 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1129 |
+
|
1130 |
+
if self.training and self.gradient_checkpointing:
|
1131 |
+
|
1132 |
+
def create_custom_forward(module, return_dict=None):
|
1133 |
+
def custom_forward(*inputs):
|
1134 |
+
if return_dict is not None:
|
1135 |
+
return module(*inputs, return_dict=return_dict)
|
1136 |
+
|
1137 |
+
return module(*inputs)
|
1138 |
+
|
1139 |
+
return custom_forward
|
1140 |
+
|
1141 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1142 |
+
create_custom_forward(resnet), hidden_states, temb
|
1143 |
+
)
|
1144 |
+
|
1145 |
+
motion_frames = []
|
1146 |
+
hidden_states, motion_frame = torch.utils.checkpoint.checkpoint(
|
1147 |
+
create_custom_forward(attn, return_dict=False),
|
1148 |
+
hidden_states,
|
1149 |
+
encoder_hidden_states,
|
1150 |
+
)
|
1151 |
+
if len(motion_frame[0]) > 0:
|
1152 |
+
motion_frames = motion_frame[0][0]
|
1153 |
+
# motion_frames = torch.cat(motion_frames, dim=0)
|
1154 |
+
motion_frames = rearrange(
|
1155 |
+
motion_frames,
|
1156 |
+
"b f (d1 d2) c -> b c f d1 d2",
|
1157 |
+
d1=hidden_states.size(-1),
|
1158 |
+
)
|
1159 |
+
else:
|
1160 |
+
motion_frames = torch.zeros(
|
1161 |
+
hidden_states.shape[0],
|
1162 |
+
hidden_states.shape[1],
|
1163 |
+
4,
|
1164 |
+
hidden_states.shape[3],
|
1165 |
+
hidden_states.shape[4],
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
n_motion_frames = motion_frames.size(2)
|
1169 |
+
|
1170 |
+
if audio_module is not None:
|
1171 |
+
# audio_embedding = audio_embedding
|
1172 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1173 |
+
create_custom_forward(audio_module, return_dict=False),
|
1174 |
+
hidden_states,
|
1175 |
+
audio_embedding,
|
1176 |
+
attention_mask,
|
1177 |
+
full_mask,
|
1178 |
+
face_mask,
|
1179 |
+
lip_mask,
|
1180 |
+
motion_scale,
|
1181 |
+
)[0]
|
1182 |
+
|
1183 |
+
# add motion module
|
1184 |
+
if motion_module is not None:
|
1185 |
+
motion_frames = motion_frames.to(
|
1186 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
1187 |
+
)
|
1188 |
+
|
1189 |
+
_hidden_states = (
|
1190 |
+
torch.cat([motion_frames, hidden_states], dim=2)
|
1191 |
+
if n_motion_frames > 0
|
1192 |
+
else hidden_states
|
1193 |
+
)
|
1194 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1195 |
+
create_custom_forward(motion_module),
|
1196 |
+
_hidden_states,
|
1197 |
+
encoder_hidden_states,
|
1198 |
+
)
|
1199 |
+
hidden_states = hidden_states[:, :, n_motion_frames:]
|
1200 |
+
else:
|
1201 |
+
hidden_states = resnet(hidden_states, temb)
|
1202 |
+
hidden_states = attn(
|
1203 |
+
hidden_states,
|
1204 |
+
encoder_hidden_states=encoder_hidden_states,
|
1205 |
+
).sample
|
1206 |
+
|
1207 |
+
if audio_module is not None:
|
1208 |
+
|
1209 |
+
hidden_states = (
|
1210 |
+
audio_module(
|
1211 |
+
hidden_states,
|
1212 |
+
encoder_hidden_states=audio_embedding,
|
1213 |
+
attention_mask=attention_mask,
|
1214 |
+
full_mask=full_mask,
|
1215 |
+
face_mask=face_mask,
|
1216 |
+
lip_mask=lip_mask,
|
1217 |
+
)
|
1218 |
+
).sample
|
1219 |
+
# add motion module
|
1220 |
+
hidden_states = (
|
1221 |
+
motion_module(
|
1222 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
1223 |
+
)
|
1224 |
+
if motion_module is not None
|
1225 |
+
else hidden_states
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
if self.upsamplers is not None:
|
1229 |
+
for upsampler in self.upsamplers:
|
1230 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1231 |
+
|
1232 |
+
return hidden_states
|
1233 |
+
|
1234 |
+
|
1235 |
+
class UpBlock3D(nn.Module):
|
1236 |
+
"""
|
1237 |
+
3D upsampling block with cross attention for the U-Net architecture. This block performs
|
1238 |
+
upsampling operations and incorporates cross attention mechanisms, which allow the model to
|
1239 |
+
focus on different parts of the input when upscaling.
|
1240 |
+
|
1241 |
+
Parameters:
|
1242 |
+
- in_channels (int): Number of input channels.
|
1243 |
+
- out_channels (int): Number of output channels.
|
1244 |
+
- prev_output_channel (int): Number of channels from the previous layer's output.
|
1245 |
+
- temb_channels (int): Number of channels for the temporal embedding.
|
1246 |
+
- dropout (float): Dropout rate for the block.
|
1247 |
+
- num_layers (int): Number of layers in the block.
|
1248 |
+
- resnet_eps (float): Epsilon for residual block stability.
|
1249 |
+
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
|
1250 |
+
- resnet_act_fn (str): Activation function used in the residual block.
|
1251 |
+
- resnet_groups (int): Number of groups for the convolutions in the residual block.
|
1252 |
+
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
|
1253 |
+
- attn_num_head_channels (int): Number of attention heads for the cross attention mechanism.
|
1254 |
+
- cross_attention_dim (int): Dimensionality of the cross attention layers.
|
1255 |
+
- audio_attention_dim (int): Dimensionality of the audio attention layers.
|
1256 |
+
- output_scale_factor (float): Scaling factor for the block's output.
|
1257 |
+
- add_upsample (bool): Whether to add an upsampling layer.
|
1258 |
+
- dual_cross_attention (bool): Whether to use dual cross attention (not implemented).
|
1259 |
+
- use_linear_projection (bool): Whether to use linear projection in the cross attention.
|
1260 |
+
- only_cross_attention (bool): Whether to use only cross attention (no self-attention).
|
1261 |
+
- upcast_attention (bool): Whether to upcast attention to the original input dimension.
|
1262 |
+
- unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net.
|
1263 |
+
- unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net.
|
1264 |
+
- use_motion_module (bool): Whether to include a motion module.
|
1265 |
+
- use_inflated_groupnorm (bool): Whether to use inflated group normalization.
|
1266 |
+
- motion_module_type (str): Type of motion module to use.
|
1267 |
+
- motion_module_kwargs (dict): Keyword arguments for the motion module.
|
1268 |
+
- use_audio_module (bool): Whether to include an audio module.
|
1269 |
+
- depth (int): Depth of the block in the network.
|
1270 |
+
- stack_enable_blocks_name (str): Name of the stack enable blocks.
|
1271 |
+
- stack_enable_blocks_depth (int): Depth of the stack enable blocks.
|
1272 |
+
|
1273 |
+
Forward method:
|
1274 |
+
The forward method upsamples the input hidden states and residual hidden states, processes
|
1275 |
+
them through the residual and cross attention blocks, and optional motion and audio modules.
|
1276 |
+
It supports gradient checkpointing during training.
|
1277 |
+
"""
|
1278 |
+
def __init__(
|
1279 |
+
self,
|
1280 |
+
in_channels: int,
|
1281 |
+
prev_output_channel: int,
|
1282 |
+
out_channels: int,
|
1283 |
+
temb_channels: int,
|
1284 |
+
dropout: float = 0.0,
|
1285 |
+
num_layers: int = 1,
|
1286 |
+
resnet_eps: float = 1e-6,
|
1287 |
+
resnet_time_scale_shift: str = "default",
|
1288 |
+
resnet_act_fn: str = "swish",
|
1289 |
+
resnet_groups: int = 32,
|
1290 |
+
resnet_pre_norm: bool = True,
|
1291 |
+
output_scale_factor=1.0,
|
1292 |
+
add_upsample=True,
|
1293 |
+
use_inflated_groupnorm=None,
|
1294 |
+
use_motion_module=None,
|
1295 |
+
motion_module_type=None,
|
1296 |
+
motion_module_kwargs=None,
|
1297 |
+
):
|
1298 |
+
super().__init__()
|
1299 |
+
resnets = []
|
1300 |
+
motion_modules = []
|
1301 |
+
|
1302 |
+
# use_motion_module = False
|
1303 |
+
for i in range(num_layers):
|
1304 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1305 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1306 |
+
|
1307 |
+
resnets.append(
|
1308 |
+
ResnetBlock3D(
|
1309 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1310 |
+
out_channels=out_channels,
|
1311 |
+
temb_channels=temb_channels,
|
1312 |
+
eps=resnet_eps,
|
1313 |
+
groups=resnet_groups,
|
1314 |
+
dropout=dropout,
|
1315 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1316 |
+
non_linearity=resnet_act_fn,
|
1317 |
+
output_scale_factor=output_scale_factor,
|
1318 |
+
pre_norm=resnet_pre_norm,
|
1319 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
1320 |
+
)
|
1321 |
+
)
|
1322 |
+
motion_modules.append(
|
1323 |
+
get_motion_module(
|
1324 |
+
in_channels=out_channels,
|
1325 |
+
motion_module_type=motion_module_type,
|
1326 |
+
motion_module_kwargs=motion_module_kwargs,
|
1327 |
+
)
|
1328 |
+
if use_motion_module
|
1329 |
+
else None
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
self.resnets = nn.ModuleList(resnets)
|
1333 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
1334 |
+
|
1335 |
+
if add_upsample:
|
1336 |
+
self.upsamplers = nn.ModuleList(
|
1337 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
1338 |
+
)
|
1339 |
+
else:
|
1340 |
+
self.upsamplers = None
|
1341 |
+
|
1342 |
+
self.gradient_checkpointing = False
|
1343 |
+
|
1344 |
+
def forward(
|
1345 |
+
self,
|
1346 |
+
hidden_states,
|
1347 |
+
res_hidden_states_tuple,
|
1348 |
+
temb=None,
|
1349 |
+
upsample_size=None,
|
1350 |
+
encoder_hidden_states=None,
|
1351 |
+
):
|
1352 |
+
"""
|
1353 |
+
Forward pass for the UpBlock3D class.
|
1354 |
+
|
1355 |
+
Args:
|
1356 |
+
self (UpBlock3D): An instance of the UpBlock3D class.
|
1357 |
+
hidden_states (Tensor): The input hidden states tensor.
|
1358 |
+
res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors.
|
1359 |
+
temb (Tensor, optional): The token embeddings tensor. Defaults to None.
|
1360 |
+
upsample_size (int, optional): The upsample size. Defaults to None.
|
1361 |
+
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
|
1362 |
+
|
1363 |
+
Returns:
|
1364 |
+
Tensor: The output tensor after passing through the UpBlock3D layers.
|
1365 |
+
"""
|
1366 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
1367 |
+
# pop res hidden states
|
1368 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1369 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1370 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1371 |
+
|
1372 |
+
# print(f"UpBlock3D {self.gradient_checkpointing = }")
|
1373 |
+
if self.training and self.gradient_checkpointing:
|
1374 |
+
|
1375 |
+
def create_custom_forward(module):
|
1376 |
+
def custom_forward(*inputs):
|
1377 |
+
return module(*inputs)
|
1378 |
+
|
1379 |
+
return custom_forward
|
1380 |
+
|
1381 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1382 |
+
create_custom_forward(resnet), hidden_states, temb
|
1383 |
+
)
|
1384 |
+
else:
|
1385 |
+
hidden_states = resnet(hidden_states, temb)
|
1386 |
+
hidden_states = (
|
1387 |
+
motion_module(
|
1388 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
1389 |
+
)
|
1390 |
+
if motion_module is not None
|
1391 |
+
else hidden_states
|
1392 |
+
)
|
1393 |
+
|
1394 |
+
if self.upsamplers is not None:
|
1395 |
+
for upsampler in self.upsamplers:
|
1396 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1397 |
+
|
1398 |
+
return hidden_states
|
joyhallo/models/wav2vec.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
|
3 |
+
It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
|
4 |
+
such as feature extraction and encoding.
|
5 |
+
|
6 |
+
Classes:
|
7 |
+
Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
8 |
+
|
9 |
+
Functions:
|
10 |
+
linear_interpolation: Interpolates the features based on the sequence length.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from transformers import Wav2Vec2Model
|
15 |
+
from transformers.modeling_outputs import BaseModelOutput
|
16 |
+
|
17 |
+
|
18 |
+
class Wav2VecModel(Wav2Vec2Model):
|
19 |
+
"""
|
20 |
+
Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
|
21 |
+
It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
22 |
+
...
|
23 |
+
|
24 |
+
Attributes:
|
25 |
+
base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
|
26 |
+
|
27 |
+
Methods:
|
28 |
+
forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
|
29 |
+
, output_attentions=None, output_hidden_states=None, return_dict=None):
|
30 |
+
Forward pass of the Wav2VecModel.
|
31 |
+
It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
|
32 |
+
|
33 |
+
feature_extract(input_values, seq_len):
|
34 |
+
Extracts features from the input_values using the base model.
|
35 |
+
|
36 |
+
encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
|
37 |
+
Encodes the extracted features using the base model and returns the encoded features.
|
38 |
+
"""
|
39 |
+
def forward(
|
40 |
+
self,
|
41 |
+
input_values,
|
42 |
+
seq_len,
|
43 |
+
attention_mask=None,
|
44 |
+
mask_time_indices=None,
|
45 |
+
output_attentions=None,
|
46 |
+
output_hidden_states=None,
|
47 |
+
return_dict=None,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Forward pass of the Wav2Vec model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
self: The instance of the model.
|
54 |
+
input_values: The input values (waveform) to the model.
|
55 |
+
seq_len: The sequence length of the input values.
|
56 |
+
attention_mask: Attention mask to be used for the model.
|
57 |
+
mask_time_indices: Mask indices to be used for the model.
|
58 |
+
output_attentions: If set to True, returns attentions.
|
59 |
+
output_hidden_states: If set to True, returns hidden states.
|
60 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
The output of the Wav2Vec model.
|
64 |
+
"""
|
65 |
+
self.config.output_attentions = True
|
66 |
+
|
67 |
+
output_hidden_states = (
|
68 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
69 |
+
)
|
70 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
71 |
+
|
72 |
+
extract_features = self.feature_extractor(input_values)
|
73 |
+
extract_features = extract_features.transpose(1, 2)
|
74 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
75 |
+
|
76 |
+
if attention_mask is not None:
|
77 |
+
# compute reduced attention_mask corresponding to feature vectors
|
78 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
79 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
80 |
+
)
|
81 |
+
|
82 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
83 |
+
hidden_states = self._mask_hidden_states(
|
84 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
85 |
+
)
|
86 |
+
|
87 |
+
encoder_outputs = self.encoder(
|
88 |
+
hidden_states,
|
89 |
+
attention_mask=attention_mask,
|
90 |
+
output_attentions=output_attentions,
|
91 |
+
output_hidden_states=output_hidden_states,
|
92 |
+
return_dict=return_dict,
|
93 |
+
)
|
94 |
+
|
95 |
+
hidden_states = encoder_outputs[0]
|
96 |
+
|
97 |
+
if self.adapter is not None:
|
98 |
+
hidden_states = self.adapter(hidden_states)
|
99 |
+
|
100 |
+
if not return_dict:
|
101 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
102 |
+
return BaseModelOutput(
|
103 |
+
last_hidden_state=hidden_states,
|
104 |
+
hidden_states=encoder_outputs.hidden_states,
|
105 |
+
attentions=encoder_outputs.attentions,
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def feature_extract(
|
110 |
+
self,
|
111 |
+
input_values,
|
112 |
+
seq_len,
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
Extracts features from the input values and returns the extracted features.
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
input_values (torch.Tensor): The input values to be processed.
|
119 |
+
seq_len (torch.Tensor): The sequence lengths of the input values.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
extracted_features (torch.Tensor): The extracted features from the input values.
|
123 |
+
"""
|
124 |
+
extract_features = self.feature_extractor(input_values)
|
125 |
+
extract_features = extract_features.transpose(1, 2)
|
126 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
127 |
+
|
128 |
+
return extract_features
|
129 |
+
|
130 |
+
def encode(
|
131 |
+
self,
|
132 |
+
extract_features,
|
133 |
+
attention_mask=None,
|
134 |
+
mask_time_indices=None,
|
135 |
+
output_attentions=None,
|
136 |
+
output_hidden_states=None,
|
137 |
+
return_dict=None,
|
138 |
+
):
|
139 |
+
"""
|
140 |
+
Encodes the input features into the output space.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
extract_features (torch.Tensor): The extracted features from the audio signal.
|
144 |
+
attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
|
145 |
+
mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
|
146 |
+
output_attentions (bool, optional): If set to True, returns the attention weights.
|
147 |
+
output_hidden_states (bool, optional): If set to True, returns all hidden states.
|
148 |
+
return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
The encoded output features.
|
152 |
+
"""
|
153 |
+
self.config.output_attentions = True
|
154 |
+
|
155 |
+
output_hidden_states = (
|
156 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
157 |
+
)
|
158 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
159 |
+
|
160 |
+
if attention_mask is not None:
|
161 |
+
# compute reduced attention_mask corresponding to feature vectors
|
162 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
163 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
164 |
+
)
|
165 |
+
|
166 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
167 |
+
hidden_states = self._mask_hidden_states(
|
168 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
169 |
+
)
|
170 |
+
|
171 |
+
encoder_outputs = self.encoder(
|
172 |
+
hidden_states,
|
173 |
+
attention_mask=attention_mask,
|
174 |
+
output_attentions=output_attentions,
|
175 |
+
output_hidden_states=output_hidden_states,
|
176 |
+
return_dict=return_dict,
|
177 |
+
)
|
178 |
+
|
179 |
+
hidden_states = encoder_outputs[0]
|
180 |
+
|
181 |
+
if self.adapter is not None:
|
182 |
+
hidden_states = self.adapter(hidden_states)
|
183 |
+
|
184 |
+
if not return_dict:
|
185 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
186 |
+
return BaseModelOutput(
|
187 |
+
last_hidden_state=hidden_states,
|
188 |
+
hidden_states=encoder_outputs.hidden_states,
|
189 |
+
attentions=encoder_outputs.attentions,
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def linear_interpolation(features, seq_len):
|
194 |
+
"""
|
195 |
+
Transpose the features to interpolate linearly.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
199 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
torch.Tensor: The interpolated features.
|
203 |
+
"""
|
204 |
+
features = features.transpose(1, 2)
|
205 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
206 |
+
return output_features.transpose(1, 2)
|
joyhallo/utils/__init__.py
ADDED
File without changes
|
joyhallo/utils/config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module provides utility functions for configuration manipulation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
|
8 |
+
def filter_non_none(dict_obj: Dict):
|
9 |
+
"""
|
10 |
+
Filters out key-value pairs from the given dictionary where the value is None.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
dict_obj (Dict): The dictionary to be filtered.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
Dict: The dictionary with key-value pairs removed where the value was None.
|
17 |
+
|
18 |
+
This function creates a new dictionary containing only the key-value pairs from
|
19 |
+
the original dictionary where the value is not None. It then clears the original
|
20 |
+
dictionary and updates it with the filtered key-value pairs.
|
21 |
+
"""
|
22 |
+
non_none_filter = { k: v for k, v in dict_obj.items() if v is not None }
|
23 |
+
dict_obj.clear()
|
24 |
+
dict_obj.update(non_none_filter)
|
25 |
+
return dict_obj
|
joyhallo/utils/util.py
ADDED
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
utils.py
|
3 |
+
|
4 |
+
This module provides utility functions for various tasks such as setting random seeds,
|
5 |
+
importing modules from files, managing checkpoint files, and saving video files from
|
6 |
+
sequences of PIL images.
|
7 |
+
|
8 |
+
Functions:
|
9 |
+
seed_everything(seed)
|
10 |
+
import_filename(filename)
|
11 |
+
delete_additional_ckpt(base_path, num_keep)
|
12 |
+
save_videos_from_pil(pil_images, path, fps=8)
|
13 |
+
|
14 |
+
Dependencies:
|
15 |
+
importlib
|
16 |
+
os
|
17 |
+
os.path as osp
|
18 |
+
random
|
19 |
+
shutil
|
20 |
+
sys
|
21 |
+
pathlib.Path
|
22 |
+
av
|
23 |
+
cv2
|
24 |
+
mediapipe as mp
|
25 |
+
numpy as np
|
26 |
+
torch
|
27 |
+
torchvision
|
28 |
+
einops.rearrange
|
29 |
+
moviepy.editor.AudioFileClip, VideoClip
|
30 |
+
PIL.Image
|
31 |
+
|
32 |
+
Examples:
|
33 |
+
seed_everything(42)
|
34 |
+
imported_module = import_filename('path/to/your/module.py')
|
35 |
+
delete_additional_ckpt('path/to/checkpoints', 1)
|
36 |
+
save_videos_from_pil(pil_images, 'output/video.mp4', fps=12)
|
37 |
+
|
38 |
+
The functions in this module ensure reproducibility of experiments by seeding random number
|
39 |
+
generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones,
|
40 |
+
and provide a way to save sequences of images as video files.
|
41 |
+
|
42 |
+
Function Details:
|
43 |
+
seed_everything(seed)
|
44 |
+
Seeds all random number generators to ensure reproducibility.
|
45 |
+
|
46 |
+
import_filename(filename)
|
47 |
+
Imports a module from a given file location.
|
48 |
+
|
49 |
+
delete_additional_ckpt(base_path, num_keep)
|
50 |
+
Deletes additional checkpoint files in the given directory.
|
51 |
+
|
52 |
+
save_videos_from_pil(pil_images, path, fps=8)
|
53 |
+
Saves a sequence of images as a video using the Pillow library.
|
54 |
+
|
55 |
+
Attributes:
|
56 |
+
_ (str): Placeholder for static type checking
|
57 |
+
"""
|
58 |
+
|
59 |
+
import importlib
|
60 |
+
import os
|
61 |
+
import os.path as osp
|
62 |
+
import random
|
63 |
+
import shutil
|
64 |
+
import subprocess
|
65 |
+
import sys
|
66 |
+
from pathlib import Path
|
67 |
+
from typing import List
|
68 |
+
|
69 |
+
import av
|
70 |
+
import cv2
|
71 |
+
import mediapipe as mp
|
72 |
+
import numpy as np
|
73 |
+
import torch
|
74 |
+
import torchvision
|
75 |
+
from einops import rearrange
|
76 |
+
from moviepy.editor import AudioFileClip, VideoClip
|
77 |
+
from PIL import Image
|
78 |
+
|
79 |
+
|
80 |
+
def seed_everything(seed):
|
81 |
+
"""
|
82 |
+
Seeds all random number generators to ensure reproducibility.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
seed (int): The seed value to set for all random number generators.
|
86 |
+
"""
|
87 |
+
torch.manual_seed(seed)
|
88 |
+
torch.cuda.manual_seed_all(seed)
|
89 |
+
np.random.seed(seed % (2**32))
|
90 |
+
random.seed(seed)
|
91 |
+
|
92 |
+
|
93 |
+
def import_filename(filename):
|
94 |
+
"""
|
95 |
+
Import a module from a given file location.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
filename (str): The path to the file containing the module to be imported.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
module: The imported module.
|
102 |
+
|
103 |
+
Raises:
|
104 |
+
ImportError: If the module cannot be imported.
|
105 |
+
|
106 |
+
Example:
|
107 |
+
>>> imported_module = import_filename('path/to/your/module.py')
|
108 |
+
"""
|
109 |
+
spec = importlib.util.spec_from_file_location("mymodule", filename)
|
110 |
+
module = importlib.util.module_from_spec(spec)
|
111 |
+
sys.modules[spec.name] = module
|
112 |
+
spec.loader.exec_module(module)
|
113 |
+
return module
|
114 |
+
|
115 |
+
|
116 |
+
def delete_additional_ckpt(base_path, num_keep):
|
117 |
+
"""
|
118 |
+
Deletes additional checkpoint files in the given directory.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
base_path (str): The path to the directory containing the checkpoint files.
|
122 |
+
num_keep (int): The number of most recent checkpoint files to keep.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
None
|
126 |
+
|
127 |
+
Raises:
|
128 |
+
FileNotFoundError: If the base_path does not exist.
|
129 |
+
|
130 |
+
Example:
|
131 |
+
>>> delete_additional_ckpt('path/to/checkpoints', 1)
|
132 |
+
# This will delete all but the most recent checkpoint file in 'path/to/checkpoints'.
|
133 |
+
"""
|
134 |
+
dirs = []
|
135 |
+
for d in os.listdir(base_path):
|
136 |
+
if d.startswith("checkpoint-"):
|
137 |
+
dirs.append(d)
|
138 |
+
num_tot = len(dirs)
|
139 |
+
if num_tot <= num_keep:
|
140 |
+
return
|
141 |
+
# ensure ckpt is sorted and delete the ealier!
|
142 |
+
del_dirs = sorted(dirs, key=lambda x: int(
|
143 |
+
x.split("-")[-1]))[: num_tot - num_keep]
|
144 |
+
for d in del_dirs:
|
145 |
+
path_to_dir = osp.join(base_path, d)
|
146 |
+
if osp.exists(path_to_dir):
|
147 |
+
shutil.rmtree(path_to_dir)
|
148 |
+
|
149 |
+
|
150 |
+
def save_videos_from_pil(pil_images, path, fps=8):
|
151 |
+
"""
|
152 |
+
Save a sequence of images as a video using the Pillow library.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video.
|
156 |
+
path (str): The output file path for the video.
|
157 |
+
fps (int, optional): The frames per second rate of the video. Defaults to 8.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
None
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
ValueError: If the save format is not supported.
|
164 |
+
|
165 |
+
This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate.
|
166 |
+
The output file format is determined by the file extension of the provided path. Supported formats include
|
167 |
+
.mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video
|
168 |
+
creation.
|
169 |
+
"""
|
170 |
+
save_fmt = Path(path).suffix
|
171 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
172 |
+
width, height = pil_images[0].size
|
173 |
+
|
174 |
+
if save_fmt == ".mp4":
|
175 |
+
codec = "libx264"
|
176 |
+
container = av.open(path, "w")
|
177 |
+
stream = container.add_stream(codec, rate=fps)
|
178 |
+
|
179 |
+
stream.width = width
|
180 |
+
stream.height = height
|
181 |
+
|
182 |
+
for pil_image in pil_images:
|
183 |
+
# pil_image = Image.fromarray(image_arr).convert("RGB")
|
184 |
+
av_frame = av.VideoFrame.from_image(pil_image)
|
185 |
+
container.mux(stream.encode(av_frame))
|
186 |
+
container.mux(stream.encode())
|
187 |
+
container.close()
|
188 |
+
|
189 |
+
elif save_fmt == ".gif":
|
190 |
+
pil_images[0].save(
|
191 |
+
fp=path,
|
192 |
+
format="GIF",
|
193 |
+
append_images=pil_images[1:],
|
194 |
+
save_all=True,
|
195 |
+
duration=(1 / fps * 1000),
|
196 |
+
loop=0,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
|
200 |
+
|
201 |
+
|
202 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
203 |
+
"""
|
204 |
+
Save a grid of videos as an animation or video.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width)
|
208 |
+
containing the videos to save.
|
209 |
+
path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif.
|
210 |
+
rescale (bool, optional): If True, rescale the video to the original resolution.
|
211 |
+
Defaults to False.
|
212 |
+
n_rows (int, optional): The number of rows in the video grid. Defaults to 6.
|
213 |
+
fps (int, optional): The frame rate of the saved video. Defaults to 8.
|
214 |
+
|
215 |
+
Raises:
|
216 |
+
ValueError: If the video format is not supported.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
None
|
220 |
+
"""
|
221 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
222 |
+
# height, width = videos.shape[-2:]
|
223 |
+
outputs = []
|
224 |
+
|
225 |
+
for x in videos:
|
226 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
|
227 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
|
228 |
+
if rescale:
|
229 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
230 |
+
x = (x * 255).numpy().astype(np.uint8)
|
231 |
+
x = Image.fromarray(x)
|
232 |
+
|
233 |
+
outputs.append(x)
|
234 |
+
|
235 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
236 |
+
|
237 |
+
save_videos_from_pil(outputs, path, fps)
|
238 |
+
|
239 |
+
|
240 |
+
def read_frames(video_path):
|
241 |
+
"""
|
242 |
+
Reads video frames from a given video file.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
video_path (str): The path to the video file.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
container (av.container.InputContainer): The input container object
|
249 |
+
containing the video stream.
|
250 |
+
|
251 |
+
Raises:
|
252 |
+
FileNotFoundError: If the video file is not found.
|
253 |
+
RuntimeError: If there is an error in reading the video stream.
|
254 |
+
|
255 |
+
The function reads the video frames from the specified video file using the
|
256 |
+
Python AV library (av). It returns an input container object that contains
|
257 |
+
the video stream. If the video file is not found, it raises a FileNotFoundError,
|
258 |
+
and if there is an error in reading the video stream, it raises a RuntimeError.
|
259 |
+
"""
|
260 |
+
container = av.open(video_path)
|
261 |
+
|
262 |
+
video_stream = next(s for s in container.streams if s.type == "video")
|
263 |
+
frames = []
|
264 |
+
for packet in container.demux(video_stream):
|
265 |
+
for frame in packet.decode():
|
266 |
+
image = Image.frombytes(
|
267 |
+
"RGB",
|
268 |
+
(frame.width, frame.height),
|
269 |
+
frame.to_rgb().to_ndarray(),
|
270 |
+
)
|
271 |
+
frames.append(image)
|
272 |
+
|
273 |
+
return frames
|
274 |
+
|
275 |
+
|
276 |
+
def get_fps(video_path):
|
277 |
+
"""
|
278 |
+
Get the frame rate (FPS) of a video file.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
video_path (str): The path to the video file.
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
int: The frame rate (FPS) of the video file.
|
285 |
+
"""
|
286 |
+
container = av.open(video_path)
|
287 |
+
video_stream = next(s for s in container.streams if s.type == "video")
|
288 |
+
fps = video_stream.average_rate
|
289 |
+
container.close()
|
290 |
+
return fps
|
291 |
+
|
292 |
+
|
293 |
+
def tensor_to_video(tensor, output_video_file, audio_source, fps=25):
|
294 |
+
"""
|
295 |
+
Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
|
299 |
+
output_video_file (str): The file path where the output video will be saved.
|
300 |
+
audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added.
|
301 |
+
fps (int): The frame rate of the output video. Default is 25 fps.
|
302 |
+
"""
|
303 |
+
tensor = tensor.permute(1, 2, 3, 0).cpu(
|
304 |
+
).numpy() # convert to [f, h, w, c]
|
305 |
+
tensor = np.clip(tensor * 255, 0, 255).astype(
|
306 |
+
np.uint8
|
307 |
+
) # to [0, 255]
|
308 |
+
|
309 |
+
def make_frame(t):
|
310 |
+
# get index
|
311 |
+
frame_index = min(int(t * fps), tensor.shape[0] - 1)
|
312 |
+
return tensor[frame_index]
|
313 |
+
new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps)
|
314 |
+
audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps)
|
315 |
+
new_video_clip = new_video_clip.set_audio(audio_clip)
|
316 |
+
new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac')
|
317 |
+
|
318 |
+
|
319 |
+
silhouette_ids = [
|
320 |
+
10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
|
321 |
+
397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
|
322 |
+
172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109
|
323 |
+
]
|
324 |
+
lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291,
|
325 |
+
146, 91, 181, 84, 17, 314, 405, 321, 375]
|
326 |
+
|
327 |
+
|
328 |
+
def compute_face_landmarks(detection_result, h, w):
|
329 |
+
"""
|
330 |
+
Compute face landmarks from a detection result.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks.
|
334 |
+
h (int): The height of the video frame.
|
335 |
+
w (int): The width of the video frame.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
face_landmarks_list (list): A list of face landmarks.
|
339 |
+
"""
|
340 |
+
face_landmarks_list = detection_result.face_landmarks
|
341 |
+
if len(face_landmarks_list) != 1:
|
342 |
+
print("#face is invalid:", len(face_landmarks_list))
|
343 |
+
return []
|
344 |
+
return [[p.x * w, p.y * h] for p in face_landmarks_list[0]]
|
345 |
+
|
346 |
+
|
347 |
+
def get_landmark(file):
|
348 |
+
"""
|
349 |
+
This function takes a file as input and returns the facial landmarks detected in the file.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
file (str): The path to the file containing the video or image to be processed.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks.
|
356 |
+
"""
|
357 |
+
model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
|
358 |
+
BaseOptions = mp.tasks.BaseOptions
|
359 |
+
FaceLandmarker = mp.tasks.vision.FaceLandmarker
|
360 |
+
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
|
361 |
+
VisionRunningMode = mp.tasks.vision.RunningMode
|
362 |
+
# Create a face landmarker instance with the video mode:
|
363 |
+
options = FaceLandmarkerOptions(
|
364 |
+
base_options=BaseOptions(model_asset_path=model_path),
|
365 |
+
running_mode=VisionRunningMode.IMAGE,
|
366 |
+
)
|
367 |
+
|
368 |
+
with FaceLandmarker.create_from_options(options) as landmarker:
|
369 |
+
image = mp.Image.create_from_file(str(file))
|
370 |
+
height, width = image.height, image.width
|
371 |
+
face_landmarker_result = landmarker.detect(image)
|
372 |
+
face_landmark = compute_face_landmarks(
|
373 |
+
face_landmarker_result, height, width)
|
374 |
+
|
375 |
+
return np.array(face_landmark), height, width
|
376 |
+
|
377 |
+
|
378 |
+
def get_landmark_overframes(landmark_model, frames_path):
|
379 |
+
"""
|
380 |
+
This function iterate frames and returns the facial landmarks detected in each frame.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
landmark_model: mediapipe landmark model instance
|
384 |
+
frames_path (str): The path to the video frames.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks.
|
388 |
+
"""
|
389 |
+
|
390 |
+
face_landmarks = []
|
391 |
+
|
392 |
+
for file in sorted(os.listdir(frames_path)):
|
393 |
+
image = mp.Image.create_from_file(os.path.join(frames_path, file))
|
394 |
+
height, width = image.height, image.width
|
395 |
+
landmarker_result = landmark_model.detect(image)
|
396 |
+
frame_landmark = compute_face_landmarks(
|
397 |
+
landmarker_result, height, width)
|
398 |
+
face_landmarks.append(frame_landmark)
|
399 |
+
|
400 |
+
return face_landmarks, height, width
|
401 |
+
|
402 |
+
|
403 |
+
def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0):
|
404 |
+
"""
|
405 |
+
Extracts the lip region from the given landmarks and saves it as an image.
|
406 |
+
|
407 |
+
Parameters:
|
408 |
+
landmarks (numpy.ndarray): Array of facial landmarks.
|
409 |
+
height (int): Height of the output lip mask image.
|
410 |
+
width (int): Width of the output lip mask image.
|
411 |
+
out_path (pathlib.Path): Path to save the lip mask image.
|
412 |
+
expand_ratio (float): Expand ratio of mask.
|
413 |
+
"""
|
414 |
+
lip_landmarks = np.take(landmarks, lip_ids, 0)
|
415 |
+
min_xy_lip = np.round(np.min(lip_landmarks, 0))
|
416 |
+
max_xy_lip = np.round(np.max(lip_landmarks, 0))
|
417 |
+
min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region(
|
418 |
+
[min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio)
|
419 |
+
lip_mask = np.zeros((height, width), dtype=np.uint8)
|
420 |
+
lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]),
|
421 |
+
round(min_xy_lip[0]):round(max_xy_lip[0])] = 255
|
422 |
+
if out_path:
|
423 |
+
cv2.imwrite(str(out_path), lip_mask)
|
424 |
+
return None
|
425 |
+
|
426 |
+
return lip_mask
|
427 |
+
|
428 |
+
|
429 |
+
def get_union_lip_mask(landmarks, height, width, expand_ratio=1):
|
430 |
+
"""
|
431 |
+
Extracts the lip region from the given landmarks and saves it as an image.
|
432 |
+
|
433 |
+
Parameters:
|
434 |
+
landmarks (numpy.ndarray): Array of facial landmarks.
|
435 |
+
height (int): Height of the output lip mask image.
|
436 |
+
width (int): Width of the output lip mask image.
|
437 |
+
expand_ratio (float): Expand ratio of mask.
|
438 |
+
"""
|
439 |
+
lip_masks = []
|
440 |
+
for landmark in landmarks:
|
441 |
+
lip_masks.append(get_lip_mask(landmarks=landmark, height=height,
|
442 |
+
width=width, expand_ratio=expand_ratio))
|
443 |
+
union_mask = get_union_mask(lip_masks)
|
444 |
+
return union_mask
|
445 |
+
|
446 |
+
|
447 |
+
def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2):
|
448 |
+
"""
|
449 |
+
Generate a face mask based on the given landmarks.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
landmarks (numpy.ndarray): The landmarks of the face.
|
453 |
+
height (int): The height of the output face mask image.
|
454 |
+
width (int): The width of the output face mask image.
|
455 |
+
out_path (pathlib.Path): The path to save the face mask image.
|
456 |
+
expand_ratio (float): Expand ratio of mask.
|
457 |
+
Returns:
|
458 |
+
None. The face mask image is saved at the specified path.
|
459 |
+
"""
|
460 |
+
face_landmarks = np.take(landmarks, silhouette_ids, 0)
|
461 |
+
min_xy_face = np.round(np.min(face_landmarks, 0))
|
462 |
+
max_xy_face = np.round(np.max(face_landmarks, 0))
|
463 |
+
min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region(
|
464 |
+
[min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio)
|
465 |
+
face_mask = np.zeros((height, width), dtype=np.uint8)
|
466 |
+
face_mask[round(min_xy_face[1]):round(max_xy_face[1]),
|
467 |
+
round(min_xy_face[0]):round(max_xy_face[0])] = 255
|
468 |
+
if out_path:
|
469 |
+
cv2.imwrite(str(out_path), face_mask)
|
470 |
+
return None
|
471 |
+
|
472 |
+
return face_mask
|
473 |
+
|
474 |
+
|
475 |
+
def get_union_face_mask(landmarks, height, width, expand_ratio=1):
|
476 |
+
"""
|
477 |
+
Generate a face mask based on the given landmarks.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
landmarks (numpy.ndarray): The landmarks of the face.
|
481 |
+
height (int): The height of the output face mask image.
|
482 |
+
width (int): The width of the output face mask image.
|
483 |
+
expand_ratio (float): Expand ratio of mask.
|
484 |
+
Returns:
|
485 |
+
None. The face mask image is saved at the specified path.
|
486 |
+
"""
|
487 |
+
face_masks = []
|
488 |
+
for landmark in landmarks:
|
489 |
+
face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio))
|
490 |
+
union_mask = get_union_mask(face_masks)
|
491 |
+
return union_mask
|
492 |
+
|
493 |
+
def get_mask(file, cache_dir, face_expand_raio):
|
494 |
+
"""
|
495 |
+
Generate a face mask based on the given landmarks and save it to the specified cache directory.
|
496 |
+
|
497 |
+
Args:
|
498 |
+
file (str): The path to the file containing the landmarks.
|
499 |
+
cache_dir (str): The directory to save the generated face mask.
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
None
|
503 |
+
"""
|
504 |
+
landmarks, height, width = get_landmark(file)
|
505 |
+
file_name = os.path.basename(file).split(".")[0]
|
506 |
+
get_lip_mask(landmarks, height, width, os.path.join(
|
507 |
+
cache_dir, f"{file_name}_lip_mask.png"))
|
508 |
+
get_face_mask(landmarks, height, width, os.path.join(
|
509 |
+
cache_dir, f"{file_name}_face_mask.png"), face_expand_raio)
|
510 |
+
get_blur_mask(os.path.join(
|
511 |
+
cache_dir, f"{file_name}_face_mask.png"), os.path.join(
|
512 |
+
cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51))
|
513 |
+
get_blur_mask(os.path.join(
|
514 |
+
cache_dir, f"{file_name}_lip_mask.png"), os.path.join(
|
515 |
+
cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31))
|
516 |
+
get_background_mask(os.path.join(
|
517 |
+
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
|
518 |
+
cache_dir, f"{file_name}_sep_background.png"))
|
519 |
+
get_sep_face_mask(os.path.join(
|
520 |
+
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
|
521 |
+
cache_dir, f"{file_name}_sep_lip.png"), os.path.join(
|
522 |
+
cache_dir, f"{file_name}_sep_face.png"))
|
523 |
+
|
524 |
+
|
525 |
+
def expand_region(region, image_w, image_h, expand_ratio=1.0):
|
526 |
+
"""
|
527 |
+
Expand the given region by a specified ratio.
|
528 |
+
Args:
|
529 |
+
region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region.
|
530 |
+
image_w (int): The width of the image.
|
531 |
+
image_h (int): The height of the image.
|
532 |
+
expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0.
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region.
|
536 |
+
"""
|
537 |
+
|
538 |
+
min_x, max_x, min_y, max_y = region
|
539 |
+
mid_x = (max_x + min_x) // 2
|
540 |
+
side_len_x = (max_x - min_x) * expand_ratio
|
541 |
+
mid_y = (max_y + min_y) // 2
|
542 |
+
side_len_y = (max_y - min_y) * expand_ratio
|
543 |
+
min_x = mid_x - side_len_x // 2
|
544 |
+
max_x = mid_x + side_len_x // 2
|
545 |
+
min_y = mid_y - side_len_y // 2
|
546 |
+
max_y = mid_y + side_len_y // 2
|
547 |
+
if min_x < 0:
|
548 |
+
max_x -= min_x
|
549 |
+
min_x = 0
|
550 |
+
if max_x > image_w:
|
551 |
+
min_x -= max_x - image_w
|
552 |
+
max_x = image_w
|
553 |
+
if min_y < 0:
|
554 |
+
max_y -= min_y
|
555 |
+
min_y = 0
|
556 |
+
if max_y > image_h:
|
557 |
+
min_y -= max_y - image_h
|
558 |
+
max_y = image_h
|
559 |
+
|
560 |
+
return round(min_x), round(max_x), round(min_y), round(max_y)
|
561 |
+
|
562 |
+
|
563 |
+
def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)):
|
564 |
+
"""
|
565 |
+
Read, resize, blur, normalize, and save an image.
|
566 |
+
|
567 |
+
Parameters:
|
568 |
+
file_path (str): Path to the input image file.
|
569 |
+
output_dir (str): Path to the output directory to save blurred images.
|
570 |
+
resize_dim (tuple): Dimensions to resize the images to.
|
571 |
+
kernel_size (tuple): Size of the kernel to use for Gaussian blur.
|
572 |
+
"""
|
573 |
+
# Read the mask image
|
574 |
+
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
|
575 |
+
|
576 |
+
# Check if the image is loaded successfully
|
577 |
+
if mask is not None:
|
578 |
+
normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size)
|
579 |
+
# Save the normalized mask image
|
580 |
+
cv2.imwrite(output_file_path, normalized_mask)
|
581 |
+
return f"Processed, normalized, and saved: {output_file_path}"
|
582 |
+
return f"Failed to load image: {file_path}"
|
583 |
+
|
584 |
+
|
585 |
+
def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)):
|
586 |
+
"""
|
587 |
+
Read, resize, blur, normalize, and save an image.
|
588 |
+
|
589 |
+
Parameters:
|
590 |
+
file_path (str): Path to the input image file.
|
591 |
+
resize_dim (tuple): Dimensions to resize the images to.
|
592 |
+
kernel_size (tuple): Size of the kernel to use for Gaussian blur.
|
593 |
+
"""
|
594 |
+
# Check if the image is loaded successfully
|
595 |
+
normalized_mask = None
|
596 |
+
if mask is not None:
|
597 |
+
# Resize the mask image
|
598 |
+
resized_mask = cv2.resize(mask, resize_dim)
|
599 |
+
# Apply Gaussian blur to the resized mask image
|
600 |
+
blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0)
|
601 |
+
# Normalize the blurred image
|
602 |
+
normalized_mask = cv2.normalize(
|
603 |
+
blurred_mask, None, 0, 255, cv2.NORM_MINMAX)
|
604 |
+
# Save the normalized mask image
|
605 |
+
return normalized_mask
|
606 |
+
|
607 |
+
def get_background_mask(file_path, output_file_path):
|
608 |
+
"""
|
609 |
+
Read an image, invert its values, and save the result.
|
610 |
+
|
611 |
+
Parameters:
|
612 |
+
file_path (str): Path to the input image file.
|
613 |
+
output_dir (str): Path to the output directory to save the inverted image.
|
614 |
+
"""
|
615 |
+
# Read the image
|
616 |
+
image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
|
617 |
+
|
618 |
+
if image is None:
|
619 |
+
print(f"Failed to load image: {file_path}")
|
620 |
+
return
|
621 |
+
|
622 |
+
# Invert the image
|
623 |
+
inverted_image = 1.0 - (
|
624 |
+
image / 255.0
|
625 |
+
) # Assuming the image values are in [0, 255] range
|
626 |
+
# Convert back to uint8
|
627 |
+
inverted_image = (inverted_image * 255).astype(np.uint8)
|
628 |
+
|
629 |
+
# Save the inverted image
|
630 |
+
cv2.imwrite(output_file_path, inverted_image)
|
631 |
+
print(f"Processed and saved: {output_file_path}")
|
632 |
+
|
633 |
+
|
634 |
+
def get_sep_face_mask(file_path1, file_path2, output_file_path):
|
635 |
+
"""
|
636 |
+
Read two images, subtract the second one from the first, and save the result.
|
637 |
+
|
638 |
+
Parameters:
|
639 |
+
output_dir (str): Path to the output directory to save the subtracted image.
|
640 |
+
"""
|
641 |
+
|
642 |
+
# Read the images
|
643 |
+
mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE)
|
644 |
+
mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE)
|
645 |
+
|
646 |
+
if mask1 is None or mask2 is None:
|
647 |
+
print(f"Failed to load images: {file_path1}")
|
648 |
+
return
|
649 |
+
|
650 |
+
# Ensure the images are the same size
|
651 |
+
if mask1.shape != mask2.shape:
|
652 |
+
print(
|
653 |
+
f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}"
|
654 |
+
)
|
655 |
+
return
|
656 |
+
|
657 |
+
# Subtract the second mask from the first
|
658 |
+
result_mask = cv2.subtract(mask1, mask2)
|
659 |
+
|
660 |
+
# Save the result mask image
|
661 |
+
cv2.imwrite(output_file_path, result_mask)
|
662 |
+
print(f"Processed and saved: {output_file_path}")
|
663 |
+
|
664 |
+
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
|
665 |
+
p = subprocess.Popen([
|
666 |
+
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
|
667 |
+
])
|
668 |
+
ret = p.wait()
|
669 |
+
assert ret == 0, "Resample audio failed!"
|
670 |
+
return output_audio_file
|
671 |
+
|
672 |
+
def get_face_region(image_path: str, detector):
|
673 |
+
try:
|
674 |
+
image = cv2.imread(image_path)
|
675 |
+
if image is None:
|
676 |
+
print(f"Failed to open image: {image_path}. Skipping...")
|
677 |
+
return None, None
|
678 |
+
|
679 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
680 |
+
detection_result = detector.detect(mp_image)
|
681 |
+
|
682 |
+
# Adjust mask creation for the three-channel image
|
683 |
+
mask = np.zeros_like(image, dtype=np.uint8)
|
684 |
+
|
685 |
+
for detection in detection_result.detections:
|
686 |
+
bbox = detection.bounding_box
|
687 |
+
start_point = (int(bbox.origin_x), int(bbox.origin_y))
|
688 |
+
end_point = (int(bbox.origin_x + bbox.width),
|
689 |
+
int(bbox.origin_y + bbox.height))
|
690 |
+
cv2.rectangle(mask, start_point, end_point,
|
691 |
+
(255, 255, 255), thickness=-1)
|
692 |
+
|
693 |
+
save_path = image_path.replace("images", "face_masks")
|
694 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
695 |
+
cv2.imwrite(save_path, mask)
|
696 |
+
# print(f"Processed and saved {save_path}")
|
697 |
+
return image_path, mask
|
698 |
+
except Exception as e:
|
699 |
+
print(f"Error processing image {image_path}: {e}")
|
700 |
+
return None, None
|
701 |
+
|
702 |
+
|
703 |
+
def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
|
704 |
+
"""
|
705 |
+
Save the model's state_dict to a checkpoint file.
|
706 |
+
|
707 |
+
If `total_limit` is provided, this function will remove the oldest checkpoints
|
708 |
+
until the total number of checkpoints is less than the specified limit.
|
709 |
+
|
710 |
+
Args:
|
711 |
+
model (nn.Module): The model whose state_dict is to be saved.
|
712 |
+
save_dir (str): The directory where the checkpoint will be saved.
|
713 |
+
prefix (str): The prefix for the checkpoint file name.
|
714 |
+
ckpt_num (int): The checkpoint number to be saved.
|
715 |
+
total_limit (int, optional): The maximum number of checkpoints to keep.
|
716 |
+
Defaults to None, in which case no checkpoints will be removed.
|
717 |
+
|
718 |
+
Raises:
|
719 |
+
FileNotFoundError: If the save directory does not exist.
|
720 |
+
ValueError: If the checkpoint number is negative.
|
721 |
+
OSError: If there is an error saving the checkpoint.
|
722 |
+
"""
|
723 |
+
|
724 |
+
if not osp.exists(save_dir):
|
725 |
+
raise FileNotFoundError(
|
726 |
+
f"The save directory {save_dir} does not exist.")
|
727 |
+
|
728 |
+
if ckpt_num < 0:
|
729 |
+
raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")
|
730 |
+
|
731 |
+
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
|
732 |
+
|
733 |
+
if total_limit > 0:
|
734 |
+
checkpoints = os.listdir(save_dir)
|
735 |
+
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
|
736 |
+
checkpoints = sorted(
|
737 |
+
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
738 |
+
)
|
739 |
+
|
740 |
+
if len(checkpoints) >= total_limit:
|
741 |
+
num_to_remove = len(checkpoints) - total_limit + 1
|
742 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
743 |
+
print(
|
744 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
745 |
+
)
|
746 |
+
print(
|
747 |
+
f"Removing checkpoints: {', '.join(removing_checkpoints)}"
|
748 |
+
)
|
749 |
+
|
750 |
+
for removing_checkpoint in removing_checkpoints:
|
751 |
+
removing_checkpoint_path = osp.join(
|
752 |
+
save_dir, removing_checkpoint)
|
753 |
+
try:
|
754 |
+
os.remove(removing_checkpoint_path)
|
755 |
+
except OSError as e:
|
756 |
+
print(
|
757 |
+
f"Error removing checkpoint {removing_checkpoint_path}: {e}")
|
758 |
+
|
759 |
+
state_dict = model.state_dict()
|
760 |
+
try:
|
761 |
+
torch.save(state_dict, save_path)
|
762 |
+
print(f"Checkpoint saved at {save_path}")
|
763 |
+
except OSError as e:
|
764 |
+
raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e
|
765 |
+
|
766 |
+
|
767 |
+
def init_output_dir(dir_list: List[str]):
|
768 |
+
"""
|
769 |
+
Initialize the output directories.
|
770 |
+
|
771 |
+
This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
|
772 |
+
|
773 |
+
Args:
|
774 |
+
dir_list (List[str]): List of directory paths to create.
|
775 |
+
"""
|
776 |
+
for path in dir_list:
|
777 |
+
os.makedirs(path, exist_ok=True)
|
778 |
+
|
779 |
+
|
780 |
+
def load_checkpoint(cfg, save_dir, accelerator):
|
781 |
+
"""
|
782 |
+
Load the most recent checkpoint from the specified directory.
|
783 |
+
|
784 |
+
This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
|
785 |
+
If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
|
786 |
+
it starts training from scratch.
|
787 |
+
|
788 |
+
Args:
|
789 |
+
cfg: The configuration object containing training parameters.
|
790 |
+
save_dir (str): The directory where checkpoints are saved.
|
791 |
+
accelerator: The accelerator object for distributed training.
|
792 |
+
|
793 |
+
Returns:
|
794 |
+
int: The global step at which to resume training.
|
795 |
+
"""
|
796 |
+
if cfg.resume_from_checkpoint != "latest":
|
797 |
+
resume_dir = cfg.resume_from_checkpoint
|
798 |
+
else:
|
799 |
+
resume_dir = save_dir
|
800 |
+
# Get the most recent checkpoint
|
801 |
+
dirs = os.listdir(resume_dir)
|
802 |
+
|
803 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
804 |
+
if len(dirs) > 0:
|
805 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
806 |
+
path = dirs[-1]
|
807 |
+
accelerator.load_state(os.path.join(resume_dir, path))
|
808 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
809 |
+
global_step = int(path.split("-")[1])
|
810 |
+
else:
|
811 |
+
accelerator.print(
|
812 |
+
f"Could not find checkpoint under {resume_dir}, start training from scratch")
|
813 |
+
global_step = 0
|
814 |
+
|
815 |
+
return global_step
|
816 |
+
|
817 |
+
|
818 |
+
def compute_snr(noise_scheduler, timesteps):
|
819 |
+
"""
|
820 |
+
Computes SNR as per
|
821 |
+
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
|
822 |
+
521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
823 |
+
"""
|
824 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
825 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
826 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
827 |
+
|
828 |
+
# Expand the tensors.
|
829 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
|
830 |
+
# 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
831 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
832 |
+
timesteps
|
833 |
+
].float()
|
834 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
835 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
836 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
837 |
+
|
838 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
839 |
+
device=timesteps.device
|
840 |
+
)[timesteps].float()
|
841 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
842 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
843 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
844 |
+
|
845 |
+
# Compute SNR.
|
846 |
+
snr = (alpha / sigma) ** 2
|
847 |
+
return snr
|
848 |
+
|
849 |
+
|
850 |
+
def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path:
|
851 |
+
"""
|
852 |
+
Extract audio from a video file and save it as a WAV file.
|
853 |
+
|
854 |
+
This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file
|
855 |
+
in the specified output directory.
|
856 |
+
|
857 |
+
Args:
|
858 |
+
video_path (Path): The path to the input video file.
|
859 |
+
output_dir (Path): The directory where the extracted audio file will be saved.
|
860 |
+
|
861 |
+
Returns:
|
862 |
+
Path: The path to the extracted audio file.
|
863 |
+
|
864 |
+
Raises:
|
865 |
+
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
|
866 |
+
"""
|
867 |
+
ffmpeg_command = [
|
868 |
+
'ffmpeg', '-y',
|
869 |
+
'-i', str(video_path),
|
870 |
+
'-vn', '-acodec',
|
871 |
+
"pcm_s16le", '-ar', '16000', '-ac', '2',
|
872 |
+
str(audio_output_path)
|
873 |
+
]
|
874 |
+
|
875 |
+
try:
|
876 |
+
print(f"Running command: {' '.join(ffmpeg_command)}")
|
877 |
+
subprocess.run(ffmpeg_command, check=True)
|
878 |
+
except subprocess.CalledProcessError as e:
|
879 |
+
print(f"Error extracting audio from video: {e}")
|
880 |
+
raise
|
881 |
+
|
882 |
+
return audio_output_path
|
883 |
+
|
884 |
+
|
885 |
+
def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
|
886 |
+
"""
|
887 |
+
Convert a video file into a sequence of images.
|
888 |
+
|
889 |
+
This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved
|
890 |
+
in a directory named after the video file stem under the specified output directory.
|
891 |
+
|
892 |
+
Args:
|
893 |
+
video_path (Path): The path to the input video file.
|
894 |
+
output_dir (Path): The directory where the extracted images will be saved.
|
895 |
+
|
896 |
+
Returns:
|
897 |
+
Path: The path to the directory containing the extracted images.
|
898 |
+
|
899 |
+
Raises:
|
900 |
+
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
|
901 |
+
"""
|
902 |
+
ffmpeg_command = [
|
903 |
+
'ffmpeg',
|
904 |
+
'-i', str(video_path),
|
905 |
+
'-vf', 'fps=25',
|
906 |
+
str(output_dir / '%04d.png')
|
907 |
+
]
|
908 |
+
|
909 |
+
try:
|
910 |
+
print(f"Running command: {' '.join(ffmpeg_command)}")
|
911 |
+
subprocess.run(ffmpeg_command, check=True)
|
912 |
+
except subprocess.CalledProcessError as e:
|
913 |
+
print(f"Error converting video to images: {e}")
|
914 |
+
raise
|
915 |
+
|
916 |
+
return output_dir
|
917 |
+
|
918 |
+
|
919 |
+
def get_union_mask(masks):
|
920 |
+
"""
|
921 |
+
Compute the union of a list of masks.
|
922 |
+
|
923 |
+
This function takes a list of masks and computes their union by taking the maximum value at each pixel location.
|
924 |
+
Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white.
|
925 |
+
|
926 |
+
Args:
|
927 |
+
masks (list of np.ndarray): List of masks to be combined.
|
928 |
+
|
929 |
+
Returns:
|
930 |
+
np.ndarray: The union of the input masks.
|
931 |
+
"""
|
932 |
+
union_mask = None
|
933 |
+
for mask in masks:
|
934 |
+
if union_mask is None:
|
935 |
+
union_mask = mask
|
936 |
+
else:
|
937 |
+
union_mask = np.maximum(union_mask, mask)
|
938 |
+
|
939 |
+
if union_mask is not None:
|
940 |
+
# Find the bounding box of the non-zero regions in the mask
|
941 |
+
rows = np.any(union_mask, axis=1)
|
942 |
+
cols = np.any(union_mask, axis=0)
|
943 |
+
try:
|
944 |
+
ymin, ymax = np.where(rows)[0][[0, -1]]
|
945 |
+
xmin, xmax = np.where(cols)[0][[0, -1]]
|
946 |
+
except Exception as e:
|
947 |
+
print(str(e))
|
948 |
+
return 0.0
|
949 |
+
|
950 |
+
# Set bounding box area to white
|
951 |
+
union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask)
|
952 |
+
|
953 |
+
return union_mask
|
954 |
+
|
955 |
+
|
956 |
+
def move_final_checkpoint(save_dir, module_dir, prefix):
|
957 |
+
"""
|
958 |
+
Move the final checkpoint file to the save directory.
|
959 |
+
|
960 |
+
This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory.
|
961 |
+
|
962 |
+
Args:
|
963 |
+
save_dir (str): The directory where the final checkpoint file should be saved.
|
964 |
+
module_dir (str): The directory containing the checkpoint files.
|
965 |
+
prefix (str): The prefix used to identify checkpoint files.
|
966 |
+
|
967 |
+
Raises:
|
968 |
+
ValueError: If no checkpoint files are found with the specified prefix.
|
969 |
+
"""
|
970 |
+
checkpoints = os.listdir(module_dir)
|
971 |
+
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
|
972 |
+
checkpoints = sorted(
|
973 |
+
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
974 |
+
)
|
975 |
+
shutil.copy2(os.path.join(
|
976 |
+
module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth'))
|
scripts/inference.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is a gradio web ui.
|
3 |
+
|
4 |
+
The script takes an image and an audio clip, and lets you configure all the
|
5 |
+
variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc.
|
6 |
+
|
7 |
+
Usage:
|
8 |
+
This script can be run from the command line with the following command:
|
9 |
+
|
10 |
+
python scripts/app.py
|
11 |
+
"""
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import argparse
|
15 |
+
import copy
|
16 |
+
import logging
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
import time
|
21 |
+
import warnings
|
22 |
+
from datetime import datetime
|
23 |
+
from typing import List, Tuple
|
24 |
+
|
25 |
+
import diffusers
|
26 |
+
import mlflow
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
import transformers
|
31 |
+
from accelerate import Accelerator
|
32 |
+
from accelerate.logging import get_logger
|
33 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
34 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
35 |
+
from diffusers.optimization import get_scheduler
|
36 |
+
from diffusers.utils import check_min_version
|
37 |
+
from diffusers.utils.import_utils import is_xformers_available
|
38 |
+
from einops import rearrange, repeat
|
39 |
+
from omegaconf import OmegaConf
|
40 |
+
from torch import nn
|
41 |
+
from tqdm.auto import tqdm
|
42 |
+
import uuid
|
43 |
+
|
44 |
+
import sys
|
45 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
46 |
+
|
47 |
+
from joyhallo.animate.face_animate import FaceAnimatePipeline
|
48 |
+
from joyhallo.datasets.audio_processor import AudioProcessor
|
49 |
+
from joyhallo.datasets.image_processor import ImageProcessor
|
50 |
+
from joyhallo.datasets.talk_video import TalkingVideoDataset
|
51 |
+
from joyhallo.models.audio_proj import AudioProjModel
|
52 |
+
from joyhallo.models.face_locator import FaceLocator
|
53 |
+
from joyhallo.models.image_proj import ImageProjModel
|
54 |
+
from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
|
55 |
+
from joyhallo.models.unet_2d_condition import UNet2DConditionModel
|
56 |
+
from joyhallo.models.unet_3d import UNet3DConditionModel
|
57 |
+
from joyhallo.utils.util import (compute_snr, delete_additional_ckpt,
|
58 |
+
import_filename, init_output_dir,
|
59 |
+
load_checkpoint, save_checkpoint,
|
60 |
+
seed_everything, tensor_to_video)
|
61 |
+
|
62 |
+
warnings.filterwarnings("ignore")
|
63 |
+
|
64 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
65 |
+
check_min_version("0.10.0.dev0")
|
66 |
+
|
67 |
+
logger = get_logger(__name__, log_level="INFO")
|
68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
69 |
+
|
70 |
+
|
71 |
+
class Net(nn.Module):
|
72 |
+
"""
|
73 |
+
The Net class defines a neural network model that combines a reference UNet2DConditionModel,
|
74 |
+
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
|
78 |
+
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
|
79 |
+
face_locator (FaceLocator): The face locator model used for face animation.
|
80 |
+
reference_control_writer: The reference control writer component.
|
81 |
+
reference_control_reader: The reference control reader component.
|
82 |
+
imageproj: The image projection model.
|
83 |
+
audioproj: The audio projection model.
|
84 |
+
|
85 |
+
Forward method:
|
86 |
+
noisy_latents (torch.Tensor): The noisy latents tensor.
|
87 |
+
timesteps (torch.Tensor): The timesteps tensor.
|
88 |
+
ref_image_latents (torch.Tensor): The reference image latents tensor.
|
89 |
+
face_emb (torch.Tensor): The face embeddings tensor.
|
90 |
+
audio_emb (torch.Tensor): The audio embeddings tensor.
|
91 |
+
mask (torch.Tensor): Hard face mask for face locator.
|
92 |
+
full_mask (torch.Tensor): Pose Mask.
|
93 |
+
face_mask (torch.Tensor): Face Mask
|
94 |
+
lip_mask (torch.Tensor): Lip Mask
|
95 |
+
uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass.
|
96 |
+
uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: The output tensor of the neural network model.
|
100 |
+
"""
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
reference_unet: UNet2DConditionModel,
|
104 |
+
denoising_unet: UNet3DConditionModel,
|
105 |
+
face_locator: FaceLocator,
|
106 |
+
reference_control_writer,
|
107 |
+
reference_control_reader,
|
108 |
+
imageproj,
|
109 |
+
audioproj,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
self.reference_unet = reference_unet
|
113 |
+
self.denoising_unet = denoising_unet
|
114 |
+
self.face_locator = face_locator
|
115 |
+
self.reference_control_writer = reference_control_writer
|
116 |
+
self.reference_control_reader = reference_control_reader
|
117 |
+
self.imageproj = imageproj
|
118 |
+
self.audioproj = audioproj
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
noisy_latents: torch.Tensor,
|
123 |
+
timesteps: torch.Tensor,
|
124 |
+
ref_image_latents: torch.Tensor,
|
125 |
+
face_emb: torch.Tensor,
|
126 |
+
audio_emb: torch.Tensor,
|
127 |
+
mask: torch.Tensor,
|
128 |
+
full_mask: torch.Tensor,
|
129 |
+
face_mask: torch.Tensor,
|
130 |
+
lip_mask: torch.Tensor,
|
131 |
+
uncond_img_fwd: bool = False,
|
132 |
+
uncond_audio_fwd: bool = False,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
simple docstring to prevent pylint error
|
136 |
+
"""
|
137 |
+
face_emb = self.imageproj(face_emb)
|
138 |
+
mask = mask.to(device=device)
|
139 |
+
mask_feature = self.face_locator(mask)
|
140 |
+
audio_emb = audio_emb.to(
|
141 |
+
device=self.audioproj.device, dtype=self.audioproj.dtype)
|
142 |
+
audio_emb = self.audioproj(audio_emb)
|
143 |
+
|
144 |
+
# condition forward
|
145 |
+
if not uncond_img_fwd:
|
146 |
+
ref_timesteps = torch.zeros_like(timesteps)
|
147 |
+
ref_timesteps = repeat(
|
148 |
+
ref_timesteps,
|
149 |
+
"b -> (repeat b)",
|
150 |
+
repeat=ref_image_latents.size(0) // ref_timesteps.size(0),
|
151 |
+
)
|
152 |
+
self.reference_unet(
|
153 |
+
ref_image_latents,
|
154 |
+
ref_timesteps,
|
155 |
+
encoder_hidden_states=face_emb,
|
156 |
+
return_dict=False,
|
157 |
+
)
|
158 |
+
self.reference_control_reader.update(self.reference_control_writer)
|
159 |
+
|
160 |
+
if uncond_audio_fwd:
|
161 |
+
audio_emb = torch.zeros_like(audio_emb).to(
|
162 |
+
device=audio_emb.device, dtype=audio_emb.dtype
|
163 |
+
)
|
164 |
+
|
165 |
+
model_pred = self.denoising_unet(
|
166 |
+
noisy_latents,
|
167 |
+
timesteps,
|
168 |
+
mask_cond_fea=mask_feature,
|
169 |
+
encoder_hidden_states=face_emb,
|
170 |
+
audio_embedding=audio_emb,
|
171 |
+
full_mask=full_mask,
|
172 |
+
face_mask=face_mask,
|
173 |
+
lip_mask=lip_mask
|
174 |
+
).sample
|
175 |
+
|
176 |
+
return model_pred
|
177 |
+
|
178 |
+
|
179 |
+
def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
|
180 |
+
"""
|
181 |
+
Rearrange the mask tensors to the required format.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
mask (torch.Tensor): The input mask tensor.
|
185 |
+
weight_dtype (torch.dtype): The data type for the mask tensor.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
torch.Tensor: The rearranged mask tensor.
|
189 |
+
"""
|
190 |
+
if isinstance(mask, List):
|
191 |
+
_mask = []
|
192 |
+
for m in mask:
|
193 |
+
_mask.append(
|
194 |
+
rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype))
|
195 |
+
return _mask
|
196 |
+
mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype)
|
197 |
+
return mask
|
198 |
+
|
199 |
+
|
200 |
+
def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]:
|
201 |
+
"""
|
202 |
+
Create noise scheduler for training.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
cfg (argparse.Namespace): Configuration object.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler.
|
209 |
+
"""
|
210 |
+
|
211 |
+
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
|
212 |
+
if cfg.enable_zero_snr:
|
213 |
+
sched_kwargs.update(
|
214 |
+
rescale_betas_zero_snr=True,
|
215 |
+
timestep_spacing="trailing",
|
216 |
+
prediction_type="v_prediction",
|
217 |
+
)
|
218 |
+
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
219 |
+
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
220 |
+
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
221 |
+
|
222 |
+
return train_noise_scheduler, val_noise_scheduler
|
223 |
+
|
224 |
+
|
225 |
+
def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""
|
227 |
+
Process the audio embedding to concatenate with other tensors.
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
audio_emb (torch.Tensor): The audio embedding tensor to process.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
|
234 |
+
"""
|
235 |
+
concatenated_tensors = []
|
236 |
+
|
237 |
+
for i in range(audio_emb.shape[0]):
|
238 |
+
vectors_to_concat = [
|
239 |
+
audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)]
|
240 |
+
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
241 |
+
|
242 |
+
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
243 |
+
|
244 |
+
return audio_emb
|
245 |
+
|
246 |
+
|
247 |
+
def log_validation(
|
248 |
+
accelerator: Accelerator,
|
249 |
+
vae: AutoencoderKL,
|
250 |
+
net: Net,
|
251 |
+
scheduler: DDIMScheduler,
|
252 |
+
width: int,
|
253 |
+
height: int,
|
254 |
+
clip_length: int = 24,
|
255 |
+
generator: torch.Generator = None,
|
256 |
+
cfg: dict = None,
|
257 |
+
save_dir: str = None,
|
258 |
+
global_step: int = 0,
|
259 |
+
times: int = None,
|
260 |
+
face_analysis_model_path: str = "",
|
261 |
+
) -> None:
|
262 |
+
"""
|
263 |
+
Log validation video during the training process.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
accelerator (Accelerator): The accelerator for distributed training.
|
267 |
+
vae (AutoencoderKL): The autoencoder model.
|
268 |
+
net (Net): The main neural network model.
|
269 |
+
scheduler (DDIMScheduler): The scheduler for noise.
|
270 |
+
width (int): The width of the input images.
|
271 |
+
height (int): The height of the input images.
|
272 |
+
clip_length (int): The length of the video clips. Defaults to 24.
|
273 |
+
generator (torch.Generator): The random number generator. Defaults to None.
|
274 |
+
cfg (dict): The configuration dictionary. Defaults to None.
|
275 |
+
save_dir (str): The directory to save validation results. Defaults to None.
|
276 |
+
global_step (int): The current global step in training. Defaults to 0.
|
277 |
+
times (int): The number of inference times. Defaults to None.
|
278 |
+
face_analysis_model_path (str): The path to the face analysis model. Defaults to "".
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
torch.Tensor: The tensor result of the validation.
|
282 |
+
"""
|
283 |
+
ori_net = accelerator.unwrap_model(net)
|
284 |
+
reference_unet = ori_net.reference_unet
|
285 |
+
denoising_unet = ori_net.denoising_unet
|
286 |
+
face_locator = ori_net.face_locator
|
287 |
+
imageproj = ori_net.imageproj
|
288 |
+
audioproj = ori_net.audioproj
|
289 |
+
tmp_denoising_unet = copy.deepcopy(denoising_unet)
|
290 |
+
|
291 |
+
pipeline = FaceAnimatePipeline(
|
292 |
+
vae=vae,
|
293 |
+
reference_unet=reference_unet,
|
294 |
+
denoising_unet=tmp_denoising_unet,
|
295 |
+
face_locator=face_locator,
|
296 |
+
image_proj=imageproj,
|
297 |
+
scheduler=scheduler,
|
298 |
+
)
|
299 |
+
pipeline = pipeline.to(device)
|
300 |
+
|
301 |
+
image_processor = ImageProcessor((width, height), face_analysis_model_path)
|
302 |
+
audio_processor = AudioProcessor(
|
303 |
+
cfg.data.sample_rate,
|
304 |
+
cfg.data.fps,
|
305 |
+
cfg.wav2vec_config.model_path,
|
306 |
+
cfg.wav2vec_config.features == "last",
|
307 |
+
os.path.dirname(cfg.audio_separator.model_path),
|
308 |
+
os.path.basename(cfg.audio_separator.model_path),
|
309 |
+
os.path.join(save_dir, '.cache', "audio_preprocess"),
|
310 |
+
device=device,
|
311 |
+
)
|
312 |
+
return cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length
|
313 |
+
|
314 |
+
|
315 |
+
def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length):
|
316 |
+
ref_img_path = cfg.ref_img_path
|
317 |
+
audio_path = cfg.audio_path
|
318 |
+
source_image_pixels, \
|
319 |
+
source_image_face_region, \
|
320 |
+
source_image_face_emb, \
|
321 |
+
source_image_full_mask, \
|
322 |
+
source_image_face_mask, \
|
323 |
+
source_image_lip_mask = image_processor.preprocess(
|
324 |
+
ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
|
325 |
+
audio_emb, audio_length = audio_processor.preprocess(
|
326 |
+
audio_path, clip_length)
|
327 |
+
|
328 |
+
audio_emb = process_audio_emb(audio_emb)
|
329 |
+
|
330 |
+
source_image_pixels = source_image_pixels.unsqueeze(0)
|
331 |
+
source_image_face_region = source_image_face_region.unsqueeze(0)
|
332 |
+
source_image_face_emb = source_image_face_emb.reshape(1, -1)
|
333 |
+
source_image_face_emb = torch.tensor(source_image_face_emb)
|
334 |
+
|
335 |
+
source_image_full_mask = [
|
336 |
+
(mask.repeat(clip_length, 1))
|
337 |
+
for mask in source_image_full_mask
|
338 |
+
]
|
339 |
+
source_image_face_mask = [
|
340 |
+
(mask.repeat(clip_length, 1))
|
341 |
+
for mask in source_image_face_mask
|
342 |
+
]
|
343 |
+
source_image_lip_mask = [
|
344 |
+
(mask.repeat(clip_length, 1))
|
345 |
+
for mask in source_image_lip_mask
|
346 |
+
]
|
347 |
+
|
348 |
+
times = audio_emb.shape[0] // clip_length
|
349 |
+
tensor_result = []
|
350 |
+
generator = torch.manual_seed(42)
|
351 |
+
for t in range(times):
|
352 |
+
print(f"[{t+1}/{times}]")
|
353 |
+
|
354 |
+
if len(tensor_result) == 0:
|
355 |
+
# The first iteration
|
356 |
+
motion_zeros = source_image_pixels.repeat(
|
357 |
+
cfg.data.n_motion_frames, 1, 1, 1)
|
358 |
+
motion_zeros = motion_zeros.to(
|
359 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
360 |
+
pixel_values_ref_img = torch.cat(
|
361 |
+
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
|
362 |
+
else:
|
363 |
+
motion_frames = tensor_result[-1][0]
|
364 |
+
motion_frames = motion_frames.permute(1, 0, 2, 3)
|
365 |
+
motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
|
366 |
+
motion_frames = motion_frames * 2.0 - 1.0
|
367 |
+
motion_frames = motion_frames.to(
|
368 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
369 |
+
pixel_values_ref_img = torch.cat(
|
370 |
+
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
|
371 |
+
|
372 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
373 |
+
|
374 |
+
audio_tensor = audio_emb[
|
375 |
+
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
|
376 |
+
]
|
377 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
378 |
+
audio_tensor = audio_tensor.to(
|
379 |
+
device=audioproj.device, dtype=audioproj.dtype)
|
380 |
+
audio_tensor = audioproj(audio_tensor)
|
381 |
+
|
382 |
+
pipeline_output = pipeline(
|
383 |
+
ref_image=pixel_values_ref_img,
|
384 |
+
audio_tensor=audio_tensor,
|
385 |
+
face_emb=source_image_face_emb,
|
386 |
+
face_mask=source_image_face_region,
|
387 |
+
pixel_values_full_mask=source_image_full_mask,
|
388 |
+
pixel_values_face_mask=source_image_face_mask,
|
389 |
+
pixel_values_lip_mask=source_image_lip_mask,
|
390 |
+
width=cfg.data.train_width,
|
391 |
+
height=cfg.data.train_height,
|
392 |
+
video_length=clip_length,
|
393 |
+
num_inference_steps=cfg.inference_steps,
|
394 |
+
guidance_scale=cfg.cfg_scale,
|
395 |
+
generator=generator,
|
396 |
+
)
|
397 |
+
|
398 |
+
tensor_result.append(pipeline_output.videos)
|
399 |
+
|
400 |
+
tensor_result = torch.cat(tensor_result, dim=2)
|
401 |
+
tensor_result = tensor_result.squeeze(0)
|
402 |
+
tensor_result = tensor_result[:, :audio_length]
|
403 |
+
output_file = cfg.output
|
404 |
+
tensor_to_video(tensor_result, output_file, audio_path)
|
405 |
+
return output_file
|
406 |
+
|
407 |
+
|
408 |
+
def get_model(cfg: argparse.Namespace) -> None:
|
409 |
+
"""
|
410 |
+
Trains the model using the given configuration (cfg).
|
411 |
+
|
412 |
+
Args:
|
413 |
+
cfg (dict): The configuration dictionary containing the parameters for training.
|
414 |
+
|
415 |
+
Notes:
|
416 |
+
- This function trains the model using the given configuration.
|
417 |
+
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
|
418 |
+
- The training progress is logged and tracked using the accelerator.
|
419 |
+
- The trained model is saved after the training is completed.
|
420 |
+
"""
|
421 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
422 |
+
accelerator = Accelerator(
|
423 |
+
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
424 |
+
mixed_precision=cfg.solver.mixed_precision,
|
425 |
+
log_with="mlflow",
|
426 |
+
project_dir="./mlruns",
|
427 |
+
kwargs_handlers=[kwargs],
|
428 |
+
)
|
429 |
+
|
430 |
+
# Make one log on every process with the configuration for debugging.
|
431 |
+
logging.basicConfig(
|
432 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
433 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
434 |
+
level=logging.INFO,
|
435 |
+
)
|
436 |
+
logger.info(accelerator.state, main_process_only=False)
|
437 |
+
if accelerator.is_local_main_process:
|
438 |
+
transformers.utils.logging.set_verbosity_warning()
|
439 |
+
diffusers.utils.logging.set_verbosity_info()
|
440 |
+
else:
|
441 |
+
transformers.utils.logging.set_verbosity_error()
|
442 |
+
diffusers.utils.logging.set_verbosity_error()
|
443 |
+
|
444 |
+
# If passed along, set the training seed now.
|
445 |
+
if cfg.seed is not None:
|
446 |
+
seed_everything(cfg.seed)
|
447 |
+
|
448 |
+
# create output dir for training
|
449 |
+
exp_name = cfg.exp_name
|
450 |
+
save_dir = f"{cfg.output_dir}/{exp_name}"
|
451 |
+
validation_dir = save_dir
|
452 |
+
if accelerator.is_main_process:
|
453 |
+
init_output_dir([save_dir])
|
454 |
+
|
455 |
+
accelerator.wait_for_everyone()
|
456 |
+
|
457 |
+
if cfg.weight_dtype == "fp16":
|
458 |
+
weight_dtype = torch.float16
|
459 |
+
elif cfg.weight_dtype == "bf16":
|
460 |
+
weight_dtype = torch.bfloat16
|
461 |
+
elif cfg.weight_dtype == "fp32":
|
462 |
+
weight_dtype = torch.float32
|
463 |
+
else:
|
464 |
+
raise ValueError(
|
465 |
+
f"Do not support weight dtype: {cfg.weight_dtype} during training"
|
466 |
+
)
|
467 |
+
|
468 |
+
if not torch.cuda.is_available():
|
469 |
+
weight_dtype = torch.float32
|
470 |
+
|
471 |
+
# Create Models
|
472 |
+
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
|
473 |
+
device=device, dtype=weight_dtype
|
474 |
+
)
|
475 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
476 |
+
cfg.base_model_path,
|
477 |
+
subfolder="unet",
|
478 |
+
).to(device=device, dtype=weight_dtype)
|
479 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
480 |
+
cfg.base_model_path,
|
481 |
+
cfg.mm_path,
|
482 |
+
subfolder="unet",
|
483 |
+
unet_additional_kwargs=OmegaConf.to_container(
|
484 |
+
cfg.unet_additional_kwargs),
|
485 |
+
use_landmark=False
|
486 |
+
).to(device=device, dtype=weight_dtype)
|
487 |
+
imageproj = ImageProjModel(
|
488 |
+
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
489 |
+
clip_embeddings_dim=512,
|
490 |
+
clip_extra_context_tokens=4,
|
491 |
+
).to(device=device, dtype=weight_dtype)
|
492 |
+
face_locator = FaceLocator(
|
493 |
+
conditioning_embedding_channels=320,
|
494 |
+
).to(device=device, dtype=weight_dtype)
|
495 |
+
audioproj = AudioProjModel(
|
496 |
+
seq_len=5,
|
497 |
+
blocks=12,
|
498 |
+
channels=768,
|
499 |
+
intermediate_dim=512,
|
500 |
+
output_dim=768,
|
501 |
+
context_tokens=32,
|
502 |
+
).to(device=device, dtype=weight_dtype)
|
503 |
+
|
504 |
+
# Freeze
|
505 |
+
vae.requires_grad_(False)
|
506 |
+
imageproj.requires_grad_(False)
|
507 |
+
reference_unet.requires_grad_(False)
|
508 |
+
denoising_unet.requires_grad_(False)
|
509 |
+
face_locator.requires_grad_(False)
|
510 |
+
audioproj.requires_grad_(True)
|
511 |
+
|
512 |
+
# Set motion module learnable
|
513 |
+
trainable_modules = cfg.trainable_para
|
514 |
+
for name, module in denoising_unet.named_modules():
|
515 |
+
if any(trainable_mod in name for trainable_mod in trainable_modules):
|
516 |
+
for params in module.parameters():
|
517 |
+
params.requires_grad_(True)
|
518 |
+
|
519 |
+
reference_control_writer = ReferenceAttentionControl(
|
520 |
+
reference_unet,
|
521 |
+
do_classifier_free_guidance=False,
|
522 |
+
mode="write",
|
523 |
+
fusion_blocks="full",
|
524 |
+
)
|
525 |
+
reference_control_reader = ReferenceAttentionControl(
|
526 |
+
denoising_unet,
|
527 |
+
do_classifier_free_guidance=False,
|
528 |
+
mode="read",
|
529 |
+
fusion_blocks="full",
|
530 |
+
)
|
531 |
+
|
532 |
+
net = Net(
|
533 |
+
reference_unet,
|
534 |
+
denoising_unet,
|
535 |
+
face_locator,
|
536 |
+
reference_control_writer,
|
537 |
+
reference_control_reader,
|
538 |
+
imageproj,
|
539 |
+
audioproj,
|
540 |
+
).to(dtype=weight_dtype)
|
541 |
+
|
542 |
+
m,u = net.load_state_dict(
|
543 |
+
torch.load(
|
544 |
+
cfg.audio_ckpt_dir,
|
545 |
+
map_location="cpu",
|
546 |
+
),
|
547 |
+
)
|
548 |
+
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
|
549 |
+
print("loaded weight from ", os.path.join(cfg.audio_ckpt_dir))
|
550 |
+
|
551 |
+
# get noise scheduler
|
552 |
+
_, val_noise_scheduler = get_noise_scheduler(cfg)
|
553 |
+
|
554 |
+
if cfg.solver.enable_xformers_memory_efficient_attention and torch.cuda.is_available():
|
555 |
+
if is_xformers_available():
|
556 |
+
reference_unet.enable_xformers_memory_efficient_attention()
|
557 |
+
denoising_unet.enable_xformers_memory_efficient_attention()
|
558 |
+
|
559 |
+
else:
|
560 |
+
raise ValueError(
|
561 |
+
"xformers is not available. Make sure it is installed correctly"
|
562 |
+
)
|
563 |
+
|
564 |
+
if cfg.solver.gradient_checkpointing:
|
565 |
+
reference_unet.enable_gradient_checkpointing()
|
566 |
+
denoising_unet.enable_gradient_checkpointing()
|
567 |
+
|
568 |
+
if cfg.solver.scale_lr:
|
569 |
+
learning_rate = (
|
570 |
+
cfg.solver.learning_rate
|
571 |
+
* cfg.solver.gradient_accumulation_steps
|
572 |
+
* cfg.data.train_bs
|
573 |
+
* accelerator.num_processes
|
574 |
+
)
|
575 |
+
else:
|
576 |
+
learning_rate = cfg.solver.learning_rate
|
577 |
+
|
578 |
+
# Initialize the optimizer
|
579 |
+
optimizer_cls = torch.optim.AdamW
|
580 |
+
|
581 |
+
trainable_params = list(
|
582 |
+
filter(lambda p: p.requires_grad, net.parameters()))
|
583 |
+
|
584 |
+
optimizer = optimizer_cls(
|
585 |
+
trainable_params,
|
586 |
+
lr=learning_rate,
|
587 |
+
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
588 |
+
weight_decay=cfg.solver.adam_weight_decay,
|
589 |
+
eps=cfg.solver.adam_epsilon,
|
590 |
+
)
|
591 |
+
|
592 |
+
# Scheduler
|
593 |
+
lr_scheduler = get_scheduler(
|
594 |
+
cfg.solver.lr_scheduler,
|
595 |
+
optimizer=optimizer,
|
596 |
+
num_warmup_steps=cfg.solver.lr_warmup_steps
|
597 |
+
* cfg.solver.gradient_accumulation_steps,
|
598 |
+
num_training_steps=cfg.solver.max_train_steps
|
599 |
+
* cfg.solver.gradient_accumulation_steps,
|
600 |
+
)
|
601 |
+
|
602 |
+
# get data loader
|
603 |
+
train_dataset = TalkingVideoDataset(
|
604 |
+
img_size=(cfg.data.train_width, cfg.data.train_height),
|
605 |
+
sample_rate=cfg.data.sample_rate,
|
606 |
+
n_sample_frames=cfg.data.n_sample_frames,
|
607 |
+
n_motion_frames=cfg.data.n_motion_frames,
|
608 |
+
audio_margin=cfg.data.audio_margin,
|
609 |
+
data_meta_paths=cfg.data.train_meta_paths,
|
610 |
+
wav2vec_cfg=cfg.wav2vec_config,
|
611 |
+
)
|
612 |
+
train_dataloader = torch.utils.data.DataLoader(
|
613 |
+
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16
|
614 |
+
)
|
615 |
+
|
616 |
+
# Prepare everything with our `accelerator`.
|
617 |
+
(
|
618 |
+
net,
|
619 |
+
optimizer,
|
620 |
+
train_dataloader,
|
621 |
+
lr_scheduler,
|
622 |
+
) = accelerator.prepare(
|
623 |
+
net,
|
624 |
+
optimizer,
|
625 |
+
train_dataloader,
|
626 |
+
lr_scheduler,
|
627 |
+
)
|
628 |
+
|
629 |
+
return accelerator, vae, net, val_noise_scheduler, cfg, validation_dir
|
630 |
+
|
631 |
+
|
632 |
+
def load_config(config_path: str) -> dict:
|
633 |
+
"""
|
634 |
+
Loads the configuration file.
|
635 |
+
|
636 |
+
Args:
|
637 |
+
config_path (str): Path to the configuration file.
|
638 |
+
|
639 |
+
Returns:
|
640 |
+
dict: The configuration dictionary.
|
641 |
+
"""
|
642 |
+
|
643 |
+
if config_path.endswith(".yaml"):
|
644 |
+
return OmegaConf.load(config_path)
|
645 |
+
if config_path.endswith(".py"):
|
646 |
+
return import_filename(config_path).cfg
|
647 |
+
raise ValueError("Unsupported format for config file")
|
648 |
+
|
649 |
+
args = argparse.Namespace()
|
650 |
+
_config = load_config('configs/inference/inference.yaml')
|
651 |
+
for key, value in _config.items():
|
652 |
+
setattr(args, key, value)
|
653 |
+
accelerator, vae, net, val_noise_scheduler, cfg, validation_dir = get_model(args)
|
654 |
+
cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length = log_validation(
|
655 |
+
accelerator=accelerator,
|
656 |
+
vae=vae,
|
657 |
+
net=net,
|
658 |
+
scheduler=val_noise_scheduler,
|
659 |
+
width=cfg.data.train_width,
|
660 |
+
height=cfg.data.train_height,
|
661 |
+
clip_length=cfg.data.n_sample_frames,
|
662 |
+
cfg=cfg,
|
663 |
+
save_dir=validation_dir,
|
664 |
+
global_step=0,
|
665 |
+
times=cfg.single_inference_times if cfg.single_inference_times is not None else None,
|
666 |
+
face_analysis_model_path=cfg.face_analysis_model_path
|
667 |
+
)
|
668 |
+
|
669 |
+
def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)):
|
670 |
+
"""
|
671 |
+
Create a gradio interface with the configs.
|
672 |
+
"""
|
673 |
+
_ = progress
|
674 |
+
unique_id = uuid.uuid4()
|
675 |
+
config = {
|
676 |
+
'ref_img_path': image,
|
677 |
+
'audio_path': audio,
|
678 |
+
'pose_weight': pose_weight,
|
679 |
+
'face_weight': face_weight,
|
680 |
+
'lip_weight': lip_weight,
|
681 |
+
'face_expand_ratio': face_expand_ratio,
|
682 |
+
'config': 'configs/inference/inference.yaml',
|
683 |
+
'checkpoint': None,
|
684 |
+
'output': f'output-{unique_id}.mp4'
|
685 |
+
}
|
686 |
+
global cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length
|
687 |
+
for key, value in config.items():
|
688 |
+
setattr(cfg, key, value)
|
689 |
+
|
690 |
+
return inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length)
|