Spaces:
Running
on
Zero
Running
on
Zero
xierui.0097
commited on
Commit
·
f0e9666
1
Parent(s):
470b11c
Add application file
Browse files- README.md +106 -13
- __pycache__/inference_utils.cpython-39.pyc +0 -0
- inference_utils.py +148 -0
- requirements.txt +15 -0
- video_super_resolution/__pycache__/color_fix.cpython-39.pyc +0 -0
- video_super_resolution/color_fix.py +122 -0
- video_super_resolution/dataset.py +113 -0
- video_super_resolution/scripts/inference_sr.py +140 -0
- video_super_resolution/scripts/inference_sr.sh +56 -0
- video_to_video/__init__.py +0 -0
- video_to_video/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/__pycache__/video_to_video_model.cpython-39.pyc +0 -0
- video_to_video/diffusion/__init__.py +0 -0
- video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/diffusion_sdedit.py +443 -0
- video_to_video/diffusion/schedules_sdedit.py +85 -0
- video_to_video/diffusion/solvers_sdedit.py +204 -0
- video_to_video/modules/__init__.py +3 -0
- video_to_video/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/embedder.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/t5.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc +0 -0
- video_to_video/modules/embedder.py +75 -0
- video_to_video/modules/t5.py +335 -0
- video_to_video/modules/unet_v2v.py +2332 -0
- video_to_video/utils/__init__.py +0 -0
- video_to_video/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/config.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/logger.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/seed.cpython-39.pyc +0 -0
- video_to_video/utils/config.py +169 -0
- video_to_video/utils/logger.py +94 -0
- video_to_video/utils/seed.py +14 -0
- video_to_video/video_to_video_model.py +210 -0
README.md
CHANGED
@@ -1,13 +1,106 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>
|
3 |
+
STAR: Spatial-Temporal Augmentation with Text-to-Video Models for Real-World Video Super-Resolution
|
4 |
+
</h1>
|
5 |
+
<div>
|
6 |
+
<a href='https://github.com/CSRuiXie' target='_blank'>Rui Xie<sup>1*</sup></a>, 
|
7 |
+
<a href='https://github.com/yhliu04' target='_blank'>Yinhong Liu<sup>1*</sup></a>, 
|
8 |
+
<a href='https://scholar.google.com/citations?user=Uhp3JKgAAAAJ&hl=zh-CN&oi=sra' target='_blank'>Chen Zhao<sup>1</sup></a>, 
|
9 |
+
<a href='https://scholar.google.com/citations?hl=zh-CN&user=yWq1Fd4AAAAJ' target='_blank'>Penghao Zhou<sup>2</sup></a>, 
|
10 |
+
<a href='https://scholar.google.com/citations?hl=zh-CN&user=Ds5wwRoAAAAJ' target='_blank'>Zhenheng Yang<sup>2</sup></a><br>
|
11 |
+
<a href='https://scholar.google.com/citations?hl=zh-CN&user=w03CHFwAAAAJ' target='_blank'>Jun Zhou<sup>3</sup></a>, 
|
12 |
+
<a href='https://cszn.github.io/' target='_blank'>Kai Zhang<sup>1</sup></a>, 
|
13 |
+
<a href='https://jessezhang92.github.io/' target='_blank'>Zhenyu Zhang<sup>1</sup></a>, 
|
14 |
+
<a href='https://scholar.google.com.hk/citations?user=6CIDtZQAAAAJ&hl=zh-CN' target='_blank'>Jian Yang<sup>1</sup></a>, 
|
15 |
+
<a href='https://tyshiwo.github.io/index.html' target='_blank'>Ying Tai<sup>1†</sup></a>
|
16 |
+
</div>
|
17 |
+
<div>
|
18 |
+
<sup>1</sup>Nanjing University, <sup>2</sup>ByteDance,  <sup>3</sup>Southwest University
|
19 |
+
</div>
|
20 |
+
<div>
|
21 |
+
<h4 align="center">
|
22 |
+
<a href="https://nju-pcalab.github.io/projects/STAR" target='_blank'>
|
23 |
+
<img src="https://img.shields.io/badge/🌟-Project%20Page-blue">
|
24 |
+
</a>
|
25 |
+
<a href="https://arxiv.org/abs/2407.07667" target='_blank'>
|
26 |
+
<img src="https://img.shields.io/badge/arXiv-2312.06640-b31b1b.svg">
|
27 |
+
</a>
|
28 |
+
<a href="https://youtu.be/hx0zrql-SrU" target='_blank'>
|
29 |
+
<img src="https://img.shields.io/badge/Demo%20Video-%23FF0000.svg?logo=YouTube&logoColor=white">
|
30 |
+
</a>
|
31 |
+
</h4>
|
32 |
+
</div>
|
33 |
+
</div>
|
34 |
+
|
35 |
+
|
36 |
+
### 🔆 Updates
|
37 |
+
- **2024.12.01** The pretrained STAR model (I2VGen-XL version) and inference code have been released.
|
38 |
+
|
39 |
+
|
40 |
+
## 🔎 Method Overview
|
41 |
+
![STAR](assets/overview.png)
|
42 |
+
|
43 |
+
|
44 |
+
## 📷 Results Display
|
45 |
+
![STAR](assets/teaser.png)
|
46 |
+
![STAR](assets/real_world.png)
|
47 |
+
👀 More visual results can be found in our [Project Page](https://nju-pcalab.github.io/projects/STAR) and [Video Demo](https://youtu.be/hx0zrql-SrU).
|
48 |
+
|
49 |
+
|
50 |
+
## ⚙️ Dependencies and Installation
|
51 |
+
```
|
52 |
+
## git clone this repository
|
53 |
+
git clone https://github.com/NJU-PCALab/STAR.git
|
54 |
+
cd STAR
|
55 |
+
|
56 |
+
## create an environment
|
57 |
+
conda create -n star python=3.10
|
58 |
+
conda activate star
|
59 |
+
pip install -r requirements.txt
|
60 |
+
sudo apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
61 |
+
```
|
62 |
+
|
63 |
+
## 🚀 Inference
|
64 |
+
#### Step 1: Download the pretrained model STAR from [HuggingFace](https://huggingface.co/SherryX/STAR).
|
65 |
+
We provide two verisions, `heavy_deg.pt` for heavy degraded videos and `light_deg.pt` for light degraded videos (e.g., the low-resolution video downloaded from video websites).
|
66 |
+
|
67 |
+
You can put the weight into `pretrained_weight/`.
|
68 |
+
|
69 |
+
|
70 |
+
#### Step 2: Prepare testing data
|
71 |
+
You can put the testing videos in the `input/video/`.
|
72 |
+
|
73 |
+
As for the prompt, there are three options: 1. No prompt. 2. Automatically generate a prompt [using Pllava](https://github.com/hpcaitech/Open-Sora/tree/main/tools/caption#pllava-captioning). 3. Manually write the prompt. You can put the txt file in the `input/text/`.
|
74 |
+
|
75 |
+
|
76 |
+
#### Step 3: Change the path
|
77 |
+
You need to change the paths in `video_super_resolution/scripts/inference_sr.sh` to your local corresponding paths, including `video_folder_path`, `txt_file_path`, `model_path`, and `save_dir`.
|
78 |
+
|
79 |
+
|
80 |
+
#### Step 4: Running inference command
|
81 |
+
```
|
82 |
+
bash video_super_resolution/scripts/inference_sr.sh
|
83 |
+
```
|
84 |
+
|
85 |
+
|
86 |
+
## ❤️ Acknowledgments
|
87 |
+
This project is based on [I2VGen-XL](https://github.com/ali-vilab/VGen), [VEnhancer](https://github.com/Vchitect/VEnhancer) and [CogVideoX](https://github.com/THUDM/CogVideo). Thanks for their awesome works.
|
88 |
+
|
89 |
+
|
90 |
+
## 🎓Citations
|
91 |
+
If our project helps your research or work, please consider citing our paper:
|
92 |
+
|
93 |
+
```
|
94 |
+
@misc{xie2024addsr,
|
95 |
+
title={AddSR: Accelerating Diffusion-based Blind Super-Resolution with Adversarial Diffusion Distillation},
|
96 |
+
author={Rui Xie and Ying Tai and Kai Zhang and Zhenyu Zhang and Jun Zhou and Jian Yang},
|
97 |
+
year={2024},
|
98 |
+
eprint={2404.01717},
|
99 |
+
archivePrefix={arXiv},
|
100 |
+
primaryClass={cs.CV}
|
101 |
+
}
|
102 |
+
```
|
103 |
+
|
104 |
+
|
105 |
+
## 📧 Contact
|
106 |
+
If you have any inquiries, please don't hesitate to reach out via email at `[email protected]`
|
__pycache__/inference_utils.cpython-39.pyc
ADDED
Binary file (5.07 kB). View file
|
|
inference_utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import tempfile
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from typing import Mapping
|
8 |
+
from einops import rearrange
|
9 |
+
import numpy as np
|
10 |
+
import torchvision.transforms.functional as transforms_F
|
11 |
+
from video_to_video.utils.logger import get_logger
|
12 |
+
|
13 |
+
logger = get_logger()
|
14 |
+
|
15 |
+
|
16 |
+
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
17 |
+
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
18 |
+
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
19 |
+
video = video.mul_(std).add_(mean)
|
20 |
+
video.clamp_(0, 1)
|
21 |
+
video = video * 255.0
|
22 |
+
images = rearrange(video, 'b c f h w -> b f h w c')[0]
|
23 |
+
return images
|
24 |
+
|
25 |
+
|
26 |
+
def preprocess(input_frames):
|
27 |
+
out_frame_list = []
|
28 |
+
for pointer in range(len(input_frames)):
|
29 |
+
frame = input_frames[pointer]
|
30 |
+
frame = frame[:, :, ::-1]
|
31 |
+
frame = Image.fromarray(frame.astype('uint8')).convert('RGB')
|
32 |
+
frame = transforms_F.to_tensor(frame)
|
33 |
+
out_frame_list.append(frame)
|
34 |
+
out_frames = torch.stack(out_frame_list, dim=0)
|
35 |
+
out_frames.clamp_(0, 1)
|
36 |
+
mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
|
37 |
+
std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
|
38 |
+
out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1))
|
39 |
+
return out_frames
|
40 |
+
|
41 |
+
|
42 |
+
def adjust_resolution(h, w, up_scale):
|
43 |
+
if h*up_scale < 720:
|
44 |
+
up_s = 720/h
|
45 |
+
target_h = int(up_s*h//2*2)
|
46 |
+
target_w = int(up_s*w//2*2)
|
47 |
+
elif h*w*up_scale*up_scale > 1280*2048:
|
48 |
+
up_s = np.sqrt(1280*2048/(h*w))
|
49 |
+
target_h = int(up_s*h//2*2)
|
50 |
+
target_w = int(up_s*w//2*2)
|
51 |
+
else:
|
52 |
+
target_h = int(up_scale*h//2*2)
|
53 |
+
target_w = int(up_scale*w//2*2)
|
54 |
+
return (target_h, target_w)
|
55 |
+
|
56 |
+
|
57 |
+
def make_mask_cond(in_f_num, interp_f_num):
|
58 |
+
mask_cond = []
|
59 |
+
interp_cond = [-1 for _ in range(interp_f_num)]
|
60 |
+
for i in range(in_f_num):
|
61 |
+
mask_cond.append(i)
|
62 |
+
if i != in_f_num - 1:
|
63 |
+
mask_cond += interp_cond
|
64 |
+
return mask_cond
|
65 |
+
|
66 |
+
|
67 |
+
def load_video(vid_path):
|
68 |
+
capture = cv2.VideoCapture(vid_path)
|
69 |
+
_fps = capture.get(cv2.CAP_PROP_FPS)
|
70 |
+
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
71 |
+
pointer = 0
|
72 |
+
frame_list = []
|
73 |
+
stride = 1
|
74 |
+
while len(frame_list) < _total_frame_num:
|
75 |
+
ret, frame = capture.read()
|
76 |
+
pointer += 1
|
77 |
+
if (not ret) or (frame is None):
|
78 |
+
break
|
79 |
+
if pointer >= _total_frame_num + 1:
|
80 |
+
break
|
81 |
+
if pointer % stride == 0:
|
82 |
+
frame_list.append(frame)
|
83 |
+
capture.release()
|
84 |
+
return frame_list, _fps
|
85 |
+
|
86 |
+
|
87 |
+
def save_video(video, save_dir, file_name, fps=16.0):
|
88 |
+
output_path = os.path.join(save_dir, file_name)
|
89 |
+
images = [(img.numpy()).astype('uint8') for img in video]
|
90 |
+
temp_dir = tempfile.mkdtemp()
|
91 |
+
|
92 |
+
for fid, frame in enumerate(images):
|
93 |
+
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
|
94 |
+
cv2.imwrite(tpth, frame[:, :, ::-1])
|
95 |
+
|
96 |
+
tmp_path = os.path.join(save_dir, 'tmp.mp4')
|
97 |
+
cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
|
98 |
+
-vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}'
|
99 |
+
|
100 |
+
status, output = subprocess.getstatusoutput(cmd)
|
101 |
+
if status != 0:
|
102 |
+
logger.error('Save Video Error with {}'.format(output))
|
103 |
+
|
104 |
+
os.system(f'rm -rf {temp_dir}')
|
105 |
+
os.rename(tmp_path, output_path)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def collate_fn(data, device):
|
110 |
+
"""Prepare the input just before the forward function.
|
111 |
+
This method will move the tensors to the right device.
|
112 |
+
Usually this method does not need to be overridden.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
data: The data out of the dataloader.
|
116 |
+
device: The device to move data to.
|
117 |
+
|
118 |
+
Returns: The processed data.
|
119 |
+
|
120 |
+
"""
|
121 |
+
from torch.utils.data.dataloader import default_collate
|
122 |
+
|
123 |
+
def get_class_name(obj):
|
124 |
+
return obj.__class__.__name__
|
125 |
+
|
126 |
+
if isinstance(data, dict) or isinstance(data, Mapping):
|
127 |
+
return type(data)({
|
128 |
+
k: collate_fn(v, device) if k != 'img_metas' else v
|
129 |
+
for k, v in data.items()
|
130 |
+
})
|
131 |
+
elif isinstance(data, (tuple, list)):
|
132 |
+
if 0 == len(data):
|
133 |
+
return torch.Tensor([])
|
134 |
+
if isinstance(data[0], (int, float)):
|
135 |
+
return default_collate(data).to(device)
|
136 |
+
else:
|
137 |
+
return type(data)(collate_fn(v, device) for v in data)
|
138 |
+
elif isinstance(data, np.ndarray):
|
139 |
+
if data.dtype.type is np.str_:
|
140 |
+
return data
|
141 |
+
else:
|
142 |
+
return collate_fn(torch.from_numpy(data), device)
|
143 |
+
elif isinstance(data, torch.Tensor):
|
144 |
+
return data.to(device)
|
145 |
+
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
|
146 |
+
return data
|
147 |
+
else:
|
148 |
+
raise ValueError(f'Unsupported data type {type(data)}')
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision==0.15.2
|
3 |
+
torchaudio==2.0.2
|
4 |
+
opencv-python==4.10.0.84
|
5 |
+
easydict==1.13
|
6 |
+
einops==0.8.0
|
7 |
+
open-clip-torch==2.20.0
|
8 |
+
xformers==0.0.21
|
9 |
+
fairscale==0.4.13
|
10 |
+
torchsde==0.2.6
|
11 |
+
pytorch-lightning==2.0.1
|
12 |
+
diffusers==0.30.0
|
13 |
+
huggingface_hub==0.23.3
|
14 |
+
gradio==4.41.0
|
15 |
+
numpy==1.24
|
video_super_resolution/__pycache__/color_fix.cpython-39.pyc
ADDED
Binary file (4.01 kB). View file
|
|
video_super_resolution/color_fix.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# --------------------------------------------------------------------------------
|
3 |
+
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
4 |
+
# --------------------------------------------------------------------------------
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
def adain_color_fix(target: Image, source: Image):
|
16 |
+
# Convert images to tensors
|
17 |
+
target = rearrange(target, 'T H W C -> T C H W') / 255
|
18 |
+
source = (source + 1) / 2
|
19 |
+
|
20 |
+
# Apply adaptive instance normalization
|
21 |
+
result_tensor_list = []
|
22 |
+
for i in range(0, target.shape[0]):
|
23 |
+
result_tensor_list.append(adaptive_instance_normalization(target[i].unsqueeze(0), source[i].unsqueeze(0)))
|
24 |
+
|
25 |
+
# Convert tensor back to image
|
26 |
+
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
|
27 |
+
result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
|
28 |
+
|
29 |
+
return result_video
|
30 |
+
|
31 |
+
def wavelet_color_fix(target, source):
|
32 |
+
# Convert images to tensors
|
33 |
+
target = rearrange(target, 'T H W C -> T C H W') / 255
|
34 |
+
source = (source + 1) / 2
|
35 |
+
|
36 |
+
# Apply wavelet reconstruction
|
37 |
+
result_tensor_list = []
|
38 |
+
for i in range(0, target.shape[0]):
|
39 |
+
result_tensor_list.append(wavelet_reconstruction(target[i].unsqueeze(0), source[i].unsqueeze(0)))
|
40 |
+
|
41 |
+
# Convert tensor back to image
|
42 |
+
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
|
43 |
+
result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
|
44 |
+
|
45 |
+
return result_video
|
46 |
+
|
47 |
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
48 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
49 |
+
Args:
|
50 |
+
feat (Tensor): 4D tensor.
|
51 |
+
eps (float): A small value added to the variance to avoid
|
52 |
+
divide-by-zero. Default: 1e-5.
|
53 |
+
"""
|
54 |
+
size = feat.size()
|
55 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
56 |
+
b, c = size[:2]
|
57 |
+
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
58 |
+
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
59 |
+
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
60 |
+
return feat_mean, feat_std
|
61 |
+
|
62 |
+
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
63 |
+
"""Adaptive instance normalization.
|
64 |
+
Adjust the reference features to have the similar color and illuminations
|
65 |
+
as those in the degradate features.
|
66 |
+
Args:
|
67 |
+
content_feat (Tensor): The reference feature.
|
68 |
+
style_feat (Tensor): The degradate features.
|
69 |
+
"""
|
70 |
+
size = content_feat.size()
|
71 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
72 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
73 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
74 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
75 |
+
|
76 |
+
def wavelet_blur(image: Tensor, radius: int):
|
77 |
+
"""
|
78 |
+
Apply wavelet blur to the input tensor.
|
79 |
+
"""
|
80 |
+
# input shape: (1, 3, H, W)
|
81 |
+
# convolution kernel
|
82 |
+
kernel_vals = [
|
83 |
+
[0.0625, 0.125, 0.0625],
|
84 |
+
[0.125, 0.25, 0.125],
|
85 |
+
[0.0625, 0.125, 0.0625],
|
86 |
+
]
|
87 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
88 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
89 |
+
kernel = kernel[None, None]
|
90 |
+
# repeat the kernel across all input channels
|
91 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
92 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
93 |
+
# apply convolution
|
94 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
95 |
+
return output
|
96 |
+
|
97 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
98 |
+
"""
|
99 |
+
Apply wavelet decomposition to the input tensor.
|
100 |
+
This function only returns the low frequency & the high frequency.
|
101 |
+
"""
|
102 |
+
high_freq = torch.zeros_like(image)
|
103 |
+
for i in range(levels):
|
104 |
+
radius = 2 ** i
|
105 |
+
low_freq = wavelet_blur(image, radius)
|
106 |
+
high_freq += (image - low_freq)
|
107 |
+
image = low_freq
|
108 |
+
|
109 |
+
return high_freq, low_freq
|
110 |
+
|
111 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
112 |
+
"""
|
113 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
114 |
+
"""
|
115 |
+
# calculate the wavelet decomposition of the content feature
|
116 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
117 |
+
del content_low_freq
|
118 |
+
# calculate the wavelet decomposition of the style feature
|
119 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
120 |
+
del style_high_freq
|
121 |
+
# reconstruct the content feature with the style's high frequency
|
122 |
+
return content_high_freq + style_low_freq
|
video_super_resolution/dataset.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import torchvision
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.utils import data as data
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
class PairedCaptionVideoDataset(data.Dataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
root_folders=None,
|
15 |
+
null_text_ratio=0.5,
|
16 |
+
num_frames=16
|
17 |
+
):
|
18 |
+
super(PairedCaptionVideoDataset, self).__init__()
|
19 |
+
|
20 |
+
self.null_text_ratio = null_text_ratio
|
21 |
+
self.lr_list = []
|
22 |
+
self.gt_list = []
|
23 |
+
self.tag_path_list = []
|
24 |
+
self.num_frames = num_frames
|
25 |
+
|
26 |
+
# root_folders = root_folders.split(',')
|
27 |
+
for root_folder in root_folders:
|
28 |
+
lr_path = root_folder +'/lq'
|
29 |
+
tag_path = root_folder +'/text'
|
30 |
+
gt_path = root_folder +'/gt'
|
31 |
+
|
32 |
+
self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
|
33 |
+
self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
|
34 |
+
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
|
35 |
+
|
36 |
+
assert len(self.lr_list) == len(self.gt_list)
|
37 |
+
assert len(self.lr_list) == len(self.tag_path_list)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
gt_path = self.gt_list[index]
|
42 |
+
vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
|
43 |
+
fps = info['video_fps']
|
44 |
+
vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
|
45 |
+
# gt = self.trandform(vframes_gt)
|
46 |
+
|
47 |
+
lq_path = self.lr_list[index]
|
48 |
+
vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
|
49 |
+
vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
|
50 |
+
# lq = self.trandform(vframes_lq)
|
51 |
+
|
52 |
+
if random.random() < self.null_text_ratio:
|
53 |
+
tag = ''
|
54 |
+
else:
|
55 |
+
tag_path = self.tag_path_list[index]
|
56 |
+
with open(tag_path, 'r', encoding='utf-8') as file:
|
57 |
+
tag = file.read()
|
58 |
+
|
59 |
+
return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return len(self.gt_list)
|
63 |
+
|
64 |
+
|
65 |
+
class PairedCaptionImageDataset(data.Dataset):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
root_folder=None,
|
69 |
+
):
|
70 |
+
super(PairedCaptionImageDataset, self).__init__()
|
71 |
+
|
72 |
+
self.lr_list = []
|
73 |
+
self.gt_list = []
|
74 |
+
self.tag_path_list = []
|
75 |
+
|
76 |
+
lr_path = root_folder +'/sr_bicubic'
|
77 |
+
gt_path = root_folder +'/gt'
|
78 |
+
|
79 |
+
self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
|
80 |
+
self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
|
81 |
+
|
82 |
+
assert len(self.lr_list) == len(self.gt_list)
|
83 |
+
|
84 |
+
self.img_preproc = transforms.Compose([
|
85 |
+
transforms.ToTensor(),
|
86 |
+
])
|
87 |
+
|
88 |
+
# Define the crop size (e.g., 256x256)
|
89 |
+
crop_size = (720, 1280)
|
90 |
+
|
91 |
+
# CenterCrop transform
|
92 |
+
self.center_crop = transforms.CenterCrop(crop_size)
|
93 |
+
|
94 |
+
def __getitem__(self, index):
|
95 |
+
|
96 |
+
gt_path = self.gt_list[index]
|
97 |
+
gt_img = Image.open(gt_path).convert('RGB')
|
98 |
+
gt_img = self.center_crop(self.img_preproc(gt_img))
|
99 |
+
|
100 |
+
lq_path = self.lr_list[index]
|
101 |
+
lq_img = Image.open(lq_path).convert('RGB')
|
102 |
+
lq_img = self.center_crop(self.img_preproc(lq_img))
|
103 |
+
|
104 |
+
example = dict()
|
105 |
+
|
106 |
+
example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
|
107 |
+
example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
|
108 |
+
example["text"] = ""
|
109 |
+
|
110 |
+
return example
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self.gt_list)
|
video_super_resolution/scripts/inference_sr.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from argparse import ArgumentParser, Namespace
|
4 |
+
import json
|
5 |
+
from typing import Any, Dict, List, Mapping, Tuple
|
6 |
+
from easydict import EasyDict
|
7 |
+
|
8 |
+
from video_to_video.video_to_video_model import VideoToVideo_sr
|
9 |
+
from video_to_video.utils.seed import setup_seed
|
10 |
+
from video_to_video.utils.logger import get_logger
|
11 |
+
from video_super_resolution.color_fix import adain_color_fix
|
12 |
+
|
13 |
+
from inference_utils import *
|
14 |
+
|
15 |
+
logger = get_logger()
|
16 |
+
|
17 |
+
|
18 |
+
class VEnhancer_sr():
|
19 |
+
def __init__(self,
|
20 |
+
result_dir='./results/',
|
21 |
+
file_name='000_video.mp4',
|
22 |
+
model_path='',
|
23 |
+
solver_mode='fast',
|
24 |
+
steps=15,
|
25 |
+
guide_scale=7.5,
|
26 |
+
upscale=4,
|
27 |
+
max_chunk_len=32,
|
28 |
+
variant_info=None,
|
29 |
+
):
|
30 |
+
self.model_path=model_path
|
31 |
+
logger.info('checkpoint_path: {}'.format(self.model_path))
|
32 |
+
|
33 |
+
self.result_dir = result_dir
|
34 |
+
self.file_name = file_name
|
35 |
+
os.makedirs(self.result_dir, exist_ok=True)
|
36 |
+
|
37 |
+
model_cfg = EasyDict(__name__='model_cfg')
|
38 |
+
model_cfg.model_path = self.model_path
|
39 |
+
self.model = VideoToVideo_sr(model_cfg)
|
40 |
+
|
41 |
+
steps = 15 if solver_mode == 'fast' else steps
|
42 |
+
self.solver_mode=solver_mode
|
43 |
+
self.steps=steps
|
44 |
+
self.guide_scale=guide_scale
|
45 |
+
self.upscale = upscale
|
46 |
+
self.max_chunk_len=max_chunk_len
|
47 |
+
self.variant_info=variant_info
|
48 |
+
|
49 |
+
def enhance_a_video(self, video_path, prompt):
|
50 |
+
logger.info('input video path: {}'.format(video_path))
|
51 |
+
text = prompt
|
52 |
+
logger.info('text: {}'.format(text))
|
53 |
+
caption = text + self.model.positive_prompt
|
54 |
+
|
55 |
+
input_frames, input_fps = load_video(video_path)
|
56 |
+
in_f_num = len(input_frames)
|
57 |
+
logger.info('input frames length: {}'.format(in_f_num))
|
58 |
+
logger.info('input fps: {}'.format(input_fps))
|
59 |
+
|
60 |
+
video_data = preprocess(input_frames)
|
61 |
+
_, _, h, w = video_data.shape
|
62 |
+
logger.info('input resolution: {}'.format((h, w)))
|
63 |
+
target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
|
64 |
+
logger.info('target resolution: {}'.format((target_h, target_w)))
|
65 |
+
|
66 |
+
pre_data = {'video_data': video_data, 'y': caption}
|
67 |
+
pre_data['target_res'] = (target_h, target_w)
|
68 |
+
|
69 |
+
total_noise_levels = 900
|
70 |
+
setup_seed(666)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
data_tensor = collate_fn(pre_data, 'cuda:0')
|
74 |
+
output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
|
75 |
+
solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
|
76 |
+
max_chunk_len=self.max_chunk_len
|
77 |
+
)
|
78 |
+
|
79 |
+
output = tensor2vid(output)
|
80 |
+
|
81 |
+
# Using color fix
|
82 |
+
output = adain_color_fix(output, video_data)
|
83 |
+
|
84 |
+
save_video(output, self.result_dir, self.file_name, fps=input_fps)
|
85 |
+
return os.path.join(self.result_dir, self.file_name)
|
86 |
+
|
87 |
+
|
88 |
+
def parse_args():
|
89 |
+
parser = ArgumentParser()
|
90 |
+
|
91 |
+
parser.add_argument("--input_path", required=True, type=str, help="input video path")
|
92 |
+
parser.add_argument("--save_dir", type=str, default='results', help="save directory")
|
93 |
+
parser.add_argument("--file_name", type=str, help="file name")
|
94 |
+
parser.add_argument("--model_path", type=str, default='./pretrained_weight/model.pt', help="model path")
|
95 |
+
parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
|
96 |
+
parser.add_argument("--upscale", type=int, default=4, help='up-scale')
|
97 |
+
parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
|
98 |
+
parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
|
99 |
+
|
100 |
+
parser.add_argument("--cfg", type=float, default=7.5)
|
101 |
+
parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
|
102 |
+
parser.add_argument("--steps", type=int, default=15)
|
103 |
+
|
104 |
+
return parser.parse_args()
|
105 |
+
|
106 |
+
def main():
|
107 |
+
|
108 |
+
args = parse_args()
|
109 |
+
|
110 |
+
input_path = args.input_path
|
111 |
+
prompt = args.prompt
|
112 |
+
model_path = args.model_path
|
113 |
+
save_dir = args.save_dir
|
114 |
+
file_name = args.file_name
|
115 |
+
upscale = args.upscale
|
116 |
+
max_chunk_len = args.max_chunk_len
|
117 |
+
|
118 |
+
steps = args.steps
|
119 |
+
solver_mode = args.solver_mode
|
120 |
+
guide_scale = args.cfg
|
121 |
+
|
122 |
+
assert solver_mode in ('fast', 'normal')
|
123 |
+
|
124 |
+
venhancer_sr = VEnhancer_sr(
|
125 |
+
result_dir=save_dir,
|
126 |
+
file_name=file_name, # new added
|
127 |
+
model_path=model_path,
|
128 |
+
solver_mode=solver_mode,
|
129 |
+
steps=steps,
|
130 |
+
guide_scale=guide_scale,
|
131 |
+
upscale=upscale,
|
132 |
+
max_chunk_len=max_chunk_len,
|
133 |
+
variant_info=None,
|
134 |
+
)
|
135 |
+
|
136 |
+
venhancer_sr.enhance_a_video(input_path, prompt)
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
main()
|
video_super_resolution/scripts/inference_sr.sh
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Folder paths
|
4 |
+
video_folder_path='./input/video'
|
5 |
+
txt_file_path='./input/text/prompt.txt'
|
6 |
+
|
7 |
+
# Get all .mp4 files in the folder using find to handle special characters
|
8 |
+
mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
|
9 |
+
|
10 |
+
# Print the list of MP4 files
|
11 |
+
echo "MP4 files to be processed:"
|
12 |
+
for mp4_file in "${mp4_files[@]}"; do
|
13 |
+
echo "$mp4_file"
|
14 |
+
done
|
15 |
+
|
16 |
+
# Read lines from the text file, skipping empty lines
|
17 |
+
mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
|
18 |
+
|
19 |
+
# List of frame counts
|
20 |
+
frame_length=32
|
21 |
+
|
22 |
+
# Debugging output
|
23 |
+
echo "Number of MP4 files: ${#mp4_files[@]}"
|
24 |
+
echo "Number of lines in the text file: ${#lines[@]}"
|
25 |
+
|
26 |
+
# Ensure the number of video files matches the number of lines
|
27 |
+
if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
|
28 |
+
echo "Number of MP4 files and lines in the text file do not match."
|
29 |
+
exit 1
|
30 |
+
fi
|
31 |
+
|
32 |
+
# Loop through video files and corresponding lines
|
33 |
+
for i in "${!mp4_files[@]}"; do
|
34 |
+
mp4_file="${mp4_files[$i]}"
|
35 |
+
line="${lines[$i]}"
|
36 |
+
|
37 |
+
# Extract the filename without the extension
|
38 |
+
file_name=$(basename "$mp4_file" .mp4)
|
39 |
+
|
40 |
+
echo "Processing video file: $mp4_file with prompt: $line"
|
41 |
+
|
42 |
+
# Run Python script with parameters
|
43 |
+
python \
|
44 |
+
./video_super_resolution/scripts/inference_sr.py \
|
45 |
+
--solver_mode 'fast' \
|
46 |
+
--steps 15 \
|
47 |
+
--input_path "${mp4_file}" \
|
48 |
+
--model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/model.pt \
|
49 |
+
--prompt "${line}" \
|
50 |
+
--upscale 4 \
|
51 |
+
--max_chunk_len ${frame_length} \
|
52 |
+
--file_name "${file_name}.mp4" \
|
53 |
+
--save_dir ./results
|
54 |
+
done
|
55 |
+
|
56 |
+
echo "All videos processed successfully."
|
video_to_video/__init__.py
ADDED
File without changes
|
video_to_video/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
video_to_video/__pycache__/video_to_video_model.cpython-39.pyc
ADDED
Binary file (6.11 kB). View file
|
|
video_to_video/diffusion/__init__.py
ADDED
File without changes
|
video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (162 Bytes). View file
|
|
video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc
ADDED
Binary file (10.4 kB). View file
|
|
video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc
ADDED
Binary file (2.68 kB). View file
|
|
video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc
ADDED
Binary file (6.18 kB). View file
|
|
video_to_video/diffusion/diffusion_sdedit.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .schedules_sdedit import karras_schedule
|
6 |
+
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
|
7 |
+
|
8 |
+
from video_to_video.utils.logger import get_logger
|
9 |
+
|
10 |
+
logger = get_logger()
|
11 |
+
|
12 |
+
__all__ = ['GaussianDiffusion']
|
13 |
+
|
14 |
+
|
15 |
+
def _i(tensor, t, x):
|
16 |
+
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
17 |
+
return tensor[t.to(tensor.device)].view(shape).to(x.device)
|
18 |
+
|
19 |
+
class GaussianDiffusion(object):
|
20 |
+
|
21 |
+
def __init__(self, sigmas):
|
22 |
+
self.sigmas = sigmas
|
23 |
+
self.alphas = torch.sqrt(1 - sigmas**2)
|
24 |
+
self.num_timesteps = len(sigmas)
|
25 |
+
|
26 |
+
def diffuse(self, x0, t, noise=None):
|
27 |
+
noise = torch.randn_like(x0) if noise is None else noise
|
28 |
+
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
|
29 |
+
|
30 |
+
return xt
|
31 |
+
|
32 |
+
def get_velocity(self, x0, xt, t):
|
33 |
+
sigmas = _i(self.sigmas, t, xt)
|
34 |
+
alphas = _i(self.alphas, t, xt)
|
35 |
+
velocity = (alphas * xt - x0) / sigmas
|
36 |
+
return velocity
|
37 |
+
|
38 |
+
def get_x0(self, v, xt, t):
|
39 |
+
sigmas = _i(self.sigmas, t, xt)
|
40 |
+
alphas = _i(self.alphas, t, xt)
|
41 |
+
x0 = alphas * xt - sigmas * v
|
42 |
+
return x0
|
43 |
+
|
44 |
+
def denoise(self,
|
45 |
+
xt,
|
46 |
+
t,
|
47 |
+
s,
|
48 |
+
model,
|
49 |
+
model_kwargs={},
|
50 |
+
guide_scale=None,
|
51 |
+
guide_rescale=None,
|
52 |
+
clamp=None,
|
53 |
+
percentile=None,
|
54 |
+
variant_info=None,):
|
55 |
+
s = t - 1 if s is None else s
|
56 |
+
|
57 |
+
# hyperparams
|
58 |
+
sigmas = _i(self.sigmas, t, xt)
|
59 |
+
alphas = _i(self.alphas, t, xt)
|
60 |
+
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
61 |
+
alphas_s[s < 0] = 1.
|
62 |
+
sigmas_s = torch.sqrt(1 - alphas_s**2)
|
63 |
+
|
64 |
+
# precompute variables
|
65 |
+
betas = 1 - (alphas / alphas_s)**2
|
66 |
+
coef1 = betas * alphas_s / sigmas**2
|
67 |
+
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
|
68 |
+
var = betas * (sigmas_s / sigmas)**2
|
69 |
+
log_var = torch.log(var).clamp_(-20, 20)
|
70 |
+
|
71 |
+
# prediction
|
72 |
+
if guide_scale is None:
|
73 |
+
assert isinstance(model_kwargs, dict)
|
74 |
+
out = model(xt, t=t, **model_kwargs)
|
75 |
+
else:
|
76 |
+
# classifier-free guidance
|
77 |
+
assert isinstance(model_kwargs, list)
|
78 |
+
if len(model_kwargs) > 3:
|
79 |
+
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
|
80 |
+
else:
|
81 |
+
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
|
82 |
+
if guide_scale == 1.:
|
83 |
+
out = y_out
|
84 |
+
else:
|
85 |
+
if len(model_kwargs) > 3:
|
86 |
+
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
|
87 |
+
else:
|
88 |
+
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
|
89 |
+
out = u_out + guide_scale * (y_out - u_out)
|
90 |
+
|
91 |
+
if guide_rescale is not None:
|
92 |
+
assert guide_rescale >= 0 and guide_rescale <= 1
|
93 |
+
ratio = (
|
94 |
+
y_out.flatten(1).std(dim=1) / # noqa
|
95 |
+
(out.flatten(1).std(dim=1) + 1e-12)
|
96 |
+
).view((-1, ) + (1, ) * (y_out.ndim - 1))
|
97 |
+
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
98 |
+
|
99 |
+
x0 = alphas * xt - sigmas * out
|
100 |
+
|
101 |
+
# restrict the range of x0
|
102 |
+
if percentile is not None:
|
103 |
+
assert percentile > 0 and percentile <= 1
|
104 |
+
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
|
105 |
+
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
|
106 |
+
x0 = torch.min(s, torch.max(-s, x0)) / s
|
107 |
+
elif clamp is not None:
|
108 |
+
x0 = x0.clamp(-clamp, clamp)
|
109 |
+
|
110 |
+
# recompute eps using the restricted x0
|
111 |
+
eps = (xt - alphas * x0) / sigmas
|
112 |
+
|
113 |
+
# compute mu (mean of posterior distribution) using the restricted x0
|
114 |
+
mu = coef1 * x0 + coef2 * xt
|
115 |
+
return mu, var, log_var, x0, eps
|
116 |
+
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def sample(self,
|
120 |
+
noise,
|
121 |
+
model,
|
122 |
+
model_kwargs={},
|
123 |
+
condition_fn=None,
|
124 |
+
guide_scale=None,
|
125 |
+
guide_rescale=None,
|
126 |
+
clamp=None,
|
127 |
+
percentile=None,
|
128 |
+
solver='euler_a',
|
129 |
+
solver_mode='fast',
|
130 |
+
steps=20,
|
131 |
+
t_max=None,
|
132 |
+
t_min=None,
|
133 |
+
discretization=None,
|
134 |
+
discard_penultimate_step=None,
|
135 |
+
return_intermediate=None,
|
136 |
+
show_progress=False,
|
137 |
+
seed=-1,
|
138 |
+
chunk_inds=None,
|
139 |
+
**kwargs):
|
140 |
+
# sanity check
|
141 |
+
assert isinstance(steps, (int, torch.LongTensor))
|
142 |
+
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
143 |
+
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
144 |
+
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
145 |
+
assert discard_penultimate_step in (None, True, False)
|
146 |
+
assert return_intermediate in (None, 'x0', 'xt')
|
147 |
+
|
148 |
+
# function of diffusion solver
|
149 |
+
solver_fn = {
|
150 |
+
'heun': sample_heun,
|
151 |
+
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
152 |
+
}[solver]
|
153 |
+
|
154 |
+
# options
|
155 |
+
schedule = 'karras' if 'karras' in solver else None
|
156 |
+
discretization = discretization or 'linspace'
|
157 |
+
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
158 |
+
if isinstance(steps, torch.LongTensor):
|
159 |
+
discard_penultimate_step = False
|
160 |
+
if discard_penultimate_step is None:
|
161 |
+
discard_penultimate_step = True if solver in (
|
162 |
+
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
163 |
+
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
164 |
+
|
165 |
+
# function for denoising xt to get x0
|
166 |
+
intermediates = []
|
167 |
+
|
168 |
+
def model_fn(xt, sigma):
|
169 |
+
# denoising
|
170 |
+
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
171 |
+
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
172 |
+
guide_rescale, clamp, percentile)[-2]
|
173 |
+
|
174 |
+
# collect intermediate outputs
|
175 |
+
if return_intermediate == 'xt':
|
176 |
+
intermediates.append(xt)
|
177 |
+
elif return_intermediate == 'x0':
|
178 |
+
intermediates.append(x0)
|
179 |
+
return x0
|
180 |
+
|
181 |
+
mask_cond = model_kwargs[3]['mask_cond']
|
182 |
+
def model_chunk_fn(xt, sigma):
|
183 |
+
# denoising
|
184 |
+
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
185 |
+
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
|
186 |
+
cut_f_ind = O_LEN//2
|
187 |
+
|
188 |
+
results_list = []
|
189 |
+
for i in range(len(chunk_inds)):
|
190 |
+
ind_start, ind_end = chunk_inds[i]
|
191 |
+
xt_chunk = xt[:,:,ind_start:ind_end].clone()
|
192 |
+
cur_f = xt_chunk.size(2)
|
193 |
+
model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
|
194 |
+
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
|
195 |
+
guide_rescale, clamp, percentile)[-2]
|
196 |
+
if i == 0:
|
197 |
+
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
|
198 |
+
elif i == len(chunk_inds)-1:
|
199 |
+
results_list.append(x0_chunk[:,:,cut_f_ind:])
|
200 |
+
else:
|
201 |
+
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
|
202 |
+
x0 = torch.concat(results_list, dim=2)
|
203 |
+
torch.cuda.empty_cache()
|
204 |
+
return x0
|
205 |
+
|
206 |
+
# get timesteps
|
207 |
+
if isinstance(steps, int):
|
208 |
+
steps += 1 if discard_penultimate_step else 0
|
209 |
+
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
210 |
+
t_min = 0 if t_min is None else t_min
|
211 |
+
|
212 |
+
# discretize timesteps
|
213 |
+
if discretization == 'leading':
|
214 |
+
steps = torch.arange(t_min, t_max + 1,
|
215 |
+
(t_max - t_min + 1) / steps).flip(0)
|
216 |
+
elif discretization == 'linspace':
|
217 |
+
steps = torch.linspace(t_max, t_min, steps)
|
218 |
+
elif discretization == 'trailing':
|
219 |
+
steps = torch.arange(t_max, t_min - 1,
|
220 |
+
-((t_max - t_min + 1) / steps))
|
221 |
+
if solver_mode == 'fast':
|
222 |
+
t_mid = 500
|
223 |
+
steps1 = torch.arange(t_max, t_mid - 1,
|
224 |
+
-((t_max - t_mid + 1) / 4))
|
225 |
+
steps2 = torch.arange(t_mid, t_min - 1,
|
226 |
+
-((t_mid - t_min + 1) / 11))
|
227 |
+
steps = torch.concat([steps1, steps2])
|
228 |
+
else:
|
229 |
+
raise NotImplementedError(
|
230 |
+
f'{discretization} discretization not implemented')
|
231 |
+
steps = steps.clamp_(t_min, t_max)
|
232 |
+
steps = torch.as_tensor(
|
233 |
+
steps, dtype=torch.float32, device=noise.device)
|
234 |
+
|
235 |
+
# get sigmas
|
236 |
+
sigmas = self._t_to_sigma(steps)
|
237 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
238 |
+
if schedule == 'karras':
|
239 |
+
if sigmas[0] == float('inf'):
|
240 |
+
sigmas = karras_schedule(
|
241 |
+
n=len(steps) - 1,
|
242 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
243 |
+
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
244 |
+
rho=7.).to(sigmas)
|
245 |
+
sigmas = torch.cat([
|
246 |
+
sigmas.new_tensor([float('inf')]), sigmas,
|
247 |
+
sigmas.new_zeros([1])
|
248 |
+
])
|
249 |
+
else:
|
250 |
+
sigmas = karras_schedule(
|
251 |
+
n=len(steps),
|
252 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
253 |
+
sigma_max=sigmas.max().item(),
|
254 |
+
rho=7.).to(sigmas)
|
255 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
256 |
+
if discard_penultimate_step:
|
257 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
258 |
+
|
259 |
+
fn = model_chunk_fn if chunk_inds is not None else model_fn
|
260 |
+
x0 = solver_fn(
|
261 |
+
noise, fn, sigmas, show_progress=show_progress, **kwargs)
|
262 |
+
return (x0, intermediates) if return_intermediate is not None else x0
|
263 |
+
|
264 |
+
@torch.no_grad()
|
265 |
+
def sample_sr(self,
|
266 |
+
noise,
|
267 |
+
model,
|
268 |
+
model_kwargs={},
|
269 |
+
condition_fn=None,
|
270 |
+
guide_scale=None,
|
271 |
+
guide_rescale=None,
|
272 |
+
clamp=None,
|
273 |
+
percentile=None,
|
274 |
+
solver='euler_a',
|
275 |
+
solver_mode='fast',
|
276 |
+
steps=20,
|
277 |
+
t_max=None,
|
278 |
+
t_min=None,
|
279 |
+
discretization=None,
|
280 |
+
discard_penultimate_step=None,
|
281 |
+
return_intermediate=None,
|
282 |
+
show_progress=False,
|
283 |
+
seed=-1,
|
284 |
+
chunk_inds=None,
|
285 |
+
variant_info=None,
|
286 |
+
**kwargs):
|
287 |
+
# sanity check
|
288 |
+
assert isinstance(steps, (int, torch.LongTensor))
|
289 |
+
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
290 |
+
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
291 |
+
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
292 |
+
assert discard_penultimate_step in (None, True, False)
|
293 |
+
assert return_intermediate in (None, 'x0', 'xt')
|
294 |
+
|
295 |
+
# function of diffusion solver
|
296 |
+
solver_fn = {
|
297 |
+
'heun': sample_heun,
|
298 |
+
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
299 |
+
}[solver]
|
300 |
+
|
301 |
+
# options
|
302 |
+
schedule = 'karras' if 'karras' in solver else None
|
303 |
+
discretization = discretization or 'linspace'
|
304 |
+
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
305 |
+
if isinstance(steps, torch.LongTensor):
|
306 |
+
discard_penultimate_step = False
|
307 |
+
if discard_penultimate_step is None:
|
308 |
+
discard_penultimate_step = True if solver in (
|
309 |
+
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
310 |
+
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
311 |
+
|
312 |
+
# function for denoising xt to get x0
|
313 |
+
intermediates = []
|
314 |
+
|
315 |
+
def model_fn(xt, sigma, variant_info=None):
|
316 |
+
# denoising
|
317 |
+
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
318 |
+
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
319 |
+
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
|
320 |
+
|
321 |
+
# collect intermediate outputs
|
322 |
+
if return_intermediate == 'xt':
|
323 |
+
intermediates.append(xt)
|
324 |
+
elif return_intermediate == 'x0':
|
325 |
+
print('add intermediate outputs x0')
|
326 |
+
intermediates.append(x0)
|
327 |
+
return x0
|
328 |
+
|
329 |
+
# mask_cond = model_kwargs[3]['mask_cond']
|
330 |
+
def model_chunk_fn(xt, sigma, variant_info=None):
|
331 |
+
# denoising
|
332 |
+
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
333 |
+
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
|
334 |
+
cut_f_ind = O_LEN//2
|
335 |
+
|
336 |
+
results_list = []
|
337 |
+
for i in range(len(chunk_inds)):
|
338 |
+
ind_start, ind_end = chunk_inds[i]
|
339 |
+
xt_chunk = xt[:,:,ind_start:ind_end].clone()
|
340 |
+
model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
|
341 |
+
cur_f = xt_chunk.size(2)
|
342 |
+
# model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
|
343 |
+
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
|
344 |
+
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
|
345 |
+
if i == 0:
|
346 |
+
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
|
347 |
+
elif i == len(chunk_inds)-1:
|
348 |
+
results_list.append(x0_chunk[:,:,cut_f_ind:])
|
349 |
+
else:
|
350 |
+
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
|
351 |
+
x0 = torch.concat(results_list, dim=2)
|
352 |
+
torch.cuda.empty_cache()
|
353 |
+
return x0
|
354 |
+
|
355 |
+
# get timesteps
|
356 |
+
if isinstance(steps, int):
|
357 |
+
steps += 1 if discard_penultimate_step else 0
|
358 |
+
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
359 |
+
t_min = 0 if t_min is None else t_min
|
360 |
+
|
361 |
+
# discretize timesteps
|
362 |
+
if discretization == 'leading':
|
363 |
+
steps = torch.arange(t_min, t_max + 1,
|
364 |
+
(t_max - t_min + 1) / steps).flip(0)
|
365 |
+
elif discretization == 'linspace':
|
366 |
+
steps = torch.linspace(t_max, t_min, steps)
|
367 |
+
elif discretization == 'trailing':
|
368 |
+
steps = torch.arange(t_max, t_min - 1,
|
369 |
+
-((t_max - t_min + 1) / steps))
|
370 |
+
if solver_mode == 'fast':
|
371 |
+
t_mid = 500
|
372 |
+
steps1 = torch.arange(t_max, t_mid - 1,
|
373 |
+
-((t_max - t_mid + 1) / 4))
|
374 |
+
steps2 = torch.arange(t_mid, t_min - 1,
|
375 |
+
-((t_mid - t_min + 1) / 11))
|
376 |
+
steps = torch.concat([steps1, steps2])
|
377 |
+
else:
|
378 |
+
raise NotImplementedError(
|
379 |
+
f'{discretization} discretization not implemented')
|
380 |
+
steps = steps.clamp_(t_min, t_max)
|
381 |
+
steps = torch.as_tensor(
|
382 |
+
steps, dtype=torch.float32, device=noise.device)
|
383 |
+
|
384 |
+
# get sigmas
|
385 |
+
sigmas = self._t_to_sigma(steps)
|
386 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
387 |
+
if schedule == 'karras':
|
388 |
+
if sigmas[0] == float('inf'):
|
389 |
+
sigmas = karras_schedule(
|
390 |
+
n=len(steps) - 1,
|
391 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
392 |
+
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
393 |
+
rho=7.).to(sigmas)
|
394 |
+
sigmas = torch.cat([
|
395 |
+
sigmas.new_tensor([float('inf')]), sigmas,
|
396 |
+
sigmas.new_zeros([1])
|
397 |
+
])
|
398 |
+
else:
|
399 |
+
sigmas = karras_schedule(
|
400 |
+
n=len(steps),
|
401 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
402 |
+
sigma_max=sigmas.max().item(),
|
403 |
+
rho=7.).to(sigmas)
|
404 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
405 |
+
if discard_penultimate_step:
|
406 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
407 |
+
|
408 |
+
|
409 |
+
fn = model_chunk_fn if chunk_inds is not None else model_fn
|
410 |
+
x0 = solver_fn(
|
411 |
+
noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
|
412 |
+
return (x0, intermediates) if return_intermediate is not None else x0
|
413 |
+
|
414 |
+
|
415 |
+
def _sigma_to_t(self, sigma):
|
416 |
+
if sigma == float('inf'):
|
417 |
+
t = torch.full_like(sigma, len(self.sigmas) - 1)
|
418 |
+
else:
|
419 |
+
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
420 |
+
(1 - self.sigmas**2)).log().to(sigma)
|
421 |
+
log_sigma = sigma.log()
|
422 |
+
dists = log_sigma - log_sigmas[:, None]
|
423 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
|
424 |
+
max=log_sigmas.shape[0] - 2)
|
425 |
+
high_idx = low_idx + 1
|
426 |
+
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
|
427 |
+
w = (low - log_sigma) / (low - high)
|
428 |
+
w = w.clamp(0, 1)
|
429 |
+
t = (1 - w) * low_idx + w * high_idx
|
430 |
+
t = t.view(sigma.shape)
|
431 |
+
if t.ndim == 0:
|
432 |
+
t = t.unsqueeze(0)
|
433 |
+
return t
|
434 |
+
|
435 |
+
def _t_to_sigma(self, t):
|
436 |
+
t = t.float()
|
437 |
+
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
438 |
+
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
439 |
+
(1 - self.sigmas**2)).log().to(t)
|
440 |
+
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
441 |
+
log_sigma[torch.isnan(log_sigma)
|
442 |
+
| torch.isinf(log_sigma)] = float('inf')
|
443 |
+
return log_sigma.exp()
|
video_to_video/diffusion/schedules_sdedit.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def betas_to_sigmas(betas):
|
9 |
+
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
10 |
+
|
11 |
+
|
12 |
+
def sigmas_to_betas(sigmas):
|
13 |
+
square_alphas = 1 - sigmas**2
|
14 |
+
betas = 1 - torch.cat(
|
15 |
+
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
|
16 |
+
return betas
|
17 |
+
|
18 |
+
|
19 |
+
def logsnrs_to_sigmas(logsnrs):
|
20 |
+
return torch.sqrt(torch.sigmoid(-logsnrs))
|
21 |
+
|
22 |
+
|
23 |
+
def sigmas_to_logsnrs(sigmas):
|
24 |
+
square_sigmas = sigmas**2
|
25 |
+
return torch.log(square_sigmas / (1 - square_sigmas))
|
26 |
+
|
27 |
+
|
28 |
+
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
|
29 |
+
t_min = math.atan(math.exp(-0.5 * logsnr_min))
|
30 |
+
t_max = math.atan(math.exp(-0.5 * logsnr_max))
|
31 |
+
t = torch.linspace(1, 0, n)
|
32 |
+
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
33 |
+
return logsnrs
|
34 |
+
|
35 |
+
|
36 |
+
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
|
37 |
+
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
|
38 |
+
logsnrs += 2 * math.log(1 / scale)
|
39 |
+
return logsnrs
|
40 |
+
|
41 |
+
|
42 |
+
def _logsnr_cosine_interp(n,
|
43 |
+
logsnr_min=-15,
|
44 |
+
logsnr_max=15,
|
45 |
+
scale_min=2,
|
46 |
+
scale_max=4):
|
47 |
+
t = torch.linspace(1, 0, n)
|
48 |
+
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
|
49 |
+
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
|
50 |
+
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
|
51 |
+
return logsnrs
|
52 |
+
|
53 |
+
|
54 |
+
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
55 |
+
ramp = torch.linspace(1, 0, n)
|
56 |
+
min_inv_rho = sigma_min**(1 / rho)
|
57 |
+
max_inv_rho = sigma_max**(1 / rho)
|
58 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
|
59 |
+
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
|
60 |
+
return sigmas
|
61 |
+
|
62 |
+
|
63 |
+
def logsnr_cosine_interp_schedule(n,
|
64 |
+
logsnr_min=-15,
|
65 |
+
logsnr_max=15,
|
66 |
+
scale_min=2,
|
67 |
+
scale_max=4):
|
68 |
+
return logsnrs_to_sigmas(
|
69 |
+
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
|
70 |
+
|
71 |
+
|
72 |
+
def noise_schedule(schedule='logsnr_cosine_interp',
|
73 |
+
n=1000,
|
74 |
+
zero_terminal_snr=False,
|
75 |
+
**kwargs):
|
76 |
+
# compute sigmas
|
77 |
+
sigmas = {
|
78 |
+
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
|
79 |
+
}[schedule](n, **kwargs)
|
80 |
+
|
81 |
+
# post-processing
|
82 |
+
if zero_terminal_snr and sigmas.max() != 1.0:
|
83 |
+
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
|
84 |
+
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
|
85 |
+
return sigmas
|
video_to_video/diffusion/solvers_sdedit.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchsde
|
5 |
+
from tqdm.auto import trange
|
6 |
+
|
7 |
+
from video_to_video.utils.logger import get_logger
|
8 |
+
|
9 |
+
logger = get_logger()
|
10 |
+
|
11 |
+
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
12 |
+
"""
|
13 |
+
Calculates the noise level (sigma_down) to step down to and the amount
|
14 |
+
of noise to add (sigma_up) when doing an ancestral sampling step.
|
15 |
+
"""
|
16 |
+
if not eta:
|
17 |
+
return sigma_to, 0.
|
18 |
+
sigma_up = min(
|
19 |
+
sigma_to,
|
20 |
+
eta * (
|
21 |
+
sigma_to**2 * # noqa
|
22 |
+
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
|
23 |
+
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
|
24 |
+
return sigma_down, sigma_up
|
25 |
+
|
26 |
+
|
27 |
+
def get_scalings(sigma):
|
28 |
+
c_out = -sigma
|
29 |
+
c_in = 1 / (sigma**2 + 1.**2)**0.5
|
30 |
+
return c_out, c_in
|
31 |
+
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def sample_heun(noise,
|
35 |
+
model,
|
36 |
+
sigmas,
|
37 |
+
s_churn=0.,
|
38 |
+
s_tmin=0.,
|
39 |
+
s_tmax=float('inf'),
|
40 |
+
s_noise=1.,
|
41 |
+
show_progress=True):
|
42 |
+
"""
|
43 |
+
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
|
44 |
+
"""
|
45 |
+
x = noise * sigmas[0]
|
46 |
+
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
47 |
+
gamma = 0.
|
48 |
+
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
|
49 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
50 |
+
eps = torch.randn_like(x) * s_noise
|
51 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
52 |
+
if gamma > 0:
|
53 |
+
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
|
54 |
+
if sigmas[i] == float('inf'):
|
55 |
+
# Euler method
|
56 |
+
denoised = model(noise, sigma_hat)
|
57 |
+
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
|
58 |
+
else:
|
59 |
+
_, c_in = get_scalings(sigma_hat)
|
60 |
+
denoised = model(x * c_in, sigma_hat)
|
61 |
+
d = (x - denoised) / sigma_hat
|
62 |
+
dt = sigmas[i + 1] - sigma_hat
|
63 |
+
if sigmas[i + 1] == 0:
|
64 |
+
# Euler method
|
65 |
+
x = x + d * dt
|
66 |
+
else:
|
67 |
+
# Heun's method
|
68 |
+
x_2 = x + d * dt
|
69 |
+
_, c_in = get_scalings(sigmas[i + 1])
|
70 |
+
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
|
71 |
+
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
|
72 |
+
d_prime = (d + d_2) / 2
|
73 |
+
x = x + d_prime * dt
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class BatchedBrownianTree:
|
78 |
+
"""
|
79 |
+
A wrapper around torchsde.BrownianTree that enables batches of entropy.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
83 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
84 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
85 |
+
if seed is None:
|
86 |
+
seed = torch.randint(0, 2**63 - 1, []).item()
|
87 |
+
self.batched = True
|
88 |
+
try:
|
89 |
+
assert len(seed) == x.shape[0]
|
90 |
+
w0 = w0[0]
|
91 |
+
except TypeError:
|
92 |
+
seed = [seed]
|
93 |
+
self.batched = False
|
94 |
+
self.trees = [
|
95 |
+
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
|
96 |
+
for s in seed
|
97 |
+
]
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def sort(a, b):
|
101 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
102 |
+
|
103 |
+
def __call__(self, t0, t1):
|
104 |
+
t0, t1, sign = self.sort(t0, t1)
|
105 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
|
106 |
+
self.sign * sign)
|
107 |
+
return w if self.batched else w[0]
|
108 |
+
|
109 |
+
|
110 |
+
class BrownianTreeNoiseSampler:
|
111 |
+
"""
|
112 |
+
A noise sampler backed by a torchsde.BrownianTree.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
116 |
+
random samples.
|
117 |
+
sigma_min (float): The low end of the valid interval.
|
118 |
+
sigma_max (float): The high end of the valid interval.
|
119 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
120 |
+
supplied instead of a single integer, then the noise sampler will
|
121 |
+
use one BrownianTree per batch item, each with its own seed.
|
122 |
+
transform (callable): A function that maps sigma to the sampler's
|
123 |
+
internal timestep.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self,
|
127 |
+
x,
|
128 |
+
sigma_min,
|
129 |
+
sigma_max,
|
130 |
+
seed=None,
|
131 |
+
transform=lambda x: x):
|
132 |
+
self.transform = transform
|
133 |
+
t0 = self.transform(torch.as_tensor(sigma_min))
|
134 |
+
t1 = self.transform(torch.as_tensor(sigma_max))
|
135 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
136 |
+
|
137 |
+
def __call__(self, sigma, sigma_next):
|
138 |
+
t0 = self.transform(torch.as_tensor(sigma))
|
139 |
+
t1 = self.transform(torch.as_tensor(sigma_next))
|
140 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
141 |
+
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def sample_dpmpp_2m_sde(noise,
|
145 |
+
model,
|
146 |
+
sigmas,
|
147 |
+
eta=1.,
|
148 |
+
s_noise=1.,
|
149 |
+
solver_type='midpoint',
|
150 |
+
show_progress=True,
|
151 |
+
variant_info=None):
|
152 |
+
"""
|
153 |
+
DPM-Solver++ (2M) SDE.
|
154 |
+
"""
|
155 |
+
assert solver_type in {'heun', 'midpoint'}
|
156 |
+
|
157 |
+
x = noise * sigmas[0]
|
158 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
|
159 |
+
sigmas < float('inf')].max()
|
160 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
161 |
+
old_denoised = None
|
162 |
+
h_last = None
|
163 |
+
|
164 |
+
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
165 |
+
logger.info(f'step: {i}')
|
166 |
+
if sigmas[i] == float('inf'):
|
167 |
+
# Euler method
|
168 |
+
denoised = model(noise, sigmas[i], variant_info=variant_info)
|
169 |
+
x = denoised + sigmas[i + 1] * noise
|
170 |
+
else:
|
171 |
+
_, c_in = get_scalings(sigmas[i])
|
172 |
+
denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
|
173 |
+
if sigmas[i + 1] == 0:
|
174 |
+
# Denoising step
|
175 |
+
x = denoised
|
176 |
+
else:
|
177 |
+
# DPM-Solver++(2M) SDE
|
178 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
179 |
+
h = s - t
|
180 |
+
eta_h = eta * h
|
181 |
+
|
182 |
+
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
|
183 |
+
(-h - eta_h).expm1().neg() * denoised
|
184 |
+
|
185 |
+
if old_denoised is not None:
|
186 |
+
r = h_last / h
|
187 |
+
if solver_type == 'heun':
|
188 |
+
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
|
189 |
+
(1 / r) * (denoised - old_denoised)
|
190 |
+
elif solver_type == 'midpoint':
|
191 |
+
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
|
192 |
+
(1 / r) * (denoised - old_denoised)
|
193 |
+
|
194 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
|
195 |
+
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
196 |
+
|
197 |
+
old_denoised = denoised
|
198 |
+
h_last = h
|
199 |
+
|
200 |
+
if variant_info is not None and variant_info.get('type') == 'variant1':
|
201 |
+
x_long, x_short = x.chunk(2, dim=0)
|
202 |
+
x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
|
203 |
+
|
204 |
+
return x
|
video_to_video/modules/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .embedder import *
|
2 |
+
from .unet_v2v import *
|
3 |
+
# from .unet_v2v_deform import *
|
video_to_video/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (205 Bytes). View file
|
|
video_to_video/modules/__pycache__/embedder.cpython-39.pyc
ADDED
Binary file (2.58 kB). View file
|
|
video_to_video/modules/__pycache__/t5.cpython-39.pyc
ADDED
Binary file (7.07 kB). View file
|
|
video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc
ADDED
Binary file (47.6 kB). View file
|
|
video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc
ADDED
Binary file (47.8 kB). View file
|
|
video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc
ADDED
Binary file (48.2 kB). View file
|
|
video_to_video/modules/embedder.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import open_clip
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torchvision.transforms as T
|
10 |
+
|
11 |
+
|
12 |
+
class FrozenOpenCLIPEmbedder(nn.Module):
|
13 |
+
"""
|
14 |
+
Uses the OpenCLIP transformer encoder for text
|
15 |
+
"""
|
16 |
+
LAYERS = ['last', 'penultimate']
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
pretrained='laion2b_s32b_b79k',
|
20 |
+
arch='ViT-H-14',
|
21 |
+
device='cuda',
|
22 |
+
max_length=77,
|
23 |
+
freeze=True,
|
24 |
+
layer='penultimate'):
|
25 |
+
super().__init__()
|
26 |
+
assert layer in self.LAYERS
|
27 |
+
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
|
28 |
+
|
29 |
+
del model.visual
|
30 |
+
self.model = model
|
31 |
+
self.device = device
|
32 |
+
self.max_length = max_length
|
33 |
+
|
34 |
+
if freeze:
|
35 |
+
self.freeze()
|
36 |
+
self.layer = layer
|
37 |
+
if self.layer == 'last':
|
38 |
+
self.layer_idx = 0
|
39 |
+
elif self.layer == 'penultimate':
|
40 |
+
self.layer_idx = 1
|
41 |
+
else:
|
42 |
+
raise NotImplementedError()
|
43 |
+
|
44 |
+
def freeze(self):
|
45 |
+
self.model = self.model.eval()
|
46 |
+
for param in self.parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
|
49 |
+
def forward(self, text):
|
50 |
+
tokens = open_clip.tokenize(text)
|
51 |
+
z = self.encode_with_transformer(tokens.to(self.device))
|
52 |
+
return z
|
53 |
+
|
54 |
+
def encode_with_transformer(self, text):
|
55 |
+
x = self.model.token_embedding(text)
|
56 |
+
x = x + self.model.positional_embedding
|
57 |
+
x = x.permute(1, 0, 2)
|
58 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
59 |
+
x = x.permute(1, 0, 2)
|
60 |
+
x = self.model.ln_final(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
64 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
65 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
66 |
+
break
|
67 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
68 |
+
):
|
69 |
+
x = checkpoint(r, x, attn_mask)
|
70 |
+
else:
|
71 |
+
x = r(x, attn_mask=attn_mask)
|
72 |
+
return x
|
73 |
+
|
74 |
+
def encode(self, text):
|
75 |
+
return self(text)
|
video_to_video/modules/t5.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from PixArt
|
2 |
+
#
|
3 |
+
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
|
4 |
+
#
|
5 |
+
# This program is free software: you can redistribute it and/or modify
|
6 |
+
# it under the terms of the GNU Affero General Public License as published
|
7 |
+
# by the Free Software Foundation, either version 3 of the License, or
|
8 |
+
# (at your option) any later version.
|
9 |
+
#
|
10 |
+
# This program is distributed in the hope that it will be useful,
|
11 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
# GNU Affero General Public License for more details.
|
14 |
+
#
|
15 |
+
#
|
16 |
+
# This source code is licensed under the license found in the
|
17 |
+
# LICENSE file in the root directory of this source tree.
|
18 |
+
# --------------------------------------------------------
|
19 |
+
# References:
|
20 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
|
21 |
+
# T5: https://github.com/google-research/text-to-text-transfer-transformer
|
22 |
+
# --------------------------------------------------------
|
23 |
+
|
24 |
+
import html
|
25 |
+
import re
|
26 |
+
|
27 |
+
import ftfy
|
28 |
+
import torch
|
29 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
30 |
+
|
31 |
+
# from opensora.registry import MODELS
|
32 |
+
|
33 |
+
|
34 |
+
class T5Embedder:
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
device,
|
38 |
+
from_pretrained=None,
|
39 |
+
*,
|
40 |
+
cache_dir=None,
|
41 |
+
hf_token=None,
|
42 |
+
use_text_preprocessing=True,
|
43 |
+
t5_model_kwargs=None,
|
44 |
+
torch_dtype=None,
|
45 |
+
use_offload_folder=None,
|
46 |
+
model_max_length=120,
|
47 |
+
local_files_only=False,
|
48 |
+
):
|
49 |
+
self.device = torch.device(device)
|
50 |
+
self.torch_dtype = torch_dtype or torch.bfloat16
|
51 |
+
self.cache_dir = cache_dir
|
52 |
+
|
53 |
+
if t5_model_kwargs is None:
|
54 |
+
t5_model_kwargs = {
|
55 |
+
"low_cpu_mem_usage": True,
|
56 |
+
"torch_dtype": self.torch_dtype,
|
57 |
+
}
|
58 |
+
|
59 |
+
if use_offload_folder is not None:
|
60 |
+
t5_model_kwargs["offload_folder"] = use_offload_folder
|
61 |
+
t5_model_kwargs["device_map"] = {
|
62 |
+
"shared": self.device,
|
63 |
+
"encoder.embed_tokens": self.device,
|
64 |
+
"encoder.block.0": self.device,
|
65 |
+
"encoder.block.1": self.device,
|
66 |
+
"encoder.block.2": self.device,
|
67 |
+
"encoder.block.3": self.device,
|
68 |
+
"encoder.block.4": self.device,
|
69 |
+
"encoder.block.5": self.device,
|
70 |
+
"encoder.block.6": self.device,
|
71 |
+
"encoder.block.7": self.device,
|
72 |
+
"encoder.block.8": self.device,
|
73 |
+
"encoder.block.9": self.device,
|
74 |
+
"encoder.block.10": self.device,
|
75 |
+
"encoder.block.11": self.device,
|
76 |
+
"encoder.block.12": "disk",
|
77 |
+
"encoder.block.13": "disk",
|
78 |
+
"encoder.block.14": "disk",
|
79 |
+
"encoder.block.15": "disk",
|
80 |
+
"encoder.block.16": "disk",
|
81 |
+
"encoder.block.17": "disk",
|
82 |
+
"encoder.block.18": "disk",
|
83 |
+
"encoder.block.19": "disk",
|
84 |
+
"encoder.block.20": "disk",
|
85 |
+
"encoder.block.21": "disk",
|
86 |
+
"encoder.block.22": "disk",
|
87 |
+
"encoder.block.23": "disk",
|
88 |
+
"encoder.final_layer_norm": "disk",
|
89 |
+
"encoder.dropout": "disk",
|
90 |
+
}
|
91 |
+
else:
|
92 |
+
t5_model_kwargs["device_map"] = {
|
93 |
+
"shared": self.device,
|
94 |
+
"encoder": self.device,
|
95 |
+
}
|
96 |
+
|
97 |
+
self.use_text_preprocessing = use_text_preprocessing
|
98 |
+
self.hf_token = hf_token
|
99 |
+
|
100 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
101 |
+
from_pretrained,
|
102 |
+
cache_dir=cache_dir,
|
103 |
+
local_files_only=local_files_only,
|
104 |
+
)
|
105 |
+
self.model = T5EncoderModel.from_pretrained(
|
106 |
+
from_pretrained,
|
107 |
+
cache_dir=cache_dir,
|
108 |
+
local_files_only=local_files_only,
|
109 |
+
**t5_model_kwargs,
|
110 |
+
).eval()
|
111 |
+
self.model_max_length = model_max_length
|
112 |
+
|
113 |
+
def get_text_embeddings(self, texts):
|
114 |
+
text_tokens_and_mask = self.tokenizer(
|
115 |
+
texts,
|
116 |
+
max_length=self.model_max_length,
|
117 |
+
padding="max_length",
|
118 |
+
truncation=True,
|
119 |
+
return_attention_mask=True,
|
120 |
+
add_special_tokens=True,
|
121 |
+
return_tensors="pt",
|
122 |
+
)
|
123 |
+
|
124 |
+
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
|
125 |
+
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
|
126 |
+
with torch.no_grad():
|
127 |
+
text_encoder_embs = self.model(
|
128 |
+
input_ids=input_ids,
|
129 |
+
attention_mask=attention_mask,
|
130 |
+
)["last_hidden_state"].detach()
|
131 |
+
return text_encoder_embs, attention_mask
|
132 |
+
|
133 |
+
|
134 |
+
# @MODELS.register_module("t5")
|
135 |
+
class T5Encoder:
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
from_pretrained=None,
|
139 |
+
model_max_length=120,
|
140 |
+
device="cuda",
|
141 |
+
dtype=torch.float,
|
142 |
+
cache_dir=None,
|
143 |
+
shardformer=False,
|
144 |
+
local_files_only=False,
|
145 |
+
):
|
146 |
+
assert from_pretrained is not None, "Please specify the path to the T5 model"
|
147 |
+
|
148 |
+
self.t5 = T5Embedder(
|
149 |
+
device=device,
|
150 |
+
torch_dtype=dtype,
|
151 |
+
from_pretrained=from_pretrained,
|
152 |
+
cache_dir=cache_dir,
|
153 |
+
model_max_length=model_max_length,
|
154 |
+
local_files_only=local_files_only,
|
155 |
+
)
|
156 |
+
self.t5.model.to(dtype=dtype)
|
157 |
+
self.y_embedder = None
|
158 |
+
|
159 |
+
self.model_max_length = model_max_length
|
160 |
+
self.output_dim = self.t5.model.config.d_model
|
161 |
+
self.dtype = dtype
|
162 |
+
|
163 |
+
if shardformer:
|
164 |
+
self.shardformer_t5()
|
165 |
+
|
166 |
+
def shardformer_t5(self):
|
167 |
+
from colossalai.shardformer import ShardConfig, ShardFormer
|
168 |
+
|
169 |
+
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
|
170 |
+
from opensora.utils.misc import requires_grad
|
171 |
+
|
172 |
+
shard_config = ShardConfig(
|
173 |
+
tensor_parallel_process_group=None,
|
174 |
+
pipeline_stage_manager=None,
|
175 |
+
enable_tensor_parallelism=False,
|
176 |
+
enable_fused_normalization=False,
|
177 |
+
enable_flash_attention=False,
|
178 |
+
enable_jit_fused=True,
|
179 |
+
enable_sequence_parallelism=False,
|
180 |
+
enable_sequence_overlap=False,
|
181 |
+
)
|
182 |
+
shard_former = ShardFormer(shard_config=shard_config)
|
183 |
+
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
|
184 |
+
self.t5.model = optim_model.to(self.dtype)
|
185 |
+
|
186 |
+
# ensure the weights are frozen
|
187 |
+
requires_grad(self.t5.model, False)
|
188 |
+
|
189 |
+
def encode(self, text):
|
190 |
+
caption_embs, emb_masks = self.t5.get_text_embeddings(text)
|
191 |
+
caption_embs = caption_embs[:, None]
|
192 |
+
return dict(y=caption_embs, mask=emb_masks)
|
193 |
+
|
194 |
+
def null(self, n):
|
195 |
+
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
|
196 |
+
return null_y
|
197 |
+
|
198 |
+
|
199 |
+
def basic_clean(text):
|
200 |
+
text = ftfy.fix_text(text)
|
201 |
+
text = html.unescape(html.unescape(text))
|
202 |
+
return text.strip()
|
203 |
+
|
204 |
+
|
205 |
+
BAD_PUNCT_REGEX = re.compile(
|
206 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
207 |
+
) # noqa
|
208 |
+
|
209 |
+
|
210 |
+
def clean_caption(caption):
|
211 |
+
import urllib.parse as ul
|
212 |
+
|
213 |
+
from bs4 import BeautifulSoup
|
214 |
+
|
215 |
+
caption = str(caption)
|
216 |
+
caption = ul.unquote_plus(caption)
|
217 |
+
caption = caption.strip().lower()
|
218 |
+
caption = re.sub("<person>", "person", caption)
|
219 |
+
# urls:
|
220 |
+
caption = re.sub(
|
221 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
222 |
+
"",
|
223 |
+
caption,
|
224 |
+
) # regex for urls
|
225 |
+
caption = re.sub(
|
226 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
227 |
+
"",
|
228 |
+
caption,
|
229 |
+
) # regex for urls
|
230 |
+
# html:
|
231 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
232 |
+
|
233 |
+
# @<nickname>
|
234 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
235 |
+
|
236 |
+
# 31C0—31EF CJK Strokes
|
237 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
238 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
239 |
+
# 3300—33FF CJK Compatibility
|
240 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
241 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
242 |
+
# 4E00—9FFF CJK Unified Ideographs
|
243 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
244 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
245 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
246 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
247 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
248 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
249 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
250 |
+
#######################################################
|
251 |
+
|
252 |
+
# все виды тире / all types of dash --> "-"
|
253 |
+
caption = re.sub(
|
254 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
255 |
+
"-",
|
256 |
+
caption,
|
257 |
+
)
|
258 |
+
|
259 |
+
# кавычки к одному стандарту
|
260 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
261 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
262 |
+
|
263 |
+
# "
|
264 |
+
caption = re.sub(r""?", "", caption)
|
265 |
+
# &
|
266 |
+
caption = re.sub(r"&", "", caption)
|
267 |
+
|
268 |
+
# ip adresses:
|
269 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
270 |
+
|
271 |
+
# article ids:
|
272 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
273 |
+
|
274 |
+
# \n
|
275 |
+
caption = re.sub(r"\\n", " ", caption)
|
276 |
+
|
277 |
+
# "#123"
|
278 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
279 |
+
# "#12345.."
|
280 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
281 |
+
# "123456.."
|
282 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
283 |
+
# filenames:
|
284 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
285 |
+
|
286 |
+
#
|
287 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
288 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
289 |
+
|
290 |
+
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
291 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
292 |
+
|
293 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
294 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
295 |
+
if len(re.findall(regex2, caption)) > 3:
|
296 |
+
caption = re.sub(regex2, " ", caption)
|
297 |
+
|
298 |
+
caption = basic_clean(caption)
|
299 |
+
|
300 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
301 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
302 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
303 |
+
|
304 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
305 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
306 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
307 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
308 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
309 |
+
|
310 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
311 |
+
|
312 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
313 |
+
|
314 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
315 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
316 |
+
caption = re.sub(r"\s+", " ", caption)
|
317 |
+
|
318 |
+
caption.strip()
|
319 |
+
|
320 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
321 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
322 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
323 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
324 |
+
|
325 |
+
return caption.strip()
|
326 |
+
|
327 |
+
|
328 |
+
def text_preprocessing(text, use_text_preprocessing: bool = True):
|
329 |
+
if use_text_preprocessing:
|
330 |
+
# The exact text cleaning as was in the training stage:
|
331 |
+
text = clean_caption(text)
|
332 |
+
text = clean_caption(text)
|
333 |
+
return text
|
334 |
+
else:
|
335 |
+
return text.lower().strip()
|
video_to_video/modules/unet_v2v.py
ADDED
@@ -0,0 +1,2332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from abc import abstractmethod
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import xformers
|
11 |
+
import xformers.ops
|
12 |
+
from einops import rearrange
|
13 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
14 |
+
from timm.models.vision_transformer import Mlp
|
15 |
+
|
16 |
+
|
17 |
+
USE_TEMPORAL_TRANSFORMER = True
|
18 |
+
|
19 |
+
|
20 |
+
class CaptionEmbedder(nn.Module):
|
21 |
+
"""
|
22 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
|
26 |
+
super().__init__()
|
27 |
+
self.y_proj = Mlp(
|
28 |
+
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
|
29 |
+
)
|
30 |
+
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
|
31 |
+
self.uncond_prob = uncond_prob
|
32 |
+
|
33 |
+
def token_drop(self, caption, force_drop_ids=None):
|
34 |
+
"""
|
35 |
+
Drops labels to enable classifier-free guidance.
|
36 |
+
"""
|
37 |
+
if force_drop_ids is None:
|
38 |
+
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
39 |
+
else:
|
40 |
+
drop_ids = force_drop_ids == 1
|
41 |
+
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
42 |
+
return caption
|
43 |
+
|
44 |
+
def forward(self, caption, train, force_drop_ids=None):
|
45 |
+
if train:
|
46 |
+
assert caption.shape[2:] == self.y_embedding.shape
|
47 |
+
use_dropout = self.uncond_prob > 0
|
48 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
49 |
+
caption = self.token_drop(caption, force_drop_ids)
|
50 |
+
caption = self.y_proj(caption)
|
51 |
+
return caption
|
52 |
+
|
53 |
+
|
54 |
+
class DropPath(nn.Module):
|
55 |
+
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, p):
|
59 |
+
super(DropPath, self).__init__()
|
60 |
+
self.p = p
|
61 |
+
|
62 |
+
def forward(self, *args, zero=None, keep=None):
|
63 |
+
if not self.training:
|
64 |
+
return args[0] if len(args) == 1 else args
|
65 |
+
|
66 |
+
# params
|
67 |
+
x = args[0]
|
68 |
+
b = x.size(0)
|
69 |
+
n = (torch.rand(b) < self.p).sum()
|
70 |
+
|
71 |
+
# non-zero and non-keep mask
|
72 |
+
mask = x.new_ones(b, dtype=torch.bool)
|
73 |
+
if keep is not None:
|
74 |
+
mask[keep] = False
|
75 |
+
if zero is not None:
|
76 |
+
mask[zero] = False
|
77 |
+
|
78 |
+
# drop-path index
|
79 |
+
index = torch.where(mask)[0]
|
80 |
+
index = index[torch.randperm(len(index))[:n]]
|
81 |
+
if zero is not None:
|
82 |
+
index = torch.cat([index, torch.where(zero)[0]], dim=0)
|
83 |
+
|
84 |
+
# drop-path multiplier
|
85 |
+
multiplier = x.new_ones(b)
|
86 |
+
multiplier[index] = 0.0
|
87 |
+
output = tuple(u * self.broadcast(multiplier, u) for u in args)
|
88 |
+
return output[0] if len(args) == 1 else output
|
89 |
+
|
90 |
+
def broadcast(self, src, dst):
|
91 |
+
assert src.size(0) == dst.size(0)
|
92 |
+
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
|
93 |
+
return src.view(shape)
|
94 |
+
|
95 |
+
|
96 |
+
def sinusoidal_embedding(timesteps, dim):
|
97 |
+
# check input
|
98 |
+
half = dim // 2
|
99 |
+
timesteps = timesteps.float()
|
100 |
+
|
101 |
+
# compute sinusoidal embedding
|
102 |
+
sinusoid = torch.outer(
|
103 |
+
timesteps, torch.pow(10000,
|
104 |
+
-torch.arange(half).to(timesteps).div(half)))
|
105 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
106 |
+
if dim % 2 != 0:
|
107 |
+
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
def exists(x):
|
112 |
+
return x is not None
|
113 |
+
|
114 |
+
|
115 |
+
def default(val, d):
|
116 |
+
if exists(val):
|
117 |
+
return val
|
118 |
+
return d() if callable(d) else d
|
119 |
+
|
120 |
+
|
121 |
+
def prob_mask_like(shape, prob, device):
|
122 |
+
if prob == 1:
|
123 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
124 |
+
elif prob == 0:
|
125 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
126 |
+
else:
|
127 |
+
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
128 |
+
# aviod mask all, which will cause find_unused_parameters error
|
129 |
+
if mask.all():
|
130 |
+
mask[0] = False
|
131 |
+
return mask
|
132 |
+
|
133 |
+
|
134 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self,
|
137 |
+
query_dim,
|
138 |
+
context_dim=None,
|
139 |
+
heads=8,
|
140 |
+
dim_head=64,
|
141 |
+
max_bs=16384,
|
142 |
+
dropout=0.0):
|
143 |
+
super().__init__()
|
144 |
+
inner_dim = dim_head * heads
|
145 |
+
context_dim = default(context_dim, query_dim)
|
146 |
+
|
147 |
+
self.max_bs = max_bs
|
148 |
+
self.heads = heads
|
149 |
+
self.dim_head = dim_head
|
150 |
+
|
151 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
152 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
153 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
154 |
+
self.to_out = nn.Sequential(
|
155 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
156 |
+
self.attention_op: Optional[Any] = None
|
157 |
+
|
158 |
+
def forward(self, x, context=None, mask=None):
|
159 |
+
q = self.to_q(x)
|
160 |
+
context = default(context, x)
|
161 |
+
k = self.to_k(context)
|
162 |
+
v = self.to_v(context)
|
163 |
+
|
164 |
+
b, _, _ = q.shape
|
165 |
+
q, k, v = map(
|
166 |
+
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
167 |
+
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
168 |
+
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
169 |
+
(q, k, v),
|
170 |
+
)
|
171 |
+
|
172 |
+
# actually compute the attention, what we cannot get enough of.
|
173 |
+
if q.shape[0] > self.max_bs:
|
174 |
+
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
|
175 |
+
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
|
176 |
+
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
|
177 |
+
out_list = []
|
178 |
+
for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
|
179 |
+
out = xformers.ops.memory_efficient_attention(
|
180 |
+
q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
|
181 |
+
out_list.append(out)
|
182 |
+
out = torch.cat(out_list, dim=0)
|
183 |
+
else:
|
184 |
+
out = xformers.ops.memory_efficient_attention(
|
185 |
+
q, k, v, attn_bias=None, op=self.attention_op)
|
186 |
+
|
187 |
+
if exists(mask):
|
188 |
+
raise NotImplementedError
|
189 |
+
out = (
|
190 |
+
out.unsqueeze(0).reshape(
|
191 |
+
b, self.heads, out.shape[1],
|
192 |
+
self.dim_head).permute(0, 2, 1,
|
193 |
+
3).reshape(b, out.shape[1],
|
194 |
+
self.heads * self.dim_head))
|
195 |
+
return self.to_out(out)
|
196 |
+
|
197 |
+
|
198 |
+
class RelativePositionBias(nn.Module):
|
199 |
+
|
200 |
+
def __init__(self, heads=8, num_buckets=32, max_distance=128):
|
201 |
+
super().__init__()
|
202 |
+
self.num_buckets = num_buckets
|
203 |
+
self.max_distance = max_distance
|
204 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def _relative_position_bucket(relative_position,
|
208 |
+
num_buckets=32,
|
209 |
+
max_distance=128):
|
210 |
+
ret = 0
|
211 |
+
n = -relative_position
|
212 |
+
|
213 |
+
num_buckets //= 2
|
214 |
+
ret += (n < 0).long() * num_buckets
|
215 |
+
n = torch.abs(n)
|
216 |
+
|
217 |
+
max_exact = num_buckets // 2
|
218 |
+
is_small = n < max_exact
|
219 |
+
|
220 |
+
val_if_large = max_exact + (
|
221 |
+
torch.log(n.float() / max_exact)
|
222 |
+
/ math.log(max_distance / max_exact) * # noqa
|
223 |
+
(num_buckets - max_exact)).long()
|
224 |
+
val_if_large = torch.min(
|
225 |
+
val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
226 |
+
|
227 |
+
ret += torch.where(is_small, n, val_if_large)
|
228 |
+
return ret
|
229 |
+
|
230 |
+
def forward(self, n, device):
|
231 |
+
q_pos = torch.arange(n, dtype=torch.long, device=device)
|
232 |
+
k_pos = torch.arange(n, dtype=torch.long, device=device)
|
233 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
234 |
+
rp_bucket = self._relative_position_bucket(
|
235 |
+
rel_pos,
|
236 |
+
num_buckets=self.num_buckets,
|
237 |
+
max_distance=self.max_distance)
|
238 |
+
values = self.relative_attention_bias(rp_bucket)
|
239 |
+
return rearrange(values, 'i j h -> h i j')
|
240 |
+
|
241 |
+
|
242 |
+
class SpatialTransformer(nn.Module):
|
243 |
+
"""
|
244 |
+
Transformer block for image-like data.
|
245 |
+
First, project the input (aka embedding)
|
246 |
+
and reshape to b, t, d.
|
247 |
+
Then apply standard transformer action.
|
248 |
+
Finally, reshape to image
|
249 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self,
|
253 |
+
in_channels,
|
254 |
+
n_heads,
|
255 |
+
d_head,
|
256 |
+
depth=1,
|
257 |
+
dropout=0.,
|
258 |
+
context_dim=None,
|
259 |
+
disable_self_attn=False,
|
260 |
+
use_linear=False,
|
261 |
+
use_checkpoint=True,
|
262 |
+
is_ctrl=False):
|
263 |
+
super().__init__()
|
264 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
265 |
+
context_dim = [context_dim]
|
266 |
+
self.in_channels = in_channels
|
267 |
+
inner_dim = n_heads * d_head
|
268 |
+
self.norm = torch.nn.GroupNorm(
|
269 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
270 |
+
if not use_linear:
|
271 |
+
self.proj_in = nn.Conv2d(
|
272 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
273 |
+
else:
|
274 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
275 |
+
|
276 |
+
self.transformer_blocks = nn.ModuleList([
|
277 |
+
BasicTransformerBlock(
|
278 |
+
inner_dim,
|
279 |
+
n_heads,
|
280 |
+
d_head,
|
281 |
+
dropout=dropout,
|
282 |
+
context_dim=context_dim[d],
|
283 |
+
disable_self_attn=disable_self_attn,
|
284 |
+
checkpoint=use_checkpoint,
|
285 |
+
local_type='space',
|
286 |
+
is_ctrl=is_ctrl) for d in range(depth)
|
287 |
+
])
|
288 |
+
if not use_linear:
|
289 |
+
self.proj_out = zero_module(
|
290 |
+
nn.Conv2d(
|
291 |
+
inner_dim, in_channels, kernel_size=1, stride=1,
|
292 |
+
padding=0))
|
293 |
+
else:
|
294 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
295 |
+
self.use_linear = use_linear
|
296 |
+
|
297 |
+
def forward(self, x, context=None):
|
298 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
299 |
+
if not isinstance(context, list):
|
300 |
+
context = [context]
|
301 |
+
_, _, h, w = x.shape
|
302 |
+
# print('x shape:', x.shape) # [64, 320, 90, 160]
|
303 |
+
x_in = x
|
304 |
+
x = self.norm(x)
|
305 |
+
if not self.use_linear:
|
306 |
+
x = self.proj_in(x)
|
307 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
308 |
+
if self.use_linear:
|
309 |
+
x = self.proj_in(x)
|
310 |
+
for i, block in enumerate(self.transformer_blocks):
|
311 |
+
x = block(x, context=context[i], h=h, w=w)
|
312 |
+
if self.use_linear:
|
313 |
+
x = self.proj_out(x)
|
314 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
315 |
+
if not self.use_linear:
|
316 |
+
x = self.proj_out(x)
|
317 |
+
return x + x_in
|
318 |
+
|
319 |
+
|
320 |
+
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
321 |
+
|
322 |
+
|
323 |
+
class CrossAttention(nn.Module):
|
324 |
+
|
325 |
+
def __init__(self,
|
326 |
+
query_dim,
|
327 |
+
context_dim=None,
|
328 |
+
heads=8,
|
329 |
+
dim_head=64,
|
330 |
+
dropout=0.):
|
331 |
+
super().__init__()
|
332 |
+
inner_dim = dim_head * heads
|
333 |
+
context_dim = default(context_dim, query_dim)
|
334 |
+
|
335 |
+
self.scale = dim_head**-0.5
|
336 |
+
self.heads = heads
|
337 |
+
|
338 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
339 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
340 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
341 |
+
|
342 |
+
self.to_out = nn.Sequential(
|
343 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
344 |
+
|
345 |
+
def forward(self, x, context=None, mask=None):
|
346 |
+
h = self.heads
|
347 |
+
|
348 |
+
q = self.to_q(x)
|
349 |
+
context = default(context, x)
|
350 |
+
k = self.to_k(context)
|
351 |
+
v = self.to_v(context)
|
352 |
+
|
353 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
354 |
+
(q, k, v))
|
355 |
+
|
356 |
+
# force cast to fp32 to avoid overflowing
|
357 |
+
if _ATTN_PRECISION == 'fp32':
|
358 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
359 |
+
q, k = q.float(), k.float()
|
360 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
361 |
+
else:
|
362 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
363 |
+
|
364 |
+
del q, k
|
365 |
+
|
366 |
+
if exists(mask):
|
367 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
368 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
369 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
370 |
+
sim.masked_fill_(~mask, max_neg_value)
|
371 |
+
|
372 |
+
# attention, what we cannot get enough of
|
373 |
+
sim = sim.softmax(dim=-1)
|
374 |
+
|
375 |
+
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
376 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
377 |
+
return self.to_out(out)
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
class SpatialAttention(nn.Module):
|
383 |
+
def __init__(self):
|
384 |
+
super(SpatialAttention, self).__init__()
|
385 |
+
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
|
386 |
+
self.sigmoid = nn.Sigmoid()
|
387 |
+
def forward(self, x):
|
388 |
+
|
389 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
390 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
391 |
+
|
392 |
+
weight = torch.cat([max_out, avg_out], dim=1)
|
393 |
+
weight = self.conv1(weight)
|
394 |
+
|
395 |
+
out = self.sigmoid(weight) * x
|
396 |
+
return out
|
397 |
+
|
398 |
+
class TemporalLocalAttention(nn.Module): # b c t h w
|
399 |
+
def __init__(self, dim, kernel_size=7):
|
400 |
+
super(TemporalLocalAttention, self).__init__()
|
401 |
+
self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
|
402 |
+
self.sigmoid = nn.Sigmoid()
|
403 |
+
|
404 |
+
def forward(self, x):
|
405 |
+
|
406 |
+
max_out, _ = torch.max(x, dim=-1, keepdim=True)
|
407 |
+
avg_out = torch.mean(x, dim=-1, keepdim=True)
|
408 |
+
|
409 |
+
weight = torch.cat([max_out, avg_out], dim=-1)
|
410 |
+
weight = self.conv1(weight)
|
411 |
+
|
412 |
+
out = self.sigmoid(weight) * x
|
413 |
+
return out
|
414 |
+
|
415 |
+
|
416 |
+
class BasicTransformerBlock(nn.Module):
|
417 |
+
|
418 |
+
def __init__(self,
|
419 |
+
dim,
|
420 |
+
n_heads,
|
421 |
+
d_head,
|
422 |
+
dropout=0.,
|
423 |
+
context_dim=None,
|
424 |
+
gated_ff=True,
|
425 |
+
checkpoint=True,
|
426 |
+
disable_self_attn=False,
|
427 |
+
local_type=None,
|
428 |
+
is_ctrl=False):
|
429 |
+
super().__init__()
|
430 |
+
self.local_type = local_type
|
431 |
+
self.is_ctrl = is_ctrl
|
432 |
+
attn_cls = MemoryEfficientCrossAttention
|
433 |
+
self.disable_self_attn = disable_self_attn
|
434 |
+
self.attn1 = attn_cls( # self-attn
|
435 |
+
query_dim=dim,
|
436 |
+
heads=n_heads,
|
437 |
+
dim_head=d_head,
|
438 |
+
dropout=dropout,
|
439 |
+
context_dim=context_dim if self.disable_self_attn else None)
|
440 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
441 |
+
|
442 |
+
attn_cls2 = MemoryEfficientCrossAttention
|
443 |
+
|
444 |
+
self.attn2 = attn_cls2(
|
445 |
+
query_dim=dim,
|
446 |
+
context_dim=context_dim,
|
447 |
+
heads=n_heads,
|
448 |
+
dim_head=d_head,
|
449 |
+
dropout=dropout)
|
450 |
+
self.norm1 = nn.LayerNorm(dim)
|
451 |
+
self.norm2 = nn.LayerNorm(dim)
|
452 |
+
self.norm3 = nn.LayerNorm(dim)
|
453 |
+
self.checkpoint = checkpoint
|
454 |
+
|
455 |
+
if self.local_type == 'space' and self.is_ctrl:
|
456 |
+
self.local1 = SpatialAttention()
|
457 |
+
|
458 |
+
if self.local_type == 'temp' and self.is_ctrl:
|
459 |
+
self.local1 = TemporalLocalAttention(dim=dim)
|
460 |
+
self.local2 = TemporalLocalAttention(dim=dim)
|
461 |
+
|
462 |
+
def forward_(self, x, context=None):
|
463 |
+
return checkpoint(self._forward, (x, context), self.parameters(),
|
464 |
+
self.checkpoint)
|
465 |
+
|
466 |
+
def forward(self, x, context=None, h=None, w=None):
|
467 |
+
|
468 |
+
if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c]
|
469 |
+
|
470 |
+
x_local = rearrange(x, 'b (h w) c -> b c h w', h=h)
|
471 |
+
x_local = self.local1(x_local)
|
472 |
+
x_local = rearrange(x_local, 'b c h w -> b (h w) c')
|
473 |
+
|
474 |
+
x = self.attn1(
|
475 |
+
self.norm1(x_local),
|
476 |
+
context=context if self.disable_self_attn else None) + x
|
477 |
+
|
478 |
+
x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
|
479 |
+
x = self.ff(self.norm3(x)) + x
|
480 |
+
|
481 |
+
if self.local_type == 'temp' and self.is_ctrl:
|
482 |
+
|
483 |
+
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
|
484 |
+
x_local = self.local1(x)
|
485 |
+
|
486 |
+
x = self.attn1(
|
487 |
+
self.norm1(x_local),
|
488 |
+
context=context if self.disable_self_attn else None) + x
|
489 |
+
|
490 |
+
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
|
491 |
+
x_local = self.local2(x)
|
492 |
+
|
493 |
+
x = self.attn2(self.norm2(x_local), context=context) + x
|
494 |
+
x = self.ff(self.norm3(x)) + x
|
495 |
+
|
496 |
+
# elif self.local_type == 'space' and self.is_ctrl:
|
497 |
+
# # print('*** use original attention ***')
|
498 |
+
# x = self.attn1(
|
499 |
+
# self.norm1(x),
|
500 |
+
# context=context if self.disable_self_attn else None) + x # self-attention
|
501 |
+
|
502 |
+
# x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
|
503 |
+
# x = self.ff(self.norm3(x)) + x
|
504 |
+
|
505 |
+
return x
|
506 |
+
|
507 |
+
|
508 |
+
# feedforward
|
509 |
+
class GEGLU(nn.Module):
|
510 |
+
|
511 |
+
def __init__(self, dim_in, dim_out):
|
512 |
+
super().__init__()
|
513 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
514 |
+
|
515 |
+
def forward(self, x):
|
516 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
517 |
+
return x * F.gelu(gate)
|
518 |
+
|
519 |
+
|
520 |
+
def zero_module(module):
|
521 |
+
"""
|
522 |
+
Zero out the parameters of a module and return it.
|
523 |
+
"""
|
524 |
+
for p in module.parameters():
|
525 |
+
p.detach().zero_()
|
526 |
+
return module
|
527 |
+
|
528 |
+
|
529 |
+
class FeedForward(nn.Module):
|
530 |
+
|
531 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
532 |
+
super().__init__()
|
533 |
+
inner_dim = int(dim * mult)
|
534 |
+
dim_out = default(dim_out, dim)
|
535 |
+
project_in = nn.Sequential(nn.Linear(
|
536 |
+
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
537 |
+
|
538 |
+
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
539 |
+
nn.Linear(inner_dim, dim_out))
|
540 |
+
|
541 |
+
def forward(self, x):
|
542 |
+
return self.net(x)
|
543 |
+
|
544 |
+
|
545 |
+
class Upsample(nn.Module):
|
546 |
+
"""
|
547 |
+
An upsampling layer with an optional convolution.
|
548 |
+
:param channels: channels in the inputs and outputs.
|
549 |
+
:param use_conv: a bool determining if a convolution is applied.
|
550 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
551 |
+
upsampling occurs in the inner-two dimensions.
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self,
|
555 |
+
channels,
|
556 |
+
use_conv,
|
557 |
+
dims=2,
|
558 |
+
out_channels=None,
|
559 |
+
padding=1):
|
560 |
+
super().__init__()
|
561 |
+
self.channels = channels
|
562 |
+
self.out_channels = out_channels or channels
|
563 |
+
self.use_conv = use_conv
|
564 |
+
self.dims = dims
|
565 |
+
if use_conv:
|
566 |
+
self.conv = nn.Conv2d(
|
567 |
+
self.channels, self.out_channels, 3, padding=padding)
|
568 |
+
|
569 |
+
def forward(self, x):
|
570 |
+
assert x.shape[1] == self.channels
|
571 |
+
if self.dims == 3:
|
572 |
+
x = F.interpolate(
|
573 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
574 |
+
mode='nearest')
|
575 |
+
else:
|
576 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
577 |
+
x = x[..., 1:-1, :]
|
578 |
+
if self.use_conv:
|
579 |
+
x = self.conv(x)
|
580 |
+
return x
|
581 |
+
|
582 |
+
|
583 |
+
class ResBlock(nn.Module):
|
584 |
+
"""
|
585 |
+
A residual block that can optionally change the number of channels.
|
586 |
+
:param channels: the number of input channels.
|
587 |
+
:param emb_channels: the number of timestep embedding channels.
|
588 |
+
:param dropout: the rate of dropout.
|
589 |
+
:param out_channels: if specified, the number of out channels.
|
590 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
591 |
+
convolution instead of a smaller 1x1 convolution to change the
|
592 |
+
channels in the skip connection.
|
593 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
594 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
595 |
+
:param up: if True, use this block for upsampling.
|
596 |
+
:param down: if True, use this block for downsampling.
|
597 |
+
"""
|
598 |
+
|
599 |
+
def __init__(
|
600 |
+
self,
|
601 |
+
channels,
|
602 |
+
emb_channels,
|
603 |
+
dropout,
|
604 |
+
out_channels=None,
|
605 |
+
use_conv=False,
|
606 |
+
use_scale_shift_norm=False,
|
607 |
+
dims=2,
|
608 |
+
up=False,
|
609 |
+
down=False,
|
610 |
+
use_temporal_conv=True,
|
611 |
+
use_image_dataset=False,
|
612 |
+
):
|
613 |
+
super().__init__()
|
614 |
+
self.channels = channels
|
615 |
+
self.emb_channels = emb_channels
|
616 |
+
self.dropout = dropout
|
617 |
+
self.out_channels = out_channels or channels
|
618 |
+
self.use_conv = use_conv
|
619 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
620 |
+
self.use_temporal_conv = use_temporal_conv
|
621 |
+
|
622 |
+
self.in_layers = nn.Sequential(
|
623 |
+
nn.GroupNorm(32, channels),
|
624 |
+
nn.SiLU(),
|
625 |
+
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
626 |
+
)
|
627 |
+
|
628 |
+
self.updown = up or down
|
629 |
+
|
630 |
+
if up:
|
631 |
+
self.h_upd = Upsample(channels, False, dims)
|
632 |
+
self.x_upd = Upsample(channels, False, dims)
|
633 |
+
elif down:
|
634 |
+
self.h_upd = Downsample(channels, False, dims)
|
635 |
+
self.x_upd = Downsample(channels, False, dims)
|
636 |
+
else:
|
637 |
+
self.h_upd = self.x_upd = nn.Identity()
|
638 |
+
|
639 |
+
self.emb_layers = nn.Sequential(
|
640 |
+
nn.SiLU(),
|
641 |
+
nn.Linear(
|
642 |
+
emb_channels,
|
643 |
+
2 * self.out_channels
|
644 |
+
if use_scale_shift_norm else self.out_channels,
|
645 |
+
),
|
646 |
+
)
|
647 |
+
self.out_layers = nn.Sequential(
|
648 |
+
nn.GroupNorm(32, self.out_channels),
|
649 |
+
nn.SiLU(),
|
650 |
+
nn.Dropout(p=dropout),
|
651 |
+
zero_module(
|
652 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
653 |
+
)
|
654 |
+
|
655 |
+
if self.out_channels == channels:
|
656 |
+
self.skip_connection = nn.Identity()
|
657 |
+
elif use_conv:
|
658 |
+
self.skip_connection = conv_nd(
|
659 |
+
dims, channels, self.out_channels, 3, padding=1)
|
660 |
+
else:
|
661 |
+
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
662 |
+
|
663 |
+
if self.use_temporal_conv:
|
664 |
+
self.temopral_conv = TemporalConvBlock_v2(
|
665 |
+
self.out_channels,
|
666 |
+
self.out_channels,
|
667 |
+
dropout=0.1,
|
668 |
+
use_image_dataset=use_image_dataset)
|
669 |
+
|
670 |
+
def forward(self, x, emb, batch_size, variant_info=None):
|
671 |
+
"""
|
672 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
673 |
+
:param x: an [N x C x ...] Tensor of features.
|
674 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
675 |
+
:return: an [N x C x ...] Tensor of outputs.
|
676 |
+
"""
|
677 |
+
return self._forward(x, emb, batch_size, variant_info)
|
678 |
+
|
679 |
+
def _forward(self, x, emb, batch_size, variant_info):
|
680 |
+
if self.updown:
|
681 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
682 |
+
h = in_rest(x)
|
683 |
+
h = self.h_upd(h)
|
684 |
+
x = self.x_upd(x)
|
685 |
+
h = in_conv(h)
|
686 |
+
else:
|
687 |
+
h = self.in_layers(x)
|
688 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
689 |
+
while len(emb_out.shape) < len(h.shape):
|
690 |
+
emb_out = emb_out[..., None]
|
691 |
+
if self.use_scale_shift_norm:
|
692 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
693 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
694 |
+
h = out_norm(h) * (1 + scale) + shift
|
695 |
+
h = out_rest(h)
|
696 |
+
else:
|
697 |
+
h = h + emb_out
|
698 |
+
h = self.out_layers(h)
|
699 |
+
h = self.skip_connection(x) + h
|
700 |
+
|
701 |
+
if self.use_temporal_conv:
|
702 |
+
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
|
703 |
+
h = self.temopral_conv(h, variant_info=variant_info)
|
704 |
+
h = rearrange(h, 'b c f h w -> (b f) c h w')
|
705 |
+
return h
|
706 |
+
|
707 |
+
|
708 |
+
class Downsample(nn.Module):
|
709 |
+
"""
|
710 |
+
A downsampling layer with an optional convolution.
|
711 |
+
:param channels: channels in the inputs and outputs.
|
712 |
+
:param use_conv: a bool determining if a convolution is applied.
|
713 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
714 |
+
downsampling occurs in the inner-two dimensions.
|
715 |
+
"""
|
716 |
+
|
717 |
+
def __init__(self,
|
718 |
+
channels,
|
719 |
+
use_conv,
|
720 |
+
dims=2,
|
721 |
+
out_channels=None,
|
722 |
+
padding=(2, 1)):
|
723 |
+
super().__init__()
|
724 |
+
self.channels = channels
|
725 |
+
self.out_channels = out_channels or channels
|
726 |
+
self.use_conv = use_conv
|
727 |
+
self.dims = dims
|
728 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
729 |
+
if use_conv:
|
730 |
+
self.op = nn.Conv2d(
|
731 |
+
self.channels,
|
732 |
+
self.out_channels,
|
733 |
+
3,
|
734 |
+
stride=stride,
|
735 |
+
padding=padding)
|
736 |
+
else:
|
737 |
+
assert self.channels == self.out_channels
|
738 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
739 |
+
|
740 |
+
def forward(self, x):
|
741 |
+
assert x.shape[1] == self.channels
|
742 |
+
return self.op(x)
|
743 |
+
|
744 |
+
|
745 |
+
class Resample(nn.Module):
|
746 |
+
|
747 |
+
def __init__(self, in_dim, out_dim, mode):
|
748 |
+
assert mode in ['none', 'upsample', 'downsample']
|
749 |
+
super(Resample, self).__init__()
|
750 |
+
self.in_dim = in_dim
|
751 |
+
self.out_dim = out_dim
|
752 |
+
self.mode = mode
|
753 |
+
|
754 |
+
def forward(self, x, reference=None):
|
755 |
+
if self.mode == 'upsample':
|
756 |
+
assert reference is not None
|
757 |
+
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
|
758 |
+
elif self.mode == 'downsample':
|
759 |
+
x = F.adaptive_avg_pool2d(
|
760 |
+
x, output_size=tuple(u // 2 for u in x.shape[-2:]))
|
761 |
+
return x
|
762 |
+
|
763 |
+
|
764 |
+
class ResidualBlock(nn.Module):
|
765 |
+
|
766 |
+
def __init__(self,
|
767 |
+
in_dim,
|
768 |
+
embed_dim,
|
769 |
+
out_dim,
|
770 |
+
use_scale_shift_norm=True,
|
771 |
+
mode='none',
|
772 |
+
dropout=0.0):
|
773 |
+
super(ResidualBlock, self).__init__()
|
774 |
+
self.in_dim = in_dim
|
775 |
+
self.embed_dim = embed_dim
|
776 |
+
self.out_dim = out_dim
|
777 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
778 |
+
self.mode = mode
|
779 |
+
|
780 |
+
# layers
|
781 |
+
self.layer1 = nn.Sequential(
|
782 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
783 |
+
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
784 |
+
self.resample = Resample(in_dim, in_dim, mode)
|
785 |
+
self.embedding = nn.Sequential(
|
786 |
+
nn.SiLU(),
|
787 |
+
nn.Linear(embed_dim,
|
788 |
+
out_dim * 2 if use_scale_shift_norm else out_dim))
|
789 |
+
self.layer2 = nn.Sequential(
|
790 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
791 |
+
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
792 |
+
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
793 |
+
in_dim, out_dim, 1)
|
794 |
+
|
795 |
+
# zero out the last layer params
|
796 |
+
nn.init.zeros_(self.layer2[-1].weight)
|
797 |
+
|
798 |
+
def forward(self, x, e, reference=None):
|
799 |
+
identity = self.resample(x, reference)
|
800 |
+
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
|
801 |
+
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
802 |
+
if self.use_scale_shift_norm:
|
803 |
+
scale, shift = e.chunk(2, dim=1)
|
804 |
+
x = self.layer2[0](x) * (1 + scale) + shift
|
805 |
+
x = self.layer2[1:](x)
|
806 |
+
else:
|
807 |
+
x = x + e
|
808 |
+
x = self.layer2(x)
|
809 |
+
x = x + self.shortcut(identity)
|
810 |
+
return x
|
811 |
+
|
812 |
+
|
813 |
+
class AttentionBlock(nn.Module):
|
814 |
+
|
815 |
+
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
816 |
+
# consider head_dim first, then num_heads
|
817 |
+
num_heads = dim // head_dim if head_dim else num_heads
|
818 |
+
head_dim = dim // num_heads
|
819 |
+
assert num_heads * head_dim == dim
|
820 |
+
super(AttentionBlock, self).__init__()
|
821 |
+
self.dim = dim
|
822 |
+
self.context_dim = context_dim
|
823 |
+
self.num_heads = num_heads
|
824 |
+
self.head_dim = head_dim
|
825 |
+
self.scale = math.pow(head_dim, -0.25)
|
826 |
+
|
827 |
+
# layers
|
828 |
+
self.norm = nn.GroupNorm(32, dim)
|
829 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
830 |
+
if context_dim is not None:
|
831 |
+
self.context_kv = nn.Linear(context_dim, dim * 2)
|
832 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
833 |
+
|
834 |
+
# zero out the last layer params
|
835 |
+
nn.init.zeros_(self.proj.weight)
|
836 |
+
|
837 |
+
def forward(self, x, context=None):
|
838 |
+
r"""x: [B, C, H, W].
|
839 |
+
context: [B, L, C] or None.
|
840 |
+
"""
|
841 |
+
identity = x
|
842 |
+
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
843 |
+
|
844 |
+
# compute query, key, value
|
845 |
+
x = self.norm(x)
|
846 |
+
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
847 |
+
if context is not None:
|
848 |
+
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
849 |
+
d).permute(0, 2, 3,
|
850 |
+
1).chunk(
|
851 |
+
2, dim=1)
|
852 |
+
k = torch.cat([ck, k], dim=-1)
|
853 |
+
v = torch.cat([cv, v], dim=-1)
|
854 |
+
|
855 |
+
# compute attention
|
856 |
+
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
857 |
+
attn = F.softmax(attn, dim=-1)
|
858 |
+
|
859 |
+
# gather context
|
860 |
+
x = torch.matmul(v, attn.transpose(-1, -2))
|
861 |
+
x = x.reshape(b, c, h, w)
|
862 |
+
|
863 |
+
# output
|
864 |
+
x = self.proj(x)
|
865 |
+
return x + identity
|
866 |
+
|
867 |
+
|
868 |
+
class TemporalAttentionBlock(nn.Module):
|
869 |
+
|
870 |
+
def __init__(self,
|
871 |
+
dim,
|
872 |
+
heads=4,
|
873 |
+
dim_head=32,
|
874 |
+
rotary_emb=None,
|
875 |
+
use_image_dataset=False,
|
876 |
+
use_sim_mask=False):
|
877 |
+
super().__init__()
|
878 |
+
# consider num_heads first, as pos_bias needs fixed num_heads
|
879 |
+
dim_head = dim // heads
|
880 |
+
assert heads * dim_head == dim
|
881 |
+
self.use_image_dataset = use_image_dataset
|
882 |
+
self.use_sim_mask = use_sim_mask
|
883 |
+
|
884 |
+
self.scale = dim_head**-0.5
|
885 |
+
self.heads = heads
|
886 |
+
hidden_dim = dim_head * heads
|
887 |
+
|
888 |
+
self.norm = nn.GroupNorm(32, dim)
|
889 |
+
self.rotary_emb = rotary_emb
|
890 |
+
self.to_qkv = nn.Linear(dim, hidden_dim * 3)
|
891 |
+
self.to_out = nn.Linear(hidden_dim, dim)
|
892 |
+
|
893 |
+
def forward(self,
|
894 |
+
x,
|
895 |
+
pos_bias=None,
|
896 |
+
focus_present_mask=None,
|
897 |
+
video_mask=None):
|
898 |
+
|
899 |
+
identity = x
|
900 |
+
n, height, device = x.shape[2], x.shape[-2], x.device
|
901 |
+
|
902 |
+
x = self.norm(x)
|
903 |
+
x = rearrange(x, 'b c f h w -> b (h w) f c')
|
904 |
+
|
905 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
906 |
+
|
907 |
+
if exists(focus_present_mask) and focus_present_mask.all():
|
908 |
+
# if all batch samples are focusing on present
|
909 |
+
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
|
910 |
+
values = qkv[-1]
|
911 |
+
out = self.to_out(values)
|
912 |
+
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
|
913 |
+
|
914 |
+
return out + identity
|
915 |
+
|
916 |
+
# split out heads
|
917 |
+
q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
|
918 |
+
k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
|
919 |
+
v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
|
920 |
+
|
921 |
+
# scale
|
922 |
+
|
923 |
+
q = q * self.scale
|
924 |
+
|
925 |
+
# rotate positions into queries and keys for time attention
|
926 |
+
if exists(self.rotary_emb):
|
927 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
928 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
929 |
+
|
930 |
+
# similarity
|
931 |
+
# shape [b (hw) h n n], n=f
|
932 |
+
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
|
933 |
+
|
934 |
+
# relative positional bias
|
935 |
+
|
936 |
+
if exists(pos_bias):
|
937 |
+
sim = sim + pos_bias
|
938 |
+
|
939 |
+
if (focus_present_mask is None and video_mask is not None):
|
940 |
+
# video_mask: [B, n]
|
941 |
+
mask = video_mask[:, None, :] * video_mask[:, :, None]
|
942 |
+
mask = mask.unsqueeze(1).unsqueeze(1)
|
943 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
944 |
+
elif exists(focus_present_mask) and not (~focus_present_mask).all():
|
945 |
+
attend_all_mask = torch.ones((n, n),
|
946 |
+
device=device,
|
947 |
+
dtype=torch.bool)
|
948 |
+
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
|
949 |
+
|
950 |
+
mask = torch.where(
|
951 |
+
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
|
952 |
+
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
|
953 |
+
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
|
954 |
+
)
|
955 |
+
|
956 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
957 |
+
|
958 |
+
if self.use_sim_mask:
|
959 |
+
sim_mask = torch.tril(
|
960 |
+
torch.ones((n, n), device=device, dtype=torch.bool),
|
961 |
+
diagonal=0)
|
962 |
+
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
|
963 |
+
|
964 |
+
# numerical stability
|
965 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
966 |
+
attn = sim.softmax(dim=-1)
|
967 |
+
|
968 |
+
# aggregate values
|
969 |
+
|
970 |
+
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
971 |
+
out = rearrange(out, '... h n d -> ... n (h d)')
|
972 |
+
out = self.to_out(out)
|
973 |
+
|
974 |
+
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
|
975 |
+
|
976 |
+
if self.use_image_dataset:
|
977 |
+
out = identity + 0 * out
|
978 |
+
else:
|
979 |
+
out = identity + out
|
980 |
+
return out
|
981 |
+
|
982 |
+
|
983 |
+
class TemporalTransformer(nn.Module):
|
984 |
+
"""
|
985 |
+
Transformer block for image-like data.
|
986 |
+
First, project the input (aka embedding)
|
987 |
+
and reshape to b, t, d.
|
988 |
+
Then apply standard transformer action.
|
989 |
+
Finally, reshape to image
|
990 |
+
"""
|
991 |
+
|
992 |
+
def __init__(self,
|
993 |
+
in_channels,
|
994 |
+
n_heads,
|
995 |
+
d_head,
|
996 |
+
depth=1,
|
997 |
+
dropout=0.,
|
998 |
+
context_dim=None,
|
999 |
+
disable_self_attn=False,
|
1000 |
+
use_linear=False,
|
1001 |
+
use_checkpoint=True,
|
1002 |
+
only_self_att=True,
|
1003 |
+
multiply_zero=False,
|
1004 |
+
is_ctrl=False):
|
1005 |
+
super().__init__()
|
1006 |
+
self.multiply_zero = multiply_zero
|
1007 |
+
self.only_self_att = only_self_att
|
1008 |
+
self.use_adaptor = False
|
1009 |
+
if self.only_self_att:
|
1010 |
+
context_dim = None
|
1011 |
+
if not isinstance(context_dim, list):
|
1012 |
+
context_dim = [context_dim]
|
1013 |
+
self.in_channels = in_channels
|
1014 |
+
inner_dim = n_heads * d_head
|
1015 |
+
self.norm = torch.nn.GroupNorm(
|
1016 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
1017 |
+
if not use_linear:
|
1018 |
+
self.proj_in = nn.Conv1d(
|
1019 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
1020 |
+
else:
|
1021 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
1022 |
+
if self.use_adaptor:
|
1023 |
+
self.adaptor_in = nn.Linear(frames, frames)
|
1024 |
+
|
1025 |
+
self.transformer_blocks = nn.ModuleList([
|
1026 |
+
BasicTransformerBlock(
|
1027 |
+
inner_dim,
|
1028 |
+
n_heads,
|
1029 |
+
d_head,
|
1030 |
+
dropout=dropout,
|
1031 |
+
context_dim=context_dim[d],
|
1032 |
+
checkpoint=use_checkpoint,
|
1033 |
+
local_type='temp',
|
1034 |
+
is_ctrl=is_ctrl) for d in range(depth)
|
1035 |
+
])
|
1036 |
+
if not use_linear:
|
1037 |
+
self.proj_out = zero_module(
|
1038 |
+
nn.Conv1d(
|
1039 |
+
inner_dim, in_channels, kernel_size=1, stride=1,
|
1040 |
+
padding=0))
|
1041 |
+
else:
|
1042 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
1043 |
+
if self.use_adaptor:
|
1044 |
+
self.adaptor_out = nn.Linear(frames, frames)
|
1045 |
+
self.use_linear = use_linear
|
1046 |
+
|
1047 |
+
def forward(self, x, context=None):
|
1048 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
1049 |
+
if self.only_self_att:
|
1050 |
+
context = None
|
1051 |
+
if not isinstance(context, list):
|
1052 |
+
context = [context]
|
1053 |
+
b, _, _, h, w = x.shape
|
1054 |
+
x_in = x
|
1055 |
+
x = self.norm(x)
|
1056 |
+
|
1057 |
+
if not self.use_linear:
|
1058 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
1059 |
+
x = self.proj_in(x)
|
1060 |
+
if self.use_linear:
|
1061 |
+
x = rearrange(
|
1062 |
+
x, 'b c f h w -> (b h w) f c').contiguous()
|
1063 |
+
x = self.proj_in(x)
|
1064 |
+
x = rearrange(
|
1065 |
+
x, 'bhw f c -> bhw c f').contiguous()
|
1066 |
+
|
1067 |
+
# print('x shape:', x.shape) # [28800, 512, 32]
|
1068 |
+
if self.only_self_att: # no cross-attention
|
1069 |
+
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
1070 |
+
for i, block in enumerate(self.transformer_blocks):
|
1071 |
+
x = block(x, h=h, w=w)
|
1072 |
+
# print('x shape:', x.shape) # [43200, 32, 512]
|
1073 |
+
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
1074 |
+
else:
|
1075 |
+
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
1076 |
+
for i, block in enumerate(self.transformer_blocks):
|
1077 |
+
context[i] = rearrange(
|
1078 |
+
context[i], '(b f) l con -> b f l con',
|
1079 |
+
f=self.frames).contiguous()
|
1080 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
1081 |
+
for j in range(b):
|
1082 |
+
context_i_j = repeat(
|
1083 |
+
context[i][j],
|
1084 |
+
'f l con -> (f r) l con',
|
1085 |
+
r=(h * w) // self.frames,
|
1086 |
+
f=self.frames).contiguous()
|
1087 |
+
x[j] = block(x[j], context=context_i_j)
|
1088 |
+
|
1089 |
+
if self.use_linear:
|
1090 |
+
x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous()
|
1091 |
+
x = self.proj_out(x)
|
1092 |
+
x = rearrange(
|
1093 |
+
x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous()
|
1094 |
+
if not self.use_linear:
|
1095 |
+
# print('x shape:', x.shape) # [2, 21600, 32, 512]
|
1096 |
+
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
1097 |
+
x = self.proj_out(x)
|
1098 |
+
x = rearrange(
|
1099 |
+
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
1100 |
+
|
1101 |
+
if self.multiply_zero:
|
1102 |
+
x = 0.0 * x + x_in
|
1103 |
+
else:
|
1104 |
+
x = x + x_in
|
1105 |
+
return x
|
1106 |
+
|
1107 |
+
|
1108 |
+
class TemporalAttentionMultiBlock(nn.Module):
|
1109 |
+
|
1110 |
+
def __init__(
|
1111 |
+
self,
|
1112 |
+
dim,
|
1113 |
+
heads=4,
|
1114 |
+
dim_head=32,
|
1115 |
+
rotary_emb=None,
|
1116 |
+
use_image_dataset=False,
|
1117 |
+
use_sim_mask=False,
|
1118 |
+
temporal_attn_times=1,
|
1119 |
+
):
|
1120 |
+
super().__init__()
|
1121 |
+
self.att_layers = nn.ModuleList([
|
1122 |
+
TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
|
1123 |
+
use_image_dataset, use_sim_mask)
|
1124 |
+
for _ in range(temporal_attn_times)
|
1125 |
+
])
|
1126 |
+
|
1127 |
+
def forward(self,
|
1128 |
+
x,
|
1129 |
+
pos_bias=None,
|
1130 |
+
focus_present_mask=None,
|
1131 |
+
video_mask=None):
|
1132 |
+
for layer in self.att_layers:
|
1133 |
+
x = layer(x, pos_bias, focus_present_mask, video_mask)
|
1134 |
+
return x
|
1135 |
+
|
1136 |
+
|
1137 |
+
class InitTemporalConvBlock(nn.Module):
|
1138 |
+
|
1139 |
+
def __init__(self,
|
1140 |
+
in_dim,
|
1141 |
+
out_dim=None,
|
1142 |
+
dropout=0.0,
|
1143 |
+
use_image_dataset=False):
|
1144 |
+
super(InitTemporalConvBlock, self).__init__()
|
1145 |
+
if out_dim is None:
|
1146 |
+
out_dim = in_dim
|
1147 |
+
self.in_dim = in_dim
|
1148 |
+
self.out_dim = out_dim
|
1149 |
+
self.use_image_dataset = use_image_dataset
|
1150 |
+
|
1151 |
+
# conv layers
|
1152 |
+
self.conv = nn.Sequential(
|
1153 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
1154 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1155 |
+
|
1156 |
+
# zero out the last layer params,so the conv block is identity
|
1157 |
+
nn.init.zeros_(self.conv[-1].weight)
|
1158 |
+
nn.init.zeros_(self.conv[-1].bias)
|
1159 |
+
|
1160 |
+
def forward(self, x):
|
1161 |
+
identity = x
|
1162 |
+
x = self.conv(x)
|
1163 |
+
if self.use_image_dataset:
|
1164 |
+
x = identity + 0 * x
|
1165 |
+
else:
|
1166 |
+
x = identity + x
|
1167 |
+
return x
|
1168 |
+
|
1169 |
+
|
1170 |
+
class TemporalConvBlock(nn.Module):
|
1171 |
+
|
1172 |
+
def __init__(self,
|
1173 |
+
in_dim,
|
1174 |
+
out_dim=None,
|
1175 |
+
dropout=0.0,
|
1176 |
+
use_image_dataset=False):
|
1177 |
+
super(TemporalConvBlock, self).__init__()
|
1178 |
+
if out_dim is None:
|
1179 |
+
out_dim = in_dim
|
1180 |
+
self.in_dim = in_dim
|
1181 |
+
self.out_dim = out_dim
|
1182 |
+
self.use_image_dataset = use_image_dataset
|
1183 |
+
|
1184 |
+
# conv layers
|
1185 |
+
self.conv1 = nn.Sequential(
|
1186 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
1187 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1188 |
+
self.conv2 = nn.Sequential(
|
1189 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
1190 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1191 |
+
|
1192 |
+
# zero out the last layer params,so the conv block is identity
|
1193 |
+
nn.init.zeros_(self.conv2[-1].weight)
|
1194 |
+
nn.init.zeros_(self.conv2[-1].bias)
|
1195 |
+
|
1196 |
+
def forward(self, x):
|
1197 |
+
identity = x
|
1198 |
+
x = self.conv1(x)
|
1199 |
+
x = self.conv2(x)
|
1200 |
+
if self.use_image_dataset:
|
1201 |
+
x = identity + 0 * x
|
1202 |
+
else:
|
1203 |
+
x = identity + x
|
1204 |
+
return x
|
1205 |
+
|
1206 |
+
|
1207 |
+
class TemporalConvBlock_v2(nn.Module):
|
1208 |
+
|
1209 |
+
def __init__(self,
|
1210 |
+
in_dim,
|
1211 |
+
out_dim=None,
|
1212 |
+
dropout=0.0,
|
1213 |
+
use_image_dataset=False):
|
1214 |
+
super(TemporalConvBlock_v2, self).__init__()
|
1215 |
+
if out_dim is None:
|
1216 |
+
out_dim = in_dim
|
1217 |
+
self.in_dim = in_dim
|
1218 |
+
self.out_dim = out_dim
|
1219 |
+
self.use_image_dataset = use_image_dataset
|
1220 |
+
|
1221 |
+
# conv layers
|
1222 |
+
self.conv1 = nn.Sequential(
|
1223 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
1224 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1225 |
+
self.conv2 = nn.Sequential(
|
1226 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
1227 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1228 |
+
self.conv3 = nn.Sequential(
|
1229 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
1230 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1231 |
+
self.conv4 = nn.Sequential(
|
1232 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
1233 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
1234 |
+
|
1235 |
+
# zero out the last layer params,so the conv block is identity
|
1236 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
1237 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
1238 |
+
|
1239 |
+
def forward(self, x, variant_info=None):
|
1240 |
+
if variant_info is not None and variant_info.get('type') == 'variant2':
|
1241 |
+
# print(x.shape) # torch.Size([1, 320, 32, 90, 160])
|
1242 |
+
_, _, f, _, _ = x.shape
|
1243 |
+
assert f % 4 == 0, "f must be divisible by 4"
|
1244 |
+
x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4)
|
1245 |
+
x_short = self.conv1(x_short)
|
1246 |
+
x_short = self.conv2(x_short)
|
1247 |
+
x_short = self.conv3(x_short)
|
1248 |
+
x_short = self.conv4(x_short)
|
1249 |
+
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
|
1250 |
+
|
1251 |
+
identity = x
|
1252 |
+
x = self.conv1(x)
|
1253 |
+
x = self.conv2(x)
|
1254 |
+
x = self.conv3(x)
|
1255 |
+
x = self.conv4(x)
|
1256 |
+
|
1257 |
+
x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha']
|
1258 |
+
|
1259 |
+
|
1260 |
+
elif variant_info is not None and variant_info.get('type') == 'variant1':
|
1261 |
+
identity = x
|
1262 |
+
x_long, x_short = x.chunk(2, dim=0)
|
1263 |
+
|
1264 |
+
x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4)
|
1265 |
+
x_short = self.conv1(x_short)
|
1266 |
+
x_short = self.conv2(x_short)
|
1267 |
+
x_short = self.conv3(x_short)
|
1268 |
+
x_short = self.conv4(x_short)
|
1269 |
+
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
|
1270 |
+
|
1271 |
+
x_long = self.conv1(x_long)
|
1272 |
+
x_long = self.conv2(x_long)
|
1273 |
+
x_long = self.conv3(x_long)
|
1274 |
+
x_long = self.conv4(x_long)
|
1275 |
+
|
1276 |
+
x = torch.cat([x_long, x_short], dim=0)
|
1277 |
+
|
1278 |
+
|
1279 |
+
elif variant_info is None:
|
1280 |
+
identity = x
|
1281 |
+
x = self.conv1(x)
|
1282 |
+
x = self.conv2(x)
|
1283 |
+
x = self.conv3(x)
|
1284 |
+
x = self.conv4(x)
|
1285 |
+
|
1286 |
+
|
1287 |
+
if self.use_image_dataset:
|
1288 |
+
x = identity + 0.0 * x
|
1289 |
+
else:
|
1290 |
+
x = identity + x
|
1291 |
+
return x
|
1292 |
+
|
1293 |
+
|
1294 |
+
class Vid2VidSDUNet(nn.Module):
|
1295 |
+
|
1296 |
+
def __init__(self,
|
1297 |
+
in_dim=4,
|
1298 |
+
dim=320,
|
1299 |
+
y_dim=1024,
|
1300 |
+
context_dim=1024,
|
1301 |
+
out_dim=4,
|
1302 |
+
dim_mult=[1, 2, 4, 4],
|
1303 |
+
num_heads=8,
|
1304 |
+
head_dim=64,
|
1305 |
+
num_res_blocks=2,
|
1306 |
+
attn_scales=[1 / 1, 1 / 2, 1 / 4],
|
1307 |
+
use_scale_shift_norm=True,
|
1308 |
+
dropout=0.1,
|
1309 |
+
temporal_attn_times=1,
|
1310 |
+
temporal_attention=True,
|
1311 |
+
use_checkpoint=True,
|
1312 |
+
use_image_dataset=False,
|
1313 |
+
use_fps_condition=False,
|
1314 |
+
use_sim_mask=False,
|
1315 |
+
training=False,
|
1316 |
+
inpainting=True):
|
1317 |
+
embed_dim = dim * 4
|
1318 |
+
num_heads = num_heads if num_heads else dim // 32
|
1319 |
+
super(Vid2VidSDUNet, self).__init__()
|
1320 |
+
self.in_dim = in_dim
|
1321 |
+
self.dim = dim
|
1322 |
+
self.y_dim = y_dim
|
1323 |
+
self.context_dim = context_dim
|
1324 |
+
self.embed_dim = embed_dim
|
1325 |
+
self.out_dim = out_dim
|
1326 |
+
self.dim_mult = dim_mult
|
1327 |
+
# for temporal attention
|
1328 |
+
self.num_heads = num_heads
|
1329 |
+
# for spatial attention
|
1330 |
+
self.head_dim = head_dim
|
1331 |
+
self.num_res_blocks = num_res_blocks
|
1332 |
+
self.attn_scales = attn_scales
|
1333 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
1334 |
+
self.temporal_attn_times = temporal_attn_times
|
1335 |
+
self.temporal_attention = temporal_attention
|
1336 |
+
self.use_checkpoint = use_checkpoint
|
1337 |
+
self.use_image_dataset = use_image_dataset
|
1338 |
+
self.use_fps_condition = use_fps_condition
|
1339 |
+
self.use_sim_mask = use_sim_mask
|
1340 |
+
self.training = training
|
1341 |
+
self.inpainting = inpainting
|
1342 |
+
|
1343 |
+
use_linear_in_temporal = False
|
1344 |
+
transformer_depth = 1
|
1345 |
+
disabled_sa = False
|
1346 |
+
# params
|
1347 |
+
enc_dims = [dim * u for u in [1] + dim_mult]
|
1348 |
+
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
1349 |
+
shortcut_dims = []
|
1350 |
+
scale = 1.0
|
1351 |
+
|
1352 |
+
# embeddings
|
1353 |
+
self.time_embed = nn.Sequential(
|
1354 |
+
nn.Linear(dim, embed_dim), nn.SiLU(),
|
1355 |
+
nn.Linear(embed_dim, embed_dim))
|
1356 |
+
|
1357 |
+
if self.use_fps_condition:
|
1358 |
+
self.fps_embedding = nn.Sequential(
|
1359 |
+
nn.Linear(dim, embed_dim), nn.SiLU(),
|
1360 |
+
nn.Linear(embed_dim, embed_dim))
|
1361 |
+
nn.init.zeros_(self.fps_embedding[-1].weight)
|
1362 |
+
nn.init.zeros_(self.fps_embedding[-1].bias)
|
1363 |
+
|
1364 |
+
# encoder
|
1365 |
+
self.input_blocks = nn.ModuleList()
|
1366 |
+
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
1367 |
+
# need an initial temporal attention?
|
1368 |
+
if temporal_attention:
|
1369 |
+
if USE_TEMPORAL_TRANSFORMER:
|
1370 |
+
init_block.append(
|
1371 |
+
TemporalTransformer(
|
1372 |
+
dim,
|
1373 |
+
num_heads,
|
1374 |
+
head_dim,
|
1375 |
+
depth=transformer_depth,
|
1376 |
+
context_dim=context_dim,
|
1377 |
+
disable_self_attn=disabled_sa,
|
1378 |
+
use_linear=use_linear_in_temporal,
|
1379 |
+
multiply_zero=use_image_dataset,
|
1380 |
+
is_ctrl=True
|
1381 |
+
))
|
1382 |
+
else:
|
1383 |
+
init_block.append(
|
1384 |
+
TemporalAttentionMultiBlock(
|
1385 |
+
dim,
|
1386 |
+
num_heads,
|
1387 |
+
head_dim,
|
1388 |
+
rotary_emb=self.rotary_emb,
|
1389 |
+
temporal_attn_times=temporal_attn_times,
|
1390 |
+
use_image_dataset=use_image_dataset))
|
1391 |
+
self.input_blocks.append(init_block)
|
1392 |
+
shortcut_dims.append(dim)
|
1393 |
+
for i, (in_dim,
|
1394 |
+
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
1395 |
+
for j in range(num_res_blocks):
|
1396 |
+
block = nn.ModuleList([
|
1397 |
+
ResBlock(
|
1398 |
+
in_dim,
|
1399 |
+
embed_dim,
|
1400 |
+
dropout,
|
1401 |
+
out_channels=out_dim,
|
1402 |
+
use_scale_shift_norm=False,
|
1403 |
+
use_image_dataset=use_image_dataset,
|
1404 |
+
)
|
1405 |
+
])
|
1406 |
+
if scale in attn_scales:
|
1407 |
+
block.append(
|
1408 |
+
SpatialTransformer(
|
1409 |
+
out_dim,
|
1410 |
+
out_dim // head_dim,
|
1411 |
+
head_dim,
|
1412 |
+
depth=1,
|
1413 |
+
context_dim=self.context_dim,
|
1414 |
+
disable_self_attn=False,
|
1415 |
+
use_linear=True,
|
1416 |
+
is_ctrl=True
|
1417 |
+
))
|
1418 |
+
if self.temporal_attention:
|
1419 |
+
if USE_TEMPORAL_TRANSFORMER:
|
1420 |
+
block.append(
|
1421 |
+
TemporalTransformer(
|
1422 |
+
out_dim,
|
1423 |
+
out_dim // head_dim,
|
1424 |
+
head_dim,
|
1425 |
+
depth=transformer_depth,
|
1426 |
+
context_dim=context_dim,
|
1427 |
+
disable_self_attn=disabled_sa,
|
1428 |
+
use_linear=use_linear_in_temporal,
|
1429 |
+
multiply_zero=use_image_dataset,
|
1430 |
+
is_ctrl=True
|
1431 |
+
))
|
1432 |
+
else:
|
1433 |
+
block.append(
|
1434 |
+
TemporalAttentionMultiBlock(
|
1435 |
+
out_dim,
|
1436 |
+
num_heads,
|
1437 |
+
head_dim,
|
1438 |
+
rotary_emb=self.rotary_emb,
|
1439 |
+
use_image_dataset=use_image_dataset,
|
1440 |
+
use_sim_mask=use_sim_mask,
|
1441 |
+
temporal_attn_times=temporal_attn_times))
|
1442 |
+
in_dim = out_dim
|
1443 |
+
self.input_blocks.append(block)
|
1444 |
+
shortcut_dims.append(out_dim)
|
1445 |
+
|
1446 |
+
# downsample
|
1447 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
1448 |
+
downsample = Downsample(
|
1449 |
+
out_dim, True, dims=2, out_channels=out_dim)
|
1450 |
+
shortcut_dims.append(out_dim)
|
1451 |
+
scale /= 2.0
|
1452 |
+
self.input_blocks.append(downsample)
|
1453 |
+
|
1454 |
+
self.middle_block = nn.ModuleList([
|
1455 |
+
ResBlock(
|
1456 |
+
out_dim,
|
1457 |
+
embed_dim,
|
1458 |
+
dropout,
|
1459 |
+
use_scale_shift_norm=False,
|
1460 |
+
use_image_dataset=use_image_dataset,
|
1461 |
+
),
|
1462 |
+
SpatialTransformer(
|
1463 |
+
out_dim,
|
1464 |
+
out_dim // head_dim,
|
1465 |
+
head_dim,
|
1466 |
+
depth=1,
|
1467 |
+
context_dim=self.context_dim,
|
1468 |
+
disable_self_attn=False,
|
1469 |
+
use_linear=True,
|
1470 |
+
is_ctrl=True
|
1471 |
+
)
|
1472 |
+
])
|
1473 |
+
|
1474 |
+
if self.temporal_attention:
|
1475 |
+
if USE_TEMPORAL_TRANSFORMER:
|
1476 |
+
self.middle_block.append(
|
1477 |
+
TemporalTransformer(
|
1478 |
+
out_dim,
|
1479 |
+
out_dim // head_dim,
|
1480 |
+
head_dim,
|
1481 |
+
depth=transformer_depth,
|
1482 |
+
context_dim=context_dim,
|
1483 |
+
disable_self_attn=disabled_sa,
|
1484 |
+
use_linear=use_linear_in_temporal,
|
1485 |
+
multiply_zero=use_image_dataset,
|
1486 |
+
is_ctrl=True
|
1487 |
+
|
1488 |
+
))
|
1489 |
+
else:
|
1490 |
+
self.middle_block.append(
|
1491 |
+
TemporalAttentionMultiBlock(
|
1492 |
+
out_dim,
|
1493 |
+
num_heads,
|
1494 |
+
head_dim,
|
1495 |
+
rotary_emb=self.rotary_emb,
|
1496 |
+
use_image_dataset=use_image_dataset,
|
1497 |
+
use_sim_mask=use_sim_mask,
|
1498 |
+
temporal_attn_times=temporal_attn_times))
|
1499 |
+
|
1500 |
+
self.middle_block.append(
|
1501 |
+
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
|
1502 |
+
|
1503 |
+
# decoder
|
1504 |
+
self.output_blocks = nn.ModuleList()
|
1505 |
+
for i, (in_dim,
|
1506 |
+
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
1507 |
+
for j in range(num_res_blocks + 1):
|
1508 |
+
block = nn.ModuleList([
|
1509 |
+
ResBlock(
|
1510 |
+
in_dim + shortcut_dims.pop(),
|
1511 |
+
embed_dim,
|
1512 |
+
dropout,
|
1513 |
+
out_dim,
|
1514 |
+
use_scale_shift_norm=False,
|
1515 |
+
use_image_dataset=use_image_dataset,
|
1516 |
+
)
|
1517 |
+
])
|
1518 |
+
if scale in attn_scales:
|
1519 |
+
block.append(
|
1520 |
+
SpatialTransformer(
|
1521 |
+
out_dim,
|
1522 |
+
out_dim // head_dim,
|
1523 |
+
head_dim,
|
1524 |
+
depth=1,
|
1525 |
+
context_dim=1024,
|
1526 |
+
disable_self_attn=False,
|
1527 |
+
use_linear=True,
|
1528 |
+
is_ctrl=True))
|
1529 |
+
if self.temporal_attention:
|
1530 |
+
if USE_TEMPORAL_TRANSFORMER:
|
1531 |
+
block.append(
|
1532 |
+
TemporalTransformer(
|
1533 |
+
out_dim,
|
1534 |
+
out_dim // head_dim,
|
1535 |
+
head_dim,
|
1536 |
+
depth=transformer_depth,
|
1537 |
+
context_dim=context_dim,
|
1538 |
+
disable_self_attn=disabled_sa,
|
1539 |
+
use_linear=use_linear_in_temporal,
|
1540 |
+
multiply_zero=use_image_dataset,
|
1541 |
+
is_ctrl=True))
|
1542 |
+
else:
|
1543 |
+
block.append(
|
1544 |
+
TemporalAttentionMultiBlock(
|
1545 |
+
out_dim,
|
1546 |
+
num_heads,
|
1547 |
+
head_dim,
|
1548 |
+
rotary_emb=self.rotary_emb,
|
1549 |
+
use_image_dataset=use_image_dataset,
|
1550 |
+
use_sim_mask=use_sim_mask,
|
1551 |
+
temporal_attn_times=temporal_attn_times))
|
1552 |
+
in_dim = out_dim
|
1553 |
+
|
1554 |
+
# upsample
|
1555 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
1556 |
+
upsample = Upsample(
|
1557 |
+
out_dim, True, dims=2.0, out_channels=out_dim)
|
1558 |
+
scale *= 2.0
|
1559 |
+
block.append(upsample)
|
1560 |
+
self.output_blocks.append(block)
|
1561 |
+
|
1562 |
+
# head
|
1563 |
+
self.out = nn.Sequential(
|
1564 |
+
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
1565 |
+
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
1566 |
+
|
1567 |
+
# zero out the last layer params
|
1568 |
+
nn.init.zeros_(self.out[-1].weight)
|
1569 |
+
|
1570 |
+
def forward(self,
|
1571 |
+
x,
|
1572 |
+
t,
|
1573 |
+
y,
|
1574 |
+
x_lr=None,
|
1575 |
+
fps=None,
|
1576 |
+
video_mask=None,
|
1577 |
+
focus_present_mask=None,
|
1578 |
+
prob_focus_present=0.,
|
1579 |
+
mask_last_frame_num=0):
|
1580 |
+
|
1581 |
+
batch, c, f, h, w = x.shape
|
1582 |
+
device = x.device
|
1583 |
+
self.batch = batch
|
1584 |
+
|
1585 |
+
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
1586 |
+
if mask_last_frame_num > 0:
|
1587 |
+
focus_present_mask = None
|
1588 |
+
video_mask[-mask_last_frame_num:] = False
|
1589 |
+
else:
|
1590 |
+
focus_present_mask = default(
|
1591 |
+
focus_present_mask, lambda: prob_mask_like(
|
1592 |
+
(batch, ), prob_focus_present, device=device))
|
1593 |
+
|
1594 |
+
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
1595 |
+
time_rel_pos_bias = self.time_rel_pos_bias(
|
1596 |
+
x.shape[2], device=x.device)
|
1597 |
+
else:
|
1598 |
+
time_rel_pos_bias = None
|
1599 |
+
|
1600 |
+
# embeddings
|
1601 |
+
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
1602 |
+
context = y
|
1603 |
+
|
1604 |
+
# repeat f times for spatial e and context
|
1605 |
+
e = e.repeat_interleave(repeats=f, dim=0)
|
1606 |
+
context = context.repeat_interleave(repeats=f, dim=0)
|
1607 |
+
|
1608 |
+
# always in shape (b f) c h w, except for temporal layer
|
1609 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1610 |
+
# encoder
|
1611 |
+
xs = []
|
1612 |
+
for ind, block in enumerate(self.input_blocks):
|
1613 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
1614 |
+
focus_present_mask, video_mask)
|
1615 |
+
xs.append(x)
|
1616 |
+
|
1617 |
+
# middle
|
1618 |
+
for block in self.middle_block:
|
1619 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
1620 |
+
focus_present_mask, video_mask)
|
1621 |
+
|
1622 |
+
# decoder
|
1623 |
+
for block in self.output_blocks:
|
1624 |
+
x = torch.cat([x, xs.pop()], dim=1)
|
1625 |
+
x = self._forward_single(
|
1626 |
+
block,
|
1627 |
+
x,
|
1628 |
+
e,
|
1629 |
+
context,
|
1630 |
+
time_rel_pos_bias,
|
1631 |
+
focus_present_mask,
|
1632 |
+
video_mask,
|
1633 |
+
reference=xs[-1] if len(xs) > 0 else None)
|
1634 |
+
|
1635 |
+
# head
|
1636 |
+
x = self.out(x)
|
1637 |
+
|
1638 |
+
# reshape back to (b c f h w)
|
1639 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
|
1640 |
+
return x
|
1641 |
+
|
1642 |
+
def _forward_single(self,
|
1643 |
+
module,
|
1644 |
+
x,
|
1645 |
+
e,
|
1646 |
+
context,
|
1647 |
+
time_rel_pos_bias,
|
1648 |
+
focus_present_mask,
|
1649 |
+
video_mask,
|
1650 |
+
reference=None):
|
1651 |
+
if isinstance(module, ResidualBlock):
|
1652 |
+
module = checkpoint_wrapper(
|
1653 |
+
module) if self.use_checkpoint else module
|
1654 |
+
x = x.contiguous()
|
1655 |
+
x = module(x, e, reference)
|
1656 |
+
elif isinstance(module, ResBlock):
|
1657 |
+
module = checkpoint_wrapper(
|
1658 |
+
module) if self.use_checkpoint else module
|
1659 |
+
x = x.contiguous()
|
1660 |
+
x = module(x, e, self.batch)
|
1661 |
+
elif isinstance(module, SpatialTransformer):
|
1662 |
+
module = checkpoint_wrapper(
|
1663 |
+
module) if self.use_checkpoint else module
|
1664 |
+
x = module(x, context)
|
1665 |
+
elif isinstance(module, TemporalTransformer):
|
1666 |
+
module = checkpoint_wrapper(
|
1667 |
+
module) if self.use_checkpoint else module
|
1668 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1669 |
+
x = module(x, context)
|
1670 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1671 |
+
elif isinstance(module, CrossAttention):
|
1672 |
+
module = checkpoint_wrapper(
|
1673 |
+
module) if self.use_checkpoint else module
|
1674 |
+
x = module(x, context)
|
1675 |
+
elif isinstance(module, MemoryEfficientCrossAttention):
|
1676 |
+
module = checkpoint_wrapper(
|
1677 |
+
module) if self.use_checkpoint else module
|
1678 |
+
x = module(x, context)
|
1679 |
+
elif isinstance(module, BasicTransformerBlock):
|
1680 |
+
module = checkpoint_wrapper(
|
1681 |
+
module) if self.use_checkpoint else module
|
1682 |
+
x = module(x, context)
|
1683 |
+
elif isinstance(module, FeedForward):
|
1684 |
+
x = module(x, context)
|
1685 |
+
elif isinstance(module, Upsample):
|
1686 |
+
x = module(x)
|
1687 |
+
elif isinstance(module, Downsample):
|
1688 |
+
x = module(x)
|
1689 |
+
elif isinstance(module, Resample):
|
1690 |
+
x = module(x, reference)
|
1691 |
+
elif isinstance(module, TemporalAttentionBlock):
|
1692 |
+
module = checkpoint_wrapper(
|
1693 |
+
module) if self.use_checkpoint else module
|
1694 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1695 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
1696 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1697 |
+
elif isinstance(module, TemporalAttentionMultiBlock):
|
1698 |
+
module = checkpoint_wrapper(
|
1699 |
+
module) if self.use_checkpoint else module
|
1700 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1701 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
1702 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1703 |
+
elif isinstance(module, InitTemporalConvBlock):
|
1704 |
+
module = checkpoint_wrapper(
|
1705 |
+
module) if self.use_checkpoint else module
|
1706 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1707 |
+
x = module(x)
|
1708 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1709 |
+
elif isinstance(module, TemporalConvBlock):
|
1710 |
+
module = checkpoint_wrapper(
|
1711 |
+
module) if self.use_checkpoint else module
|
1712 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1713 |
+
x = module(x)
|
1714 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1715 |
+
elif isinstance(module, nn.ModuleList):
|
1716 |
+
for block in module:
|
1717 |
+
x = self._forward_single(block, x, e, context,
|
1718 |
+
time_rel_pos_bias, focus_present_mask,
|
1719 |
+
video_mask, reference)
|
1720 |
+
else:
|
1721 |
+
x = module(x)
|
1722 |
+
return x
|
1723 |
+
|
1724 |
+
|
1725 |
+
class ControlledV2VUNet(Vid2VidSDUNet):
|
1726 |
+
def __init__(self):
|
1727 |
+
super(ControlledV2VUNet, self).__init__()
|
1728 |
+
self.VideoControlNet = VideoControlNet()
|
1729 |
+
|
1730 |
+
def forward(self,
|
1731 |
+
x,
|
1732 |
+
t,
|
1733 |
+
y,
|
1734 |
+
hint=None,
|
1735 |
+
variant_info=None,
|
1736 |
+
hint_chunk=None,
|
1737 |
+
t_hint=None,
|
1738 |
+
s_cond=None,
|
1739 |
+
mask_cond=None,
|
1740 |
+
x_lr=None,
|
1741 |
+
fps=None,
|
1742 |
+
mask=None,
|
1743 |
+
video_mask=None,
|
1744 |
+
focus_present_mask=None,
|
1745 |
+
prob_focus_present=0.,
|
1746 |
+
mask_last_frame_num=0,
|
1747 |
+
):
|
1748 |
+
|
1749 |
+
batch, _, f, _, _= x.shape
|
1750 |
+
device = x.device
|
1751 |
+
self.batch = batch
|
1752 |
+
|
1753 |
+
# Process text (new added for t5 encoder)
|
1754 |
+
# y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
|
1755 |
+
|
1756 |
+
if hint_chunk is not None:
|
1757 |
+
hint = hint_chunk
|
1758 |
+
|
1759 |
+
control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \
|
1760 |
+
mask_cond=mask_cond, s_cond=s_cond, \
|
1761 |
+
variant_info=variant_info)
|
1762 |
+
|
1763 |
+
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
1764 |
+
if mask_last_frame_num > 0:
|
1765 |
+
focus_present_mask = None
|
1766 |
+
video_mask[-mask_last_frame_num:] = False
|
1767 |
+
else:
|
1768 |
+
focus_present_mask = default(
|
1769 |
+
focus_present_mask, lambda: prob_mask_like(
|
1770 |
+
(batch, ), prob_focus_present, device=device))
|
1771 |
+
|
1772 |
+
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
1773 |
+
time_rel_pos_bias = self.time_rel_pos_bias(
|
1774 |
+
x.shape[2], device=x.device)
|
1775 |
+
else:
|
1776 |
+
time_rel_pos_bias = None
|
1777 |
+
|
1778 |
+
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
1779 |
+
e = e.repeat_interleave(repeats=f, dim=0)
|
1780 |
+
|
1781 |
+
# context = y
|
1782 |
+
context = y.repeat_interleave(repeats=f, dim=0)
|
1783 |
+
|
1784 |
+
# always in shape (b f) c h w, except for temporal layer
|
1785 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1786 |
+
# encoder
|
1787 |
+
xs = []
|
1788 |
+
for block in self.input_blocks:
|
1789 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
1790 |
+
focus_present_mask, video_mask, variant_info=variant_info)
|
1791 |
+
xs.append(x)
|
1792 |
+
# middle
|
1793 |
+
for block in self.middle_block:
|
1794 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
1795 |
+
focus_present_mask, video_mask, variant_info=variant_info)
|
1796 |
+
|
1797 |
+
if control is not None:
|
1798 |
+
x = control.pop() + x
|
1799 |
+
|
1800 |
+
# decoder
|
1801 |
+
for block in self.output_blocks:
|
1802 |
+
if control is None:
|
1803 |
+
x = torch.cat([x, xs.pop()], dim=1)
|
1804 |
+
else:
|
1805 |
+
x = torch.cat([x, xs.pop() + control.pop()], dim=1)
|
1806 |
+
x = self._forward_single(
|
1807 |
+
block,
|
1808 |
+
x,
|
1809 |
+
e,
|
1810 |
+
context,
|
1811 |
+
time_rel_pos_bias,
|
1812 |
+
focus_present_mask,
|
1813 |
+
video_mask,
|
1814 |
+
reference=xs[-1] if len(xs) > 0 else None,
|
1815 |
+
variant_info=variant_info)
|
1816 |
+
|
1817 |
+
# head
|
1818 |
+
x = self.out(x)
|
1819 |
+
|
1820 |
+
# reshape back to (b c f h w)
|
1821 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
|
1822 |
+
return x
|
1823 |
+
|
1824 |
+
def _forward_single(self,
|
1825 |
+
module,
|
1826 |
+
x,
|
1827 |
+
e,
|
1828 |
+
context,
|
1829 |
+
time_rel_pos_bias,
|
1830 |
+
focus_present_mask,
|
1831 |
+
video_mask,
|
1832 |
+
reference=None,
|
1833 |
+
variant_info=None):
|
1834 |
+
variant_info = None # For Debug
|
1835 |
+
if isinstance(module, ResidualBlock):
|
1836 |
+
module = checkpoint_wrapper(
|
1837 |
+
module) if self.use_checkpoint else module
|
1838 |
+
x = x.contiguous()
|
1839 |
+
x = module(x, e, reference)
|
1840 |
+
elif isinstance(module, ResBlock):
|
1841 |
+
module = checkpoint_wrapper(
|
1842 |
+
module) if self.use_checkpoint else module
|
1843 |
+
x = x.contiguous()
|
1844 |
+
x = module(x, e, self.batch, variant_info)
|
1845 |
+
elif isinstance(module, SpatialTransformer):
|
1846 |
+
module = checkpoint_wrapper(
|
1847 |
+
module) if self.use_checkpoint else module
|
1848 |
+
x = module(x, context)
|
1849 |
+
elif isinstance(module, TemporalTransformer):
|
1850 |
+
module = checkpoint_wrapper(
|
1851 |
+
module) if self.use_checkpoint else module
|
1852 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1853 |
+
x = module(x, context)
|
1854 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1855 |
+
elif isinstance(module, CrossAttention):
|
1856 |
+
module = checkpoint_wrapper(
|
1857 |
+
module) if self.use_checkpoint else module
|
1858 |
+
x = module(x, context)
|
1859 |
+
elif isinstance(module, MemoryEfficientCrossAttention):
|
1860 |
+
module = checkpoint_wrapper(
|
1861 |
+
module) if self.use_checkpoint else module
|
1862 |
+
x = module(x, context)
|
1863 |
+
elif isinstance(module, BasicTransformerBlock):
|
1864 |
+
module = checkpoint_wrapper(
|
1865 |
+
module) if self.use_checkpoint else module
|
1866 |
+
x = module(x, context)
|
1867 |
+
elif isinstance(module, FeedForward):
|
1868 |
+
x = module(x, context)
|
1869 |
+
elif isinstance(module, Upsample):
|
1870 |
+
x = module(x)
|
1871 |
+
elif isinstance(module, Downsample):
|
1872 |
+
x = module(x)
|
1873 |
+
elif isinstance(module, Resample):
|
1874 |
+
x = module(x, reference)
|
1875 |
+
elif isinstance(module, TemporalAttentionBlock):
|
1876 |
+
module = checkpoint_wrapper(
|
1877 |
+
module) if self.use_checkpoint else module
|
1878 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1879 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
1880 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1881 |
+
elif isinstance(module, TemporalAttentionMultiBlock):
|
1882 |
+
module = checkpoint_wrapper(
|
1883 |
+
module) if self.use_checkpoint else module
|
1884 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1885 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
1886 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1887 |
+
elif isinstance(module, InitTemporalConvBlock):
|
1888 |
+
module = checkpoint_wrapper(
|
1889 |
+
module) if self.use_checkpoint else module
|
1890 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1891 |
+
x = module(x)
|
1892 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1893 |
+
elif isinstance(module, TemporalConvBlock):
|
1894 |
+
module = checkpoint_wrapper(
|
1895 |
+
module) if self.use_checkpoint else module
|
1896 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
1897 |
+
x = module(x)
|
1898 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
1899 |
+
elif isinstance(module, nn.ModuleList):
|
1900 |
+
for block in module:
|
1901 |
+
x = self._forward_single(block, x, e, context,
|
1902 |
+
time_rel_pos_bias, focus_present_mask,
|
1903 |
+
video_mask, reference, variant_info)
|
1904 |
+
else:
|
1905 |
+
x = module(x)
|
1906 |
+
return x
|
1907 |
+
|
1908 |
+
|
1909 |
+
class VideoControlNet(nn.Module):
|
1910 |
+
|
1911 |
+
def __init__(self,
|
1912 |
+
in_dim=4,
|
1913 |
+
dim=320,
|
1914 |
+
y_dim=1024,
|
1915 |
+
context_dim=1024,
|
1916 |
+
out_dim=4,
|
1917 |
+
dim_mult=[1, 2, 4, 4],
|
1918 |
+
num_heads=8,
|
1919 |
+
head_dim=64,
|
1920 |
+
num_res_blocks=2,
|
1921 |
+
attn_scales=[1 / 1, 1 / 2, 1 / 4],
|
1922 |
+
use_scale_shift_norm=True,
|
1923 |
+
dropout=0.1,
|
1924 |
+
temporal_attn_times=1,
|
1925 |
+
temporal_attention=True,
|
1926 |
+
use_checkpoint=True,
|
1927 |
+
use_image_dataset=False,
|
1928 |
+
use_fps_condition=False,
|
1929 |
+
use_sim_mask=False,
|
1930 |
+
training=False,
|
1931 |
+
inpainting=True):
|
1932 |
+
embed_dim = dim * 4
|
1933 |
+
num_heads = num_heads if num_heads else dim // 32
|
1934 |
+
super(VideoControlNet, self).__init__()
|
1935 |
+
self.in_dim = in_dim
|
1936 |
+
self.dim = dim
|
1937 |
+
self.y_dim = y_dim
|
1938 |
+
self.context_dim = context_dim
|
1939 |
+
self.embed_dim = embed_dim
|
1940 |
+
self.out_dim = out_dim
|
1941 |
+
self.dim_mult = dim_mult
|
1942 |
+
# for temporal attention
|
1943 |
+
self.num_heads = num_heads
|
1944 |
+
# for spatial attention
|
1945 |
+
self.head_dim = head_dim
|
1946 |
+
self.num_res_blocks = num_res_blocks
|
1947 |
+
self.attn_scales = attn_scales
|
1948 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
1949 |
+
self.temporal_attn_times = temporal_attn_times
|
1950 |
+
self.temporal_attention = temporal_attention
|
1951 |
+
self.use_checkpoint = use_checkpoint
|
1952 |
+
self.use_image_dataset = use_image_dataset
|
1953 |
+
self.use_fps_condition = use_fps_condition
|
1954 |
+
self.use_sim_mask = use_sim_mask
|
1955 |
+
self.training = training
|
1956 |
+
self.inpainting = inpainting
|
1957 |
+
|
1958 |
+
use_linear_in_temporal = False
|
1959 |
+
transformer_depth = 1
|
1960 |
+
disabled_sa = False
|
1961 |
+
# params
|
1962 |
+
enc_dims = [dim * u for u in [1] + dim_mult]
|
1963 |
+
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
1964 |
+
shortcut_dims = []
|
1965 |
+
scale = 1.0
|
1966 |
+
|
1967 |
+
# CaptionEmbedder (new add)
|
1968 |
+
# approx_gelu = lambda: nn.GELU(approximate="tanh")
|
1969 |
+
# self.y_embedder = CaptionEmbedder(
|
1970 |
+
# in_channels=4096,
|
1971 |
+
# hidden_size=1024,
|
1972 |
+
# uncond_prob=0.1,
|
1973 |
+
# act_layer=approx_gelu,
|
1974 |
+
# token_num=120,
|
1975 |
+
# )
|
1976 |
+
|
1977 |
+
# embeddings
|
1978 |
+
self.time_embed = nn.Sequential(
|
1979 |
+
nn.Linear(dim, embed_dim), nn.SiLU(),
|
1980 |
+
nn.Linear(embed_dim, embed_dim))
|
1981 |
+
|
1982 |
+
# self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
|
1983 |
+
|
1984 |
+
# scale prompt
|
1985 |
+
# self.scale_cond = nn.Sequential(
|
1986 |
+
# nn.Linear(dim, embed_dim), nn.SiLU(),
|
1987 |
+
# zero_module(nn.Linear(embed_dim, embed_dim)))
|
1988 |
+
|
1989 |
+
if self.use_fps_condition:
|
1990 |
+
self.fps_embedding = nn.Sequential(
|
1991 |
+
nn.Linear(dim, embed_dim), nn.SiLU(),
|
1992 |
+
nn.Linear(embed_dim, embed_dim))
|
1993 |
+
nn.init.zeros_(self.fps_embedding[-1].weight)
|
1994 |
+
nn.init.zeros_(self.fps_embedding[-1].bias)
|
1995 |
+
|
1996 |
+
# encoder
|
1997 |
+
self.input_blocks = nn.ModuleList()
|
1998 |
+
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
1999 |
+
# need an initial temporal attention?
|
2000 |
+
if temporal_attention:
|
2001 |
+
if USE_TEMPORAL_TRANSFORMER:
|
2002 |
+
init_block.append(
|
2003 |
+
TemporalTransformer(
|
2004 |
+
dim,
|
2005 |
+
num_heads,
|
2006 |
+
head_dim,
|
2007 |
+
depth=transformer_depth,
|
2008 |
+
context_dim=context_dim,
|
2009 |
+
disable_self_attn=disabled_sa,
|
2010 |
+
use_linear=use_linear_in_temporal,
|
2011 |
+
multiply_zero=use_image_dataset,
|
2012 |
+
is_ctrl=True,))
|
2013 |
+
else:
|
2014 |
+
init_block.append(
|
2015 |
+
TemporalAttentionMultiBlock(
|
2016 |
+
dim,
|
2017 |
+
num_heads,
|
2018 |
+
head_dim,
|
2019 |
+
rotary_emb=self.rotary_emb,
|
2020 |
+
temporal_attn_times=temporal_attn_times,
|
2021 |
+
use_image_dataset=use_image_dataset))
|
2022 |
+
self.input_blocks.append(init_block)
|
2023 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)])
|
2024 |
+
shortcut_dims.append(dim)
|
2025 |
+
for i, (in_dim,
|
2026 |
+
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
2027 |
+
for j in range(num_res_blocks):
|
2028 |
+
block = nn.ModuleList([
|
2029 |
+
ResBlock(
|
2030 |
+
in_dim,
|
2031 |
+
embed_dim,
|
2032 |
+
dropout,
|
2033 |
+
out_channels=out_dim,
|
2034 |
+
use_scale_shift_norm=False,
|
2035 |
+
use_image_dataset=use_image_dataset,
|
2036 |
+
)
|
2037 |
+
])
|
2038 |
+
if scale in attn_scales:
|
2039 |
+
block.append(
|
2040 |
+
SpatialTransformer(
|
2041 |
+
out_dim,
|
2042 |
+
out_dim // head_dim,
|
2043 |
+
head_dim,
|
2044 |
+
depth=1,
|
2045 |
+
context_dim=self.context_dim,
|
2046 |
+
disable_self_attn=False,
|
2047 |
+
use_linear=True,
|
2048 |
+
is_ctrl=True))
|
2049 |
+
if self.temporal_attention:
|
2050 |
+
if USE_TEMPORAL_TRANSFORMER:
|
2051 |
+
block.append(
|
2052 |
+
TemporalTransformer(
|
2053 |
+
out_dim,
|
2054 |
+
out_dim // head_dim,
|
2055 |
+
head_dim,
|
2056 |
+
depth=transformer_depth,
|
2057 |
+
context_dim=context_dim,
|
2058 |
+
disable_self_attn=disabled_sa,
|
2059 |
+
use_linear=use_linear_in_temporal,
|
2060 |
+
multiply_zero=use_image_dataset,
|
2061 |
+
is_ctrl=True,))
|
2062 |
+
else:
|
2063 |
+
block.append(
|
2064 |
+
TemporalAttentionMultiBlock(
|
2065 |
+
out_dim,
|
2066 |
+
num_heads,
|
2067 |
+
head_dim,
|
2068 |
+
rotary_emb=self.rotary_emb,
|
2069 |
+
use_image_dataset=use_image_dataset,
|
2070 |
+
use_sim_mask=use_sim_mask,
|
2071 |
+
temporal_attn_times=temporal_attn_times))
|
2072 |
+
in_dim = out_dim
|
2073 |
+
self.input_blocks.append(block)
|
2074 |
+
self.zero_convs.append(self.make_zero_conv(out_dim))
|
2075 |
+
shortcut_dims.append(out_dim)
|
2076 |
+
|
2077 |
+
# downsample
|
2078 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
2079 |
+
downsample = Downsample(
|
2080 |
+
out_dim, True, dims=2, out_channels=out_dim)
|
2081 |
+
shortcut_dims.append(out_dim)
|
2082 |
+
scale /= 2.0
|
2083 |
+
self.input_blocks.append(downsample)
|
2084 |
+
self.zero_convs.append(self.make_zero_conv(out_dim))
|
2085 |
+
|
2086 |
+
self.middle_block = nn.ModuleList([
|
2087 |
+
ResBlock(
|
2088 |
+
out_dim,
|
2089 |
+
embed_dim,
|
2090 |
+
dropout,
|
2091 |
+
use_scale_shift_norm=False,
|
2092 |
+
use_image_dataset=use_image_dataset,
|
2093 |
+
),
|
2094 |
+
SpatialTransformer(
|
2095 |
+
out_dim,
|
2096 |
+
out_dim // head_dim,
|
2097 |
+
head_dim,
|
2098 |
+
depth=1,
|
2099 |
+
context_dim=self.context_dim,
|
2100 |
+
disable_self_attn=False,
|
2101 |
+
use_linear=True,
|
2102 |
+
is_ctrl=True)
|
2103 |
+
])
|
2104 |
+
|
2105 |
+
if self.temporal_attention:
|
2106 |
+
if USE_TEMPORAL_TRANSFORMER:
|
2107 |
+
self.middle_block.append(
|
2108 |
+
TemporalTransformer(
|
2109 |
+
out_dim,
|
2110 |
+
out_dim // head_dim,
|
2111 |
+
head_dim,
|
2112 |
+
depth=transformer_depth,
|
2113 |
+
context_dim=context_dim,
|
2114 |
+
disable_self_attn=disabled_sa,
|
2115 |
+
use_linear=use_linear_in_temporal,
|
2116 |
+
multiply_zero=use_image_dataset,
|
2117 |
+
is_ctrl=True,
|
2118 |
+
))
|
2119 |
+
else:
|
2120 |
+
self.middle_block.append(
|
2121 |
+
TemporalAttentionMultiBlock(
|
2122 |
+
out_dim,
|
2123 |
+
num_heads,
|
2124 |
+
head_dim,
|
2125 |
+
rotary_emb=self.rotary_emb,
|
2126 |
+
use_image_dataset=use_image_dataset,
|
2127 |
+
use_sim_mask=use_sim_mask,
|
2128 |
+
temporal_attn_times=temporal_attn_times))
|
2129 |
+
|
2130 |
+
self.middle_block.append(
|
2131 |
+
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
|
2132 |
+
|
2133 |
+
self.middle_block_out = self.make_zero_conv(embed_dim)
|
2134 |
+
|
2135 |
+
'''
|
2136 |
+
add prompt
|
2137 |
+
'''
|
2138 |
+
add_dim = 320
|
2139 |
+
self.add_dim = add_dim
|
2140 |
+
|
2141 |
+
self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1))
|
2142 |
+
|
2143 |
+
def make_zero_conv(self, in_channels, out_channels=None):
|
2144 |
+
out_channels = in_channels if out_channels is None else out_channels
|
2145 |
+
return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)))
|
2146 |
+
|
2147 |
+
def forward(self,
|
2148 |
+
x,
|
2149 |
+
t,
|
2150 |
+
y,
|
2151 |
+
s_cond=None,
|
2152 |
+
hint=None,
|
2153 |
+
variant_info=None,
|
2154 |
+
t_hint=None,
|
2155 |
+
mask_cond=None,
|
2156 |
+
fps=None,
|
2157 |
+
video_mask=None,
|
2158 |
+
focus_present_mask=None,
|
2159 |
+
prob_focus_present=0.,
|
2160 |
+
mask_last_frame_num=0):
|
2161 |
+
|
2162 |
+
batch, _, f, _, _ = x.shape
|
2163 |
+
device = x.device
|
2164 |
+
self.batch = batch
|
2165 |
+
|
2166 |
+
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
2167 |
+
if mask_last_frame_num > 0:
|
2168 |
+
focus_present_mask = None
|
2169 |
+
video_mask[-mask_last_frame_num:] = False
|
2170 |
+
else:
|
2171 |
+
focus_present_mask = default(
|
2172 |
+
focus_present_mask, lambda: prob_mask_like(
|
2173 |
+
(batch, ), prob_focus_present, device=device))
|
2174 |
+
|
2175 |
+
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
2176 |
+
time_rel_pos_bias = self.time_rel_pos_bias(
|
2177 |
+
x.shape[2], device=x.device)
|
2178 |
+
else:
|
2179 |
+
time_rel_pos_bias = None
|
2180 |
+
|
2181 |
+
if hint is not None:
|
2182 |
+
# add = x.new_zeros(batch, self.add_dim, f, h, w)
|
2183 |
+
hint = rearrange(hint, 'b c f h w -> (b f) c h w')
|
2184 |
+
hint = self.input_hint_block(hint)
|
2185 |
+
# hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
|
2186 |
+
|
2187 |
+
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
2188 |
+
e = e.repeat_interleave(repeats=f, dim=0)
|
2189 |
+
|
2190 |
+
context = y.repeat_interleave(repeats=f, dim=0)
|
2191 |
+
|
2192 |
+
# always in shape (b f) c h w, except for temporal layer
|
2193 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2194 |
+
# print('before x shape:', x.shape) [64, 320, 90, 160]
|
2195 |
+
# print('hint shape:', hint.shape) [32, 320, 90, 160]
|
2196 |
+
|
2197 |
+
# encoder
|
2198 |
+
xs = []
|
2199 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
2200 |
+
if hint is not None:
|
2201 |
+
for block in module:
|
2202 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
2203 |
+
focus_present_mask, video_mask, variant_info=variant_info)
|
2204 |
+
if not isinstance(block, TemporalTransformer):
|
2205 |
+
if hint is not None:
|
2206 |
+
x += hint
|
2207 |
+
hint = None
|
2208 |
+
else:
|
2209 |
+
x = self._forward_single(module, x, e, context, time_rel_pos_bias,
|
2210 |
+
focus_present_mask, video_mask, variant_info=variant_info)
|
2211 |
+
xs.append(zero_conv(x, e, context))
|
2212 |
+
|
2213 |
+
# middle
|
2214 |
+
for block in self.middle_block:
|
2215 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
2216 |
+
focus_present_mask, video_mask, variant_info=variant_info)
|
2217 |
+
xs.append(self.middle_block_out(x, e, context))
|
2218 |
+
|
2219 |
+
return xs
|
2220 |
+
|
2221 |
+
def _forward_single(self,
|
2222 |
+
module,
|
2223 |
+
x,
|
2224 |
+
e,
|
2225 |
+
context,
|
2226 |
+
time_rel_pos_bias,
|
2227 |
+
focus_present_mask,
|
2228 |
+
video_mask,
|
2229 |
+
reference=None,
|
2230 |
+
variant_info=None,):
|
2231 |
+
# variant_info = None # For Debug
|
2232 |
+
if isinstance(module, ResidualBlock):
|
2233 |
+
module = checkpoint_wrapper(
|
2234 |
+
module) if self.use_checkpoint else module
|
2235 |
+
x = x.contiguous()
|
2236 |
+
x = module(x, e, reference)
|
2237 |
+
elif isinstance(module, ResBlock):
|
2238 |
+
module = checkpoint_wrapper(
|
2239 |
+
module) if self.use_checkpoint else module
|
2240 |
+
x = x.contiguous()
|
2241 |
+
x = module(x, e, self.batch, variant_info)
|
2242 |
+
elif isinstance(module, SpatialTransformer):
|
2243 |
+
module = checkpoint_wrapper(
|
2244 |
+
module) if self.use_checkpoint else module
|
2245 |
+
x = module(x, context)
|
2246 |
+
elif isinstance(module, TemporalTransformer):
|
2247 |
+
module = checkpoint_wrapper(
|
2248 |
+
module) if self.use_checkpoint else module
|
2249 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
2250 |
+
# print("x shape:", x.shape) # [2, 320, 32, 90, 160]
|
2251 |
+
x = module(x, context)
|
2252 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2253 |
+
elif isinstance(module, CrossAttention):
|
2254 |
+
module = checkpoint_wrapper(
|
2255 |
+
module) if self.use_checkpoint else module
|
2256 |
+
x = module(x, context)
|
2257 |
+
elif isinstance(module, MemoryEfficientCrossAttention):
|
2258 |
+
module = checkpoint_wrapper(
|
2259 |
+
module) if self.use_checkpoint else module
|
2260 |
+
x = module(x, context)
|
2261 |
+
elif isinstance(module, BasicTransformerBlock):
|
2262 |
+
module = checkpoint_wrapper(
|
2263 |
+
module) if self.use_checkpoint else module
|
2264 |
+
x = module(x, context)
|
2265 |
+
elif isinstance(module, FeedForward):
|
2266 |
+
x = module(x, context)
|
2267 |
+
elif isinstance(module, Upsample):
|
2268 |
+
x = module(x)
|
2269 |
+
elif isinstance(module, Downsample):
|
2270 |
+
x = module(x)
|
2271 |
+
elif isinstance(module, Resample):
|
2272 |
+
x = module(x, reference)
|
2273 |
+
elif isinstance(module, TemporalAttentionBlock):
|
2274 |
+
module = checkpoint_wrapper(
|
2275 |
+
module) if self.use_checkpoint else module
|
2276 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
2277 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
2278 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2279 |
+
elif isinstance(module, TemporalAttentionMultiBlock):
|
2280 |
+
module = checkpoint_wrapper(
|
2281 |
+
module) if self.use_checkpoint else module
|
2282 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
2283 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
2284 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2285 |
+
elif isinstance(module, InitTemporalConvBlock):
|
2286 |
+
module = checkpoint_wrapper(
|
2287 |
+
module) if self.use_checkpoint else module
|
2288 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
2289 |
+
x = module(x)
|
2290 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2291 |
+
elif isinstance(module, TemporalConvBlock):
|
2292 |
+
module = checkpoint_wrapper(
|
2293 |
+
module) if self.use_checkpoint else module
|
2294 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
2295 |
+
x = module(x)
|
2296 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
2297 |
+
elif isinstance(module, nn.ModuleList):
|
2298 |
+
for block in module:
|
2299 |
+
x = self._forward_single(block, x, e, context,
|
2300 |
+
time_rel_pos_bias, focus_present_mask,
|
2301 |
+
video_mask, reference, variant_info)
|
2302 |
+
else:
|
2303 |
+
x = module(x)
|
2304 |
+
return x
|
2305 |
+
|
2306 |
+
|
2307 |
+
class TimestepBlock(nn.Module):
|
2308 |
+
"""
|
2309 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
2310 |
+
"""
|
2311 |
+
|
2312 |
+
@abstractmethod
|
2313 |
+
def forward(self, x, emb):
|
2314 |
+
"""
|
2315 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
2316 |
+
"""
|
2317 |
+
|
2318 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
2319 |
+
"""
|
2320 |
+
A sequential module that passes timestep embeddings to the children that
|
2321 |
+
support it as an extra input.
|
2322 |
+
"""
|
2323 |
+
|
2324 |
+
def forward(self, x, emb, context=None):
|
2325 |
+
for layer in self:
|
2326 |
+
if isinstance(layer, TimestepBlock):
|
2327 |
+
x = layer(x, emb)
|
2328 |
+
elif isinstance(layer, SpatialTransformer):
|
2329 |
+
x = layer(x, context)
|
2330 |
+
else:
|
2331 |
+
x = layer(x)
|
2332 |
+
return x
|
video_to_video/utils/__init__.py
ADDED
File without changes
|
video_to_video/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (158 Bytes). View file
|
|
video_to_video/utils/__pycache__/config.cpython-39.pyc
ADDED
Binary file (3.43 kB). View file
|
|
video_to_video/utils/__pycache__/logger.cpython-39.pyc
ADDED
Binary file (2.14 kB). View file
|
|
video_to_video/utils/__pycache__/seed.cpython-39.pyc
ADDED
Binary file (466 Bytes). View file
|
|
video_to_video/utils/config.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from easydict import EasyDict
|
10 |
+
|
11 |
+
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
12 |
+
|
13 |
+
# ---------------------------work dir--------------------------
|
14 |
+
cfg.work_dir = 'workspace/'
|
15 |
+
|
16 |
+
# ---------------------------Global Variable-----------------------------------
|
17 |
+
cfg.resolution = [448, 256]
|
18 |
+
cfg.max_frames = 32
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
# ---------------------------Dataset Parameter---------------------------------
|
22 |
+
cfg.mean = [0.5, 0.5, 0.5]
|
23 |
+
cfg.std = [0.5, 0.5, 0.5]
|
24 |
+
cfg.max_words = 1000
|
25 |
+
|
26 |
+
# PlaceHolder
|
27 |
+
cfg.vit_out_dim = 1024
|
28 |
+
cfg.vit_resolution = [224, 224]
|
29 |
+
cfg.depth_clamp = 10.0
|
30 |
+
cfg.misc_size = 384
|
31 |
+
cfg.depth_std = 20.0
|
32 |
+
|
33 |
+
cfg.frame_lens = 32
|
34 |
+
cfg.sample_fps = 8
|
35 |
+
|
36 |
+
cfg.batch_sizes = 1
|
37 |
+
# -----------------------------------------------------------------------------
|
38 |
+
|
39 |
+
# ---------------------------Mode Parameters-----------------------------------
|
40 |
+
# Diffusion
|
41 |
+
cfg.schedule = 'cosine'
|
42 |
+
cfg.num_timesteps = 1000
|
43 |
+
cfg.mean_type = 'v'
|
44 |
+
cfg.var_type = 'fixed_small'
|
45 |
+
cfg.loss_type = 'mse'
|
46 |
+
cfg.ddim_timesteps = 50
|
47 |
+
cfg.ddim_eta = 0.0
|
48 |
+
cfg.clamp = 1.0
|
49 |
+
cfg.share_noise = False
|
50 |
+
cfg.use_div_loss = False
|
51 |
+
cfg.noise_strength = 0.1
|
52 |
+
|
53 |
+
# classifier-free guidance
|
54 |
+
cfg.p_zero = 0.1
|
55 |
+
cfg.guide_scale = 3.0
|
56 |
+
|
57 |
+
# clip vision encoder
|
58 |
+
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
59 |
+
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
60 |
+
|
61 |
+
# Model
|
62 |
+
cfg.scale_factor = 0.18215
|
63 |
+
cfg.use_fp16 = True
|
64 |
+
cfg.temporal_attention = True
|
65 |
+
cfg.decoder_bs = 8
|
66 |
+
|
67 |
+
cfg.UNet = {
|
68 |
+
'type': 'Vid2VidSDUNet',
|
69 |
+
'in_dim': 4,
|
70 |
+
'dim': 320,
|
71 |
+
'y_dim': cfg.vit_out_dim,
|
72 |
+
'context_dim': 1024,
|
73 |
+
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
74 |
+
'dim_mult': [1, 2, 4, 4],
|
75 |
+
'num_heads': 8,
|
76 |
+
'head_dim': 64,
|
77 |
+
'num_res_blocks': 2,
|
78 |
+
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
79 |
+
'dropout': 0.1,
|
80 |
+
'temporal_attention': cfg.temporal_attention,
|
81 |
+
'temporal_attn_times': 1,
|
82 |
+
'use_checkpoint': False,
|
83 |
+
'use_fps_condition': False,
|
84 |
+
'use_sim_mask': False,
|
85 |
+
'num_tokens': 4,
|
86 |
+
'default_fps': 8,
|
87 |
+
'input_dim': 1024
|
88 |
+
}
|
89 |
+
|
90 |
+
cfg.guidances = []
|
91 |
+
|
92 |
+
# auotoencoder from stabel diffusion
|
93 |
+
cfg.auto_encoder = {
|
94 |
+
'type': 'AutoencoderKL',
|
95 |
+
'ddconfig': {
|
96 |
+
'double_z': True,
|
97 |
+
'z_channels': 4,
|
98 |
+
'resolution': 256,
|
99 |
+
'in_channels': 3,
|
100 |
+
'out_ch': 3,
|
101 |
+
'ch': 128,
|
102 |
+
'ch_mult': [1, 2, 4, 4],
|
103 |
+
'num_res_blocks': 2,
|
104 |
+
'attn_resolutions': [],
|
105 |
+
'dropout': 0.0
|
106 |
+
},
|
107 |
+
'embed_dim': 4,
|
108 |
+
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
|
109 |
+
}
|
110 |
+
# clip embedder
|
111 |
+
cfg.embedder = {
|
112 |
+
'type': 'FrozenOpenCLIPEmbedder',
|
113 |
+
'layer': 'penultimate',
|
114 |
+
'vit_resolution': [224, 224],
|
115 |
+
'pretrained': 'open_clip_pytorch_model.bin'
|
116 |
+
}
|
117 |
+
# -----------------------------------------------------------------------------
|
118 |
+
|
119 |
+
# ---------------------------Training Settings---------------------------------
|
120 |
+
# training and optimizer
|
121 |
+
cfg.ema_decay = 0.9999
|
122 |
+
cfg.num_steps = 600000
|
123 |
+
cfg.lr = 5e-5
|
124 |
+
cfg.weight_decay = 0.0
|
125 |
+
cfg.betas = (0.9, 0.999)
|
126 |
+
cfg.eps = 1.0e-8
|
127 |
+
cfg.chunk_size = 16
|
128 |
+
cfg.alpha = 0.7
|
129 |
+
cfg.save_ckp_interval = 1000
|
130 |
+
# -----------------------------------------------------------------------------
|
131 |
+
|
132 |
+
# ----------------------------Pretrain Settings---------------------------------
|
133 |
+
# Default: load 2d pretrain
|
134 |
+
cfg.fix_weight = False
|
135 |
+
cfg.load_match = False
|
136 |
+
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
137 |
+
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
138 |
+
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
139 |
+
# -----------------------------------------------------------------------------
|
140 |
+
|
141 |
+
# -----------------------------Visual-------------------------------------------
|
142 |
+
# Visual videos
|
143 |
+
cfg.viz_interval = 1000
|
144 |
+
cfg.visual_train = {
|
145 |
+
'type': 'VisualVideoTextDuringTrain',
|
146 |
+
}
|
147 |
+
cfg.visual_inference = {
|
148 |
+
'type': 'VisualGeneratedVideos',
|
149 |
+
}
|
150 |
+
cfg.inference_list_path = ''
|
151 |
+
|
152 |
+
# logging
|
153 |
+
cfg.log_interval = 100
|
154 |
+
|
155 |
+
# Default log_dir
|
156 |
+
cfg.log_dir = 'workspace/output_data'
|
157 |
+
# -----------------------------------------------------------------------------
|
158 |
+
|
159 |
+
# ---------------------------Others--------------------------------------------
|
160 |
+
# seed
|
161 |
+
cfg.seed = 8888
|
162 |
+
|
163 |
+
cfg.negative_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
|
164 |
+
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
|
165 |
+
signature, jpeg artifacts, deformed, lowres, over-smooth'
|
166 |
+
|
167 |
+
cfg.positive_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
|
168 |
+
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
|
169 |
+
skin pore detailing, hyper sharpness, perfect without deformations.'
|
video_to_video/utils/logger.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import importlib
|
4 |
+
import logging
|
5 |
+
from typing import Optional
|
6 |
+
from torch import distributed as dist
|
7 |
+
|
8 |
+
init_loggers = {}
|
9 |
+
|
10 |
+
formatter = logging.Formatter(
|
11 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
|
13 |
+
|
14 |
+
def get_logger(log_file: Optional[str] = None,
|
15 |
+
log_level: int = logging.INFO,
|
16 |
+
file_mode: str = 'w'):
|
17 |
+
""" Get logging logger
|
18 |
+
|
19 |
+
Args:
|
20 |
+
log_file: Log filename, if specified, file handler will be added to
|
21 |
+
logger
|
22 |
+
log_level: Logging level.
|
23 |
+
file_mode: Specifies the mode to open the file, if filename is
|
24 |
+
specified (if filemode is unspecified, it defaults to 'w').
|
25 |
+
"""
|
26 |
+
|
27 |
+
logger_name = __name__.split('.')[0]
|
28 |
+
logger = logging.getLogger(logger_name)
|
29 |
+
logger.propagate = False
|
30 |
+
if logger_name in init_loggers:
|
31 |
+
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
|
32 |
+
return logger
|
33 |
+
|
34 |
+
# handle duplicate logs to the console
|
35 |
+
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
36 |
+
# to the root logger. As logger.propagate is True by default, this root
|
37 |
+
# level handler causes logging messages from rank>0 processes to
|
38 |
+
# unexpectedly show up on the console, creating much unwanted clutter.
|
39 |
+
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
40 |
+
# at the ERROR level.
|
41 |
+
for handler in logger.root.handlers:
|
42 |
+
if type(handler) is logging.StreamHandler:
|
43 |
+
handler.setLevel(logging.ERROR)
|
44 |
+
|
45 |
+
stream_handler = logging.StreamHandler()
|
46 |
+
handlers = [stream_handler]
|
47 |
+
|
48 |
+
if importlib.util.find_spec('torch') is not None:
|
49 |
+
is_worker0 = is_master()
|
50 |
+
else:
|
51 |
+
is_worker0 = True
|
52 |
+
|
53 |
+
if is_worker0 and log_file is not None:
|
54 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
55 |
+
handlers.append(file_handler)
|
56 |
+
|
57 |
+
for handler in handlers:
|
58 |
+
handler.setFormatter(formatter)
|
59 |
+
handler.setLevel(log_level)
|
60 |
+
logger.addHandler(handler)
|
61 |
+
|
62 |
+
if is_worker0:
|
63 |
+
logger.setLevel(log_level)
|
64 |
+
else:
|
65 |
+
logger.setLevel(logging.ERROR)
|
66 |
+
|
67 |
+
init_loggers[logger_name] = True
|
68 |
+
|
69 |
+
return logger
|
70 |
+
|
71 |
+
|
72 |
+
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
|
73 |
+
for handler in logger.handlers:
|
74 |
+
if isinstance(handler, logging.FileHandler):
|
75 |
+
return
|
76 |
+
|
77 |
+
if importlib.util.find_spec('torch') is not None:
|
78 |
+
is_worker0 = is_master()
|
79 |
+
else:
|
80 |
+
is_worker0 = True
|
81 |
+
|
82 |
+
if is_worker0 and log_file is not None:
|
83 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
84 |
+
file_handler.setFormatter(formatter)
|
85 |
+
file_handler.setLevel(log_level)
|
86 |
+
logger.addHandler(file_handler)
|
87 |
+
|
88 |
+
|
89 |
+
def is_master(group=None):
|
90 |
+
return dist.get_rank(group) == 0 if is_dist() else True
|
91 |
+
|
92 |
+
|
93 |
+
def is_dist():
|
94 |
+
return dist.is_available() and dist.is_initialized()
|
video_to_video/utils/seed.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def setup_seed(seed):
|
10 |
+
torch.manual_seed(seed)
|
11 |
+
torch.cuda.manual_seed_all(seed)
|
12 |
+
np.random.seed(seed)
|
13 |
+
random.seed(seed)
|
14 |
+
torch.backends.cudnn.deterministic = True
|
video_to_video/video_to_video_model.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import random
|
4 |
+
from typing import Any, Dict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.cuda.amp as amp
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from video_to_video.modules import *
|
11 |
+
from video_to_video.utils.config import cfg
|
12 |
+
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
|
13 |
+
from video_to_video.diffusion.schedules_sdedit import noise_schedule
|
14 |
+
from video_to_video.utils.logger import get_logger
|
15 |
+
|
16 |
+
from diffusers import AutoencoderKLTemporalDecoder
|
17 |
+
|
18 |
+
logger = get_logger()
|
19 |
+
|
20 |
+
class VideoToVideo_sr():
|
21 |
+
def __init__(self, opt, device=torch.device(f'cuda:0')):
|
22 |
+
self.opt = opt
|
23 |
+
self.device = device # torch.device(f'cuda:0')
|
24 |
+
|
25 |
+
# text_encoder
|
26 |
+
text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
|
27 |
+
text_encoder.model.to(self.device)
|
28 |
+
self.text_encoder = text_encoder
|
29 |
+
logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
|
30 |
+
|
31 |
+
# U-Net with ControlNet
|
32 |
+
generator = ControlledV2VUNet()
|
33 |
+
generator = generator.to(self.device)
|
34 |
+
generator.eval()
|
35 |
+
|
36 |
+
cfg.model_path = opt.model_path
|
37 |
+
load_dict = torch.load(cfg.model_path, map_location='cpu')
|
38 |
+
if 'state_dict' in load_dict:
|
39 |
+
load_dict = load_dict['state_dict']
|
40 |
+
ret = generator.load_state_dict(load_dict, strict=False)
|
41 |
+
|
42 |
+
self.generator = generator.half()
|
43 |
+
logger.info('Load model path {}, with local status {}'.format(cfg.model_path, ret))
|
44 |
+
|
45 |
+
# Noise scheduler
|
46 |
+
sigmas = noise_schedule(
|
47 |
+
schedule='logsnr_cosine_interp',
|
48 |
+
n=1000,
|
49 |
+
zero_terminal_snr=True,
|
50 |
+
scale_min=2.0,
|
51 |
+
scale_max=4.0)
|
52 |
+
diffusion = GaussianDiffusion(sigmas=sigmas)
|
53 |
+
self.diffusion = diffusion
|
54 |
+
logger.info('Build diffusion with GaussianDiffusion')
|
55 |
+
|
56 |
+
# Temporal VAE
|
57 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
58 |
+
"stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
|
59 |
+
)
|
60 |
+
vae.eval()
|
61 |
+
vae.requires_grad_(False)
|
62 |
+
vae.to(self.device)
|
63 |
+
self.vae = vae
|
64 |
+
logger.info('Build Temporal VAE')
|
65 |
+
|
66 |
+
torch.cuda.empty_cache()
|
67 |
+
|
68 |
+
self.negative_prompt = cfg.negative_prompt
|
69 |
+
self.positive_prompt = cfg.positive_prompt
|
70 |
+
|
71 |
+
negative_y = text_encoder(self.negative_prompt).detach()
|
72 |
+
self.negative_y = negative_y
|
73 |
+
|
74 |
+
|
75 |
+
def test(self, input: Dict[str, Any], total_noise_levels=1000, \
|
76 |
+
steps=50, solver_mode='fast', guide_scale=7.5, max_chunk_len=32):
|
77 |
+
video_data = input['video_data']
|
78 |
+
y = input['y']
|
79 |
+
(target_h, target_w) = input['target_res']
|
80 |
+
|
81 |
+
video_data = F.interpolate(video_data, [target_h,target_w], mode='bilinear')
|
82 |
+
|
83 |
+
logger.info(f'video_data shape: {video_data.shape}')
|
84 |
+
frames_num, _, h, w = video_data.shape
|
85 |
+
|
86 |
+
padding = pad_to_fit(h, w)
|
87 |
+
video_data = F.pad(video_data, padding, 'constant', 1)
|
88 |
+
|
89 |
+
video_data = video_data.unsqueeze(0)
|
90 |
+
bs = 1
|
91 |
+
video_data = video_data.to(self.device)
|
92 |
+
|
93 |
+
video_data_feature = self.vae_encode(video_data)
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
|
96 |
+
y = self.text_encoder(y).detach()
|
97 |
+
|
98 |
+
with amp.autocast(enabled=True):
|
99 |
+
|
100 |
+
t = torch.LongTensor([total_noise_levels-1]).to(self.device)
|
101 |
+
noised_lr = self.diffusion.diffuse(video_data_feature, t)
|
102 |
+
|
103 |
+
model_kwargs = [{'y': y}, {'y': self.negative_y}]
|
104 |
+
model_kwargs.append({'hint': video_data_feature})
|
105 |
+
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
chunk_inds = make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) if frames_num > max_chunk_len else None
|
108 |
+
|
109 |
+
solver = 'dpmpp_2m_sde' # 'heun' | 'dpmpp_2m_sde'
|
110 |
+
gen_vid = self.diffusion.sample_sr(
|
111 |
+
noise=noised_lr,
|
112 |
+
model=self.generator,
|
113 |
+
model_kwargs=model_kwargs,
|
114 |
+
guide_scale=guide_scale,
|
115 |
+
guide_rescale=0.2,
|
116 |
+
solver=solver,
|
117 |
+
solver_mode=solver_mode,
|
118 |
+
return_intermediate=None,
|
119 |
+
steps=steps,
|
120 |
+
t_max=total_noise_levels - 1,
|
121 |
+
t_min=0,
|
122 |
+
discretization='trailing',
|
123 |
+
chunk_inds=chunk_inds,)
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
|
126 |
+
logger.info(f'sampling, finished.')
|
127 |
+
vid_tensor_gen = self.vae_decode_chunk(gen_vid, chunk_size=3)
|
128 |
+
|
129 |
+
logger.info(f'temporal vae decoding, finished.')
|
130 |
+
|
131 |
+
w1, w2, h1, h2 = padding
|
132 |
+
vid_tensor_gen = vid_tensor_gen[:,:,h1:h+h1,w1:w+w1]
|
133 |
+
|
134 |
+
gen_video = rearrange(
|
135 |
+
vid_tensor_gen, '(b f) c h w -> b c f h w', b=bs)
|
136 |
+
|
137 |
+
torch.cuda.empty_cache()
|
138 |
+
|
139 |
+
return gen_video.type(torch.float32).cpu()
|
140 |
+
|
141 |
+
def temporal_vae_decode(self, z, num_f):
|
142 |
+
return self.vae.decode(z/self.vae.config.scaling_factor, num_frames=num_f).sample
|
143 |
+
|
144 |
+
def vae_decode_chunk(self, z, chunk_size=3):
|
145 |
+
z = rearrange(z, "b c f h w -> (b f) c h w")
|
146 |
+
video = []
|
147 |
+
for ind in range(0, z.shape[0], chunk_size):
|
148 |
+
num_f = z[ind:ind+chunk_size].shape[0]
|
149 |
+
video.append(self.temporal_vae_decode(z[ind:ind+chunk_size],num_f))
|
150 |
+
video = torch.cat(video)
|
151 |
+
return video
|
152 |
+
|
153 |
+
def vae_encode(self, t, chunk_size=1):
|
154 |
+
num_f = t.shape[1]
|
155 |
+
t = rearrange(t, "b f c h w -> (b f) c h w")
|
156 |
+
z_list = []
|
157 |
+
for ind in range(0,t.shape[0],chunk_size):
|
158 |
+
z_list.append(self.vae.encode(t[ind:ind+chunk_size]).latent_dist.sample())
|
159 |
+
z = torch.cat(z_list, dim=0)
|
160 |
+
z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f)
|
161 |
+
return z * self.vae.config.scaling_factor
|
162 |
+
|
163 |
+
|
164 |
+
def pad_to_fit(h, w):
|
165 |
+
BEST_H, BEST_W = 720, 1280
|
166 |
+
|
167 |
+
if h < BEST_H:
|
168 |
+
h1, h2 = _create_pad(h, BEST_H)
|
169 |
+
elif h == BEST_H:
|
170 |
+
h1 = h2 = 0
|
171 |
+
else:
|
172 |
+
h1 = 0
|
173 |
+
h2 = int((h + 48) // 64 * 64) + 64 - 48 - h
|
174 |
+
|
175 |
+
if w < BEST_W:
|
176 |
+
w1, w2 = _create_pad(w, BEST_W)
|
177 |
+
elif w == BEST_W:
|
178 |
+
w1 = w2 = 0
|
179 |
+
else:
|
180 |
+
w1 = 0
|
181 |
+
w2 = int(w // 64 * 64) + 64 - w
|
182 |
+
return (w1, w2, h1, h2)
|
183 |
+
|
184 |
+
def _create_pad(h, max_len):
|
185 |
+
h1 = int((max_len - h) // 2)
|
186 |
+
h2 = max_len - h1 - h
|
187 |
+
return h1, h2
|
188 |
+
|
189 |
+
|
190 |
+
def make_chunks(f_num, interp_f_num, max_chunk_len, chunk_overlap_ratio=0.5):
|
191 |
+
MAX_CHUNK_LEN = max_chunk_len
|
192 |
+
MAX_O_LEN = MAX_CHUNK_LEN * chunk_overlap_ratio
|
193 |
+
chunk_len = int((MAX_CHUNK_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
|
194 |
+
o_len = int((MAX_O_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
|
195 |
+
chunk_inds = sliding_windows_1d(f_num, chunk_len, o_len)
|
196 |
+
return chunk_inds
|
197 |
+
|
198 |
+
|
199 |
+
def sliding_windows_1d(length, window_size, overlap_size):
|
200 |
+
stride = window_size - overlap_size
|
201 |
+
ind = 0
|
202 |
+
coords = []
|
203 |
+
while ind<length:
|
204 |
+
if ind+window_size*1.25>=length:
|
205 |
+
coords.append((ind,length))
|
206 |
+
break
|
207 |
+
else:
|
208 |
+
coords.append((ind,ind+window_size))
|
209 |
+
ind += stride
|
210 |
+
return coords
|