Spaces:
Runtime error
Runtime error
chenyangqi
commited on
Commit
·
3060b7e
1
Parent(s):
8094e3b
add FateZero code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- FateZero +0 -1
- FateZero/.gitignore +176 -0
- FateZero/LICENSE.md +21 -0
- FateZero/README.md +393 -0
- FateZero/colab_fatezero.ipynb +0 -0
- FateZero/config/.gitignore +1 -0
- FateZero/config/attribute/bear_tiger_lion_leopard.yaml +108 -0
- FateZero/config/attribute/bus_gpu.yaml +100 -0
- FateZero/config/attribute/cat_tiger_leopard_grass.yaml +112 -0
- FateZero/config/attribute/dog_robotic_corgi.yaml +103 -0
- FateZero/config/attribute/duck_rubber.yaml +99 -0
- FateZero/config/attribute/fox_wolf_snow.yaml +107 -0
- FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml +114 -0
- FateZero/config/attribute/squ_carrot_robot_eggplant.yaml +123 -0
- FateZero/config/attribute/swan_swa.yaml +102 -0
- FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml +83 -0
- FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml +84 -0
- FateZero/config/style/jeep_watercolor.yaml +94 -0
- FateZero/config/style/lily_monet.yaml +93 -0
- FateZero/config/style/rabit_pokemon.yaml +92 -0
- FateZero/config/style/sun_flower_van_gogh.yaml +86 -0
- FateZero/config/style/surf_ukiyo.yaml +90 -0
- FateZero/config/style/swan_cartoon.yaml +101 -0
- FateZero/config/style/train_shinkai.yaml +97 -0
- FateZero/config/teaser/jeep_posche.yaml +93 -0
- FateZero/config/teaser/jeep_watercolor.yaml +94 -0
- FateZero/data/.gitignore +4 -0
- FateZero/data/teaser_car-turn/00000.png +0 -0
- FateZero/data/teaser_car-turn/00001.png +0 -0
- FateZero/data/teaser_car-turn/00002.png +0 -0
- FateZero/data/teaser_car-turn/00003.png +0 -0
- FateZero/data/teaser_car-turn/00004.png +0 -0
- FateZero/data/teaser_car-turn/00005.png +0 -0
- FateZero/data/teaser_car-turn/00006.png +0 -0
- FateZero/data/teaser_car-turn/00007.png +0 -0
- FateZero/docs/EditingGuidance.md +65 -0
- FateZero/docs/OpenSans-Regular.ttf +0 -0
- FateZero/requirements.txt +17 -0
- FateZero/test_fatezero.py +290 -0
- FateZero/test_fatezero_dataset.py +52 -0
- FateZero/test_install.py +23 -0
- FateZero/train_tune_a_video.py +426 -0
- FateZero/video_diffusion/common/image_util.py +203 -0
- FateZero/video_diffusion/common/instantiate_from_config.py +33 -0
- FateZero/video_diffusion/common/logger.py +17 -0
- FateZero/video_diffusion/common/set_seed.py +28 -0
- FateZero/video_diffusion/common/util.py +73 -0
- FateZero/video_diffusion/data/dataset.py +158 -0
- FateZero/video_diffusion/data/transform.py +48 -0
- FateZero/video_diffusion/models/attention.py +482 -0
FateZero
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 6992d238770f464c03a0a74cbcec4f99da4635ec
|
|
|
|
FateZero/.gitignore
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
start_hold
|
2 |
+
chenyangqi
|
3 |
+
trash/**
|
4 |
+
runs*/**
|
5 |
+
result/**
|
6 |
+
ckpt/**
|
7 |
+
ckpt
|
8 |
+
**.whl
|
9 |
+
stable-diffusion-v1-4
|
10 |
+
trash
|
11 |
+
# data/**
|
12 |
+
|
13 |
+
# Initially taken from Github's Python gitignore file
|
14 |
+
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# tests and logs
|
24 |
+
tests/fixtures/cached_*_text.txt
|
25 |
+
logs/
|
26 |
+
lightning_logs/
|
27 |
+
lang_code_data/
|
28 |
+
|
29 |
+
# Distribution / packaging
|
30 |
+
.Python
|
31 |
+
build/
|
32 |
+
develop-eggs/
|
33 |
+
dist/
|
34 |
+
downloads/
|
35 |
+
eggs/
|
36 |
+
.eggs/
|
37 |
+
lib/
|
38 |
+
lib64/
|
39 |
+
parts/
|
40 |
+
sdist/
|
41 |
+
var/
|
42 |
+
wheels/
|
43 |
+
*.egg-info/
|
44 |
+
.installed.cfg
|
45 |
+
*.egg
|
46 |
+
MANIFEST
|
47 |
+
|
48 |
+
# PyInstaller
|
49 |
+
# Usually these files are written by a python script from a template
|
50 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
51 |
+
*.manifest
|
52 |
+
*.spec
|
53 |
+
|
54 |
+
# Installer logs
|
55 |
+
pip-log.txt
|
56 |
+
pip-delete-this-directory.txt
|
57 |
+
|
58 |
+
# Unit test / coverage reports
|
59 |
+
htmlcov/
|
60 |
+
.tox/
|
61 |
+
.nox/
|
62 |
+
.coverage
|
63 |
+
.coverage.*
|
64 |
+
.cache
|
65 |
+
nosetests.xml
|
66 |
+
coverage.xml
|
67 |
+
*.cover
|
68 |
+
.hypothesis/
|
69 |
+
.pytest_cache/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
|
80 |
+
# Flask stuff:
|
81 |
+
instance/
|
82 |
+
.webassets-cache
|
83 |
+
|
84 |
+
# Scrapy stuff:
|
85 |
+
.scrapy
|
86 |
+
|
87 |
+
# Sphinx documentation
|
88 |
+
docs/_build/
|
89 |
+
|
90 |
+
# PyBuilder
|
91 |
+
target/
|
92 |
+
|
93 |
+
# Jupyter Notebook
|
94 |
+
.ipynb_checkpoints
|
95 |
+
|
96 |
+
# IPython
|
97 |
+
profile_default/
|
98 |
+
ipython_config.py
|
99 |
+
|
100 |
+
# pyenv
|
101 |
+
.python-version
|
102 |
+
|
103 |
+
# celery beat schedule file
|
104 |
+
celerybeat-schedule
|
105 |
+
|
106 |
+
# SageMath parsed files
|
107 |
+
*.sage.py
|
108 |
+
|
109 |
+
# Environments
|
110 |
+
.env
|
111 |
+
.venv
|
112 |
+
env/
|
113 |
+
venv/
|
114 |
+
ENV/
|
115 |
+
env.bak/
|
116 |
+
venv.bak/
|
117 |
+
|
118 |
+
# Spyder project settings
|
119 |
+
.spyderproject
|
120 |
+
.spyproject
|
121 |
+
|
122 |
+
# Rope project settings
|
123 |
+
.ropeproject
|
124 |
+
|
125 |
+
# mkdocs documentation
|
126 |
+
/site
|
127 |
+
|
128 |
+
# mypy
|
129 |
+
.mypy_cache/
|
130 |
+
.dmypy.json
|
131 |
+
dmypy.json
|
132 |
+
|
133 |
+
# Pyre type checker
|
134 |
+
.pyre/
|
135 |
+
|
136 |
+
# vscode
|
137 |
+
.vs
|
138 |
+
.vscode
|
139 |
+
|
140 |
+
# Pycharm
|
141 |
+
.idea
|
142 |
+
|
143 |
+
# TF code
|
144 |
+
tensorflow_code
|
145 |
+
|
146 |
+
# Models
|
147 |
+
proc_data
|
148 |
+
|
149 |
+
# examples
|
150 |
+
runs
|
151 |
+
/runs_old
|
152 |
+
/wandb
|
153 |
+
/examples/runs
|
154 |
+
/examples/**/*.args
|
155 |
+
/examples/rag/sweep
|
156 |
+
|
157 |
+
# emacs
|
158 |
+
*.*~
|
159 |
+
debug.env
|
160 |
+
|
161 |
+
# vim
|
162 |
+
.*.swp
|
163 |
+
|
164 |
+
#ctags
|
165 |
+
tags
|
166 |
+
|
167 |
+
# pre-commit
|
168 |
+
.pre-commit*
|
169 |
+
|
170 |
+
# .lock
|
171 |
+
*.lock
|
172 |
+
|
173 |
+
# DS_Store (MacOS)
|
174 |
+
.DS_Store
|
175 |
+
# RL pipelines may produce mp4 outputs
|
176 |
+
*.mp4
|
FateZero/LICENSE.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Chenyang QI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
FateZero/README.md
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FateZero: Fusing Attentions for Zero-shot Text-based Video Editing
|
2 |
+
|
3 |
+
[Chenyang Qi](https://chenyangqiqi.github.io/), [Xiaodong Cun](http://vinthony.github.io/), [Yong Zhang](https://yzhang2016.github.io), [Chenyang Lei](https://chenyanglei.github.io/), [Xintao Wang](https://xinntao.github.io/), [Ying Shan](https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ), and [Qifeng Chen](https://cqf.io)
|
4 |
+
|
5 |
+
<a href='https://arxiv.org/abs/2303.09535'><img src='https://img.shields.io/badge/ArXiv-2303.09535-red'></a>
|
6 |
+
<a href='https://fate-zero-edit.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb)
|
7 |
+
[![GitHub](https://img.shields.io/github/stars/ChenyangQiQi/FateZero?style=social)](https://github.com/ChenyangQiQi/FateZero)
|
8 |
+
|
9 |
+
|
10 |
+
<!-- ![fatezero_demo](./docs/teaser.png) -->
|
11 |
+
|
12 |
+
<table class="center">
|
13 |
+
<td><img src="docs/gif_results/17_car_posche_01_concat_result.gif"></td>
|
14 |
+
<td><img src="docs/gif_results/3_sunflower_vangogh_conat_result.gif"></td>
|
15 |
+
<tr>
|
16 |
+
<td width=25% style="text-align:center;">"silver jeep ➜ posche car"</td>
|
17 |
+
<td width=25% style="text-align:center;">"+ Van Gogh style"</td>
|
18 |
+
<!-- <td width=25% style="text-align:center;">"Wonder Woman, wearing a cowboy hat, is skiing"</td>
|
19 |
+
<td width=25% style="text-align:center;">"A man, wearing pink clothes, is skiing at sunset"</td> -->
|
20 |
+
</tr>
|
21 |
+
</table >
|
22 |
+
|
23 |
+
## Abstract
|
24 |
+
<b>TL;DR: Using FateZero, Edits your video via pretrained Diffusion models without training.</b>
|
25 |
+
|
26 |
+
<details><summary>CLICK for full abstract</summary>
|
27 |
+
|
28 |
+
|
29 |
+
> The diffusion-based generative models have achieved
|
30 |
+
remarkable success in text-based image generation. However,
|
31 |
+
since it contains enormous randomness in generation
|
32 |
+
progress, it is still challenging to apply such models for
|
33 |
+
real-world visual content editing, especially in videos. In
|
34 |
+
this paper, we propose FateZero, a zero-shot text-based editing method on real-world videos without per-prompt
|
35 |
+
training or use-specific mask. To edit videos consistently,
|
36 |
+
we propose several techniques based on the pre-trained
|
37 |
+
models. Firstly, in contrast to the straightforward DDIM
|
38 |
+
inversion technique, our approach captures intermediate
|
39 |
+
attention maps during inversion, which effectively retain
|
40 |
+
both structural and motion information. These maps are
|
41 |
+
directly fused in the editing process rather than generated
|
42 |
+
during denoising. To further minimize semantic leakage of
|
43 |
+
the source video, we then fuse self-attentions with a blending
|
44 |
+
mask obtained by cross-attention features from the source
|
45 |
+
prompt. Furthermore, we have implemented a reform of the
|
46 |
+
self-attention mechanism in denoising UNet by introducing
|
47 |
+
spatial-temporal attention to ensure frame consistency. Yet
|
48 |
+
succinct, our method is the first one to show the ability of
|
49 |
+
zero-shot text-driven video style and local attribute editing
|
50 |
+
from the trained text-to-image model. We also have a better
|
51 |
+
zero-shot shape-aware editing ability based on the text-tovideo
|
52 |
+
model. Extensive experiments demonstrate our
|
53 |
+
superior temporal consistency and editing capability than
|
54 |
+
previous works.
|
55 |
+
</details>
|
56 |
+
|
57 |
+
## Changelog
|
58 |
+
- 2023.03.27 Release [`attribute editing config`](config/attribute) and
|
59 |
+
<!-- [`data`](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/Ee7J2IzZuaVGkefh-ZRp1GwB7RCUYU7MVJCKqeNWmOIpfg?e=dcOwb7) -->
|
60 |
+
[`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip) used in the paper.
|
61 |
+
- 2023.03.22 Upload a `colab notebook` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb). Enjoy the fun of zero-shot video-editing freely!
|
62 |
+
- 2023.03.22 Release [`style editing config`](config/style) and
|
63 |
+
<!--[`data`](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/EaTqRAuW0eJLj0z_JJrURkcBZCC3Zvgsdo6zsXHhpyHhHQ?e=FzuiNG) -->
|
64 |
+
[`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip)
|
65 |
+
used in the paper.
|
66 |
+
- 2023.03.21 [Editing guidance](docs/EditingGuidance.md) is provided to help users to edit in-the-wild video. Welcome to play and give feedback!
|
67 |
+
- 2023.03.21 Update the `codebase and configuration`. Now, it can run with lower resources (16G GPU and less than 16G CPU RAM) with [new configuration](config/low_resource_teaser) in `config/low_resource_teaser`.
|
68 |
+
<!-- A new option store all the attentions in hard disk, which require less ram. -->
|
69 |
+
- 2023.03.17 Release Code and Paper!
|
70 |
+
|
71 |
+
## Todo
|
72 |
+
|
73 |
+
- [x] Release the edit config for teaser
|
74 |
+
- [x] Memory and runtime profiling
|
75 |
+
- [x] Hands-on guidance of hyperparameters tuning
|
76 |
+
- [x] Colab
|
77 |
+
- [x] Release configs for other result and in-the-wild dataset
|
78 |
+
<!-- - [x] Style editing: done
|
79 |
+
- [-] Attribute editing: in progress -->
|
80 |
+
- [-] hugging-face: inprogress
|
81 |
+
- [ ] Tune-a-video optimization and shape editing configs
|
82 |
+
- [ ] Release more application
|
83 |
+
|
84 |
+
## Setup Environment
|
85 |
+
Our method is tested using cuda11, fp16 of accelerator and xformers on a single A100 or 3090.
|
86 |
+
|
87 |
+
```bash
|
88 |
+
conda create -n fatezero38 python=3.8
|
89 |
+
conda activate fatezero38
|
90 |
+
|
91 |
+
pip install -r requirements.txt
|
92 |
+
```
|
93 |
+
|
94 |
+
`xformers` is recommended for A100 GPU to save memory and running time.
|
95 |
+
|
96 |
+
<details><summary>Click for xformers installation </summary>
|
97 |
+
|
98 |
+
We find its installation not stable. You may try the following wheel:
|
99 |
+
```bash
|
100 |
+
wget https://github.com/ShivamShrirao/xformers-wheels/releases/download/4c06c79/xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl
|
101 |
+
pip install xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl
|
102 |
+
```
|
103 |
+
|
104 |
+
</details>
|
105 |
+
|
106 |
+
Validate the installation by
|
107 |
+
```
|
108 |
+
python test_install.py
|
109 |
+
```
|
110 |
+
|
111 |
+
Our environment is similar to Tune-A-video ([official](https://github.com/showlab/Tune-A-Video), [unofficial](https://github.com/bryandlee/Tune-A-Video)) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). You may check them for more details.
|
112 |
+
|
113 |
+
|
114 |
+
## FateZero Editing
|
115 |
+
|
116 |
+
#### Style and Attribute Editing in Teaser
|
117 |
+
|
118 |
+
Download the [stable diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) (or other interesting image diffusion model) and put it to `./ckpt/stable-diffusion-v1-4`.
|
119 |
+
|
120 |
+
<details><summary>Click for bash command: </summary>
|
121 |
+
|
122 |
+
```
|
123 |
+
mkdir ./ckpt
|
124 |
+
# download from huggingface face, takes 20G space
|
125 |
+
git lfs install
|
126 |
+
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
127 |
+
cd ./ckpt
|
128 |
+
ln -s ../stable-diffusion-v1-4 .
|
129 |
+
```
|
130 |
+
</details>
|
131 |
+
|
132 |
+
Then, you could reproduce style and shape editing result in our teaser by running:
|
133 |
+
|
134 |
+
```bash
|
135 |
+
accelerate launch test_fatezero.py --config config/teaser/jeep_watercolor.yaml
|
136 |
+
# or CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
|
137 |
+
```
|
138 |
+
|
139 |
+
<details><summary>The result is saved at `./result` . (Click for directory structure) </summary>
|
140 |
+
|
141 |
+
```
|
142 |
+
result
|
143 |
+
├── teaser
|
144 |
+
│ ├── jeep_posche
|
145 |
+
│ ├── jeep_watercolor
|
146 |
+
│ ├── cross-attention # visualization of cross-attention during inversion
|
147 |
+
│ ├── sample # result
|
148 |
+
│ ├── train_samples # the input video
|
149 |
+
|
150 |
+
```
|
151 |
+
|
152 |
+
</details>
|
153 |
+
|
154 |
+
Editing 8 frames on an Nvidia 3090, use `100G CPU memory, 12G GPU memory` for editing. We also provide some [`low cost setting`](config/low_resource_teaser) of style editing by different hyper-parameters on a 16GB GPU.
|
155 |
+
You may try these low cost setting on colab.
|
156 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb)
|
157 |
+
|
158 |
+
More the speed and hardware benchmark [here](docs/EditingGuidance.md#ddim-hyperparameters).
|
159 |
+
|
160 |
+
#### Shape and large motion editing with Tune-A-Video
|
161 |
+
|
162 |
+
Besides style and attribution editing above, we also provide a `Tune-A-Video` [checkpoint](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EviSTWoAOs1EmHtqZruq50kBZu1E8gxDknCPigSvsS96uQ?e=492khj). You may download the it and move it to `./ckpt/jeep_tuned_200/`.
|
163 |
+
<!-- We provide the [Tune-a-Video](https://drive.google.com/file/d/166eNbabM6TeJVy7hxol2gL1kUGKHi3Do/view?usp=share_link), you could download the data, unzip and put it to `data`. : -->
|
164 |
+
|
165 |
+
<details><summary>The directory structure should like this: (Click for directory structure) </summary>
|
166 |
+
|
167 |
+
```
|
168 |
+
ckpt
|
169 |
+
├── stable-diffusion-v1-4
|
170 |
+
├── jeep_tuned_200
|
171 |
+
...
|
172 |
+
data
|
173 |
+
├── car-turn
|
174 |
+
│ ├── 00000000.png
|
175 |
+
│ ├── 00000001.png
|
176 |
+
│ ├── ...
|
177 |
+
video_diffusion
|
178 |
+
```
|
179 |
+
</details>
|
180 |
+
|
181 |
+
You could reproduce the shape editing result in our teaser by running:
|
182 |
+
|
183 |
+
```bash
|
184 |
+
accelerate launch test_fatezero.py --config config/teaser/jeep_posche.yaml
|
185 |
+
```
|
186 |
+
|
187 |
+
|
188 |
+
### Reproduce other results in the paper (in progress)
|
189 |
+
<!-- Download the data of [style editing](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/EaTqRAuW0eJLj0z_JJrURkcBZCC3Zvgsdo6zsXHhpyHhHQ?e=FzuiNG) and [attribute editing](https://hkustconnect-my.sharepoint.com/:u:/g/personal/cqiaa_connect_ust_hk/Ee7J2IzZuaVGkefh-ZRp1GwB7RCUYU7MVJCKqeNWmOIpfg?e=dcOwb7)
|
190 |
+
-->
|
191 |
+
Download the data of style editing and attribute editing
|
192 |
+
from [onedrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EkIeHj3CQiBNhm6iEEhJQZwBEBJNCGt3FsANmyqeAYbuXQ?e=FxYtJk) or from Github [Release](https://github.com/ChenyangQiQi/FateZero/releases/tag/v0.0.1).
|
193 |
+
<details><summary>Click for wget bash command: </summary>
|
194 |
+
|
195 |
+
```
|
196 |
+
wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip
|
197 |
+
wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip
|
198 |
+
```
|
199 |
+
</details>
|
200 |
+
|
201 |
+
Unzip and Place it in ['./data'](data). Then use the command in ['config/style'](config/style) and ['config/attribute'](config/attribute) to get the results.
|
202 |
+
|
203 |
+
The config of our tune-a-video ckpts will be updated latter.
|
204 |
+
|
205 |
+
## Tuning guidance to edit YOUR video
|
206 |
+
We provided a tuning guidance to edit in-the-wild video at [here](./docs/EditingGuidance.md). The work is still in progress. Welcome to give your feedback in issues.
|
207 |
+
|
208 |
+
## Style Editing Results with Stable Diffusion
|
209 |
+
We show the difference of source prompt and target prompt in the box below each video.
|
210 |
+
|
211 |
+
Note mp4 and gif files in this github page are compressed.
|
212 |
+
Please check our [Project Page](https://fate-zero-edit.github.io/) for mp4 files of original video editing results.
|
213 |
+
<table class="center">
|
214 |
+
|
215 |
+
<tr>
|
216 |
+
<td><img src="docs/gif_results/style/1_surf_ukiyo_01_concat_result.gif"></td>
|
217 |
+
<td><img src="docs/gif_results/style/2_car_watercolor_01_concat_result.gif"></td>
|
218 |
+
<td><img src="docs/gif_results/style/6_lily_monet_01_concat_result.gif"></td>
|
219 |
+
<!-- <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/wonder-woman.gif"></td>
|
220 |
+
<td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/pink-sunset.gif"></td> -->
|
221 |
+
</tr>
|
222 |
+
<tr>
|
223 |
+
<td width=25% style="text-align:center;">"+ Ukiyo-e style"</td>
|
224 |
+
<td width=25% style="text-align:center;">"+ watercolor painting"</td>
|
225 |
+
<td width=25% style="text-align:center;">"+ Monet style"</td>
|
226 |
+
</tr>
|
227 |
+
|
228 |
+
<tr>
|
229 |
+
<td><img src="docs/gif_results/style/4_rabit_pokemon_01_concat_result.gif"></td>
|
230 |
+
<td><img src="docs/gif_results/style/5_train_shikai_01_concat_result.gif"></td>
|
231 |
+
<td><img src="docs/gif_results/style/7_swan_carton_01_concat_result.gif"></td>
|
232 |
+
|
233 |
+
</tr>
|
234 |
+
<tr>
|
235 |
+
|
236 |
+
</tr>
|
237 |
+
<tr>
|
238 |
+
<td width=25% style="text-align:center;">"+ Pokémon cartoon style"</td>
|
239 |
+
<td width=25% style="text-align:center;">"+ Makoto Shinkai style"</td>
|
240 |
+
<td width=25% style="text-align:center;">"+ cartoon style"</td>
|
241 |
+
</tr>
|
242 |
+
</table>
|
243 |
+
|
244 |
+
## Attribute Editing Results with Stable Diffusion
|
245 |
+
<table class="center">
|
246 |
+
|
247 |
+
<tr>
|
248 |
+
|
249 |
+
<td><img src="docs/gif_results/attri/15_rabbit_eat_01_concat_result.gif"></td>
|
250 |
+
<td><img src="docs/gif_results/attri/15_rabbit_eat_02_concat_result.gif"></td>
|
251 |
+
<td><img src="docs/gif_results/attri/15_rabbit_eat_04_concat_result.gif"></td>
|
252 |
+
|
253 |
+
</tr>
|
254 |
+
<tr>
|
255 |
+
<td width=25% style="text-align:center;">"rabbit, strawberry ➜ white rabbit, flower"</td>
|
256 |
+
<td width=25% style="text-align:center;">"rabbit, strawberry ➜ squirrel, carrot"</td>
|
257 |
+
<td width=25% style="text-align:center;">"rabbit, strawberry ➜ white rabbit, leaves"</td>
|
258 |
+
|
259 |
+
</tr>
|
260 |
+
<tr>
|
261 |
+
|
262 |
+
<td><img src="docs/gif_results/attri/16_sq_eat_04_concat_result.gif"></td>
|
263 |
+
<td><img src="docs/gif_results/attri/16_sq_eat_02_concat_result.gif"></td>
|
264 |
+
<td><img src="docs/gif_results/attri/16_sq_eat_03_concat_result.gif"></td>
|
265 |
+
|
266 |
+
</tr>
|
267 |
+
<tr>
|
268 |
+
<td width=25% style="text-align:center;">"squirrel ➜ robot squirrel"</td>
|
269 |
+
<td width=25% style="text-align:center;">"squirrel, Carrot ➜ rabbit, eggplant"</td>
|
270 |
+
<td width=25% style="text-align:center;">"squirrel, Carrot ➜ robot mouse, screwdriver"</td>
|
271 |
+
|
272 |
+
</tr>
|
273 |
+
|
274 |
+
<tr>
|
275 |
+
|
276 |
+
<td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_01_concat_result.gif"></td>
|
277 |
+
<td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_02_concat_result.gif"></td>
|
278 |
+
<td><img src="docs/gif_results/attri/13_bear_tiger_leopard_lion_03_concat_result.gif"></td>
|
279 |
+
|
280 |
+
</tr>
|
281 |
+
<tr>
|
282 |
+
<td width=25% style="text-align:center;">"bear ➜ a red tiger"</td>
|
283 |
+
<td width=25% style="text-align:center;">"bear ➜ a yellow leopard"</td>
|
284 |
+
<td width=25% style="text-align:center;">"bear ➜ a brown lion"</td>
|
285 |
+
|
286 |
+
</tr>
|
287 |
+
<tr>
|
288 |
+
|
289 |
+
<td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_02_concat_result.gif"></td>
|
290 |
+
<td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_03_concat_result.gif"></td>
|
291 |
+
<td><img src="docs/gif_results/attri/14_cat_grass_tiger_corgin_04_concat_result.gif"></td>
|
292 |
+
|
293 |
+
</tr>
|
294 |
+
<tr>
|
295 |
+
<td width=25% style="text-align:center;">"cat ➜ black cat, grass..."</td>
|
296 |
+
<td width=25% style="text-align:center;">"cat ➜ red tiger"</td>
|
297 |
+
<td width=25% style="text-align:center;">"cat ➜ Shiba-Inu"</td>
|
298 |
+
|
299 |
+
</tr>
|
300 |
+
|
301 |
+
<tr>
|
302 |
+
|
303 |
+
<td><img src="docs/gif_results/attri/10_bus_gpu_01_concat_result.gif"></td>
|
304 |
+
<td><img src="docs/gif_results/attri/11_dog_robotic_corgin_01_concat_result.gif"></td>
|
305 |
+
<td><img src="docs/gif_results/attri/11_dog_robotic_corgin_02_concat_result.gif"></td>
|
306 |
+
|
307 |
+
</tr>
|
308 |
+
<tr>
|
309 |
+
<td width=25% style="text-align:center;">"bus ➜ GPU"</td>
|
310 |
+
<td width=25% style="text-align:center;">"gray dog ➜ yellow corgi"</td>
|
311 |
+
<td width=25% style="text-align:center;">"gray dog ➜ robotic dog"</td>
|
312 |
+
|
313 |
+
</tr>
|
314 |
+
<tr>
|
315 |
+
|
316 |
+
<td><img src="docs/gif_results/attri/9_duck_rubber_01_concat_result.gif"></td>
|
317 |
+
<td><img src="docs/gif_results/attri/12_fox_snow_wolf_01_concat_result.gif"></td>
|
318 |
+
<td><img src="docs/gif_results/attri/12_fox_snow_wolf_02_concat_result.gif"></td>
|
319 |
+
|
320 |
+
</tr>
|
321 |
+
<tr>
|
322 |
+
<td width=25% style="text-align:center;">"white duck ➜ yellow rubber duck"</td>
|
323 |
+
<td width=25% style="text-align:center;">"grass ➜ snow"</td>
|
324 |
+
<td width=25% style="text-align:center;">"white fox ➜ grey wolf"</td>
|
325 |
+
|
326 |
+
</tr>
|
327 |
+
|
328 |
+
|
329 |
+
</table>
|
330 |
+
|
331 |
+
## Shape and large motion editing with Tune-A-Video
|
332 |
+
<table class="center">
|
333 |
+
|
334 |
+
<tr>
|
335 |
+
<td><img src="docs/gif_results/shape/17_car_posche_01_concat_result.gif"></td>
|
336 |
+
<td><img src="docs/gif_results/shape/18_swan_01_concat_result.gif"></td>
|
337 |
+
<td><img src="docs/gif_results/shape/18_swan_02_concat_result.gif"></td>
|
338 |
+
<!-- <td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/wonder-woman.gif"></td>
|
339 |
+
<td><img src="https://tuneavideo.github.io/assets/results/tuneavideo/man-skiing/pink-sunset.gif"></td> -->
|
340 |
+
</tr>
|
341 |
+
<tr>
|
342 |
+
<td width=25% style="text-align:center;">"silver jeep ➜ posche car"</td>
|
343 |
+
<td width=25% style="text-align:center;">"Swan ➜ White Duck"</td>
|
344 |
+
<td width=25% style="text-align:center;">"Swan ➜ Pink flamingo"</td>
|
345 |
+
</tr>
|
346 |
+
|
347 |
+
<tr>
|
348 |
+
<td><img src="docs/gif_results/shape/19_man_wonder_01_concat_result.gif"></td>
|
349 |
+
<td><img src="docs/gif_results/shape/19_man_wonder_02_concat_result.gif"></td>
|
350 |
+
<td><img src="docs/gif_results/shape/19_man_wonder_03_concat_result.gif"></td>
|
351 |
+
|
352 |
+
</tr>
|
353 |
+
<tr>
|
354 |
+
|
355 |
+
</tr>
|
356 |
+
<tr>
|
357 |
+
<td width=25% style="text-align:center;">"A man ➜ A Batman"</td>
|
358 |
+
<td width=25% style="text-align:center;">"A man ➜ A Wonder Woman, With cowboy hat"</td>
|
359 |
+
<td width=25% style="text-align:center;">"A man ➜ A Spider-Man"</td>
|
360 |
+
</tr>
|
361 |
+
</table>
|
362 |
+
|
363 |
+
|
364 |
+
## Demo Video
|
365 |
+
|
366 |
+
https://user-images.githubusercontent.com/45789244/225698509-79c14793-3153-4bba-9d6e-ede7d811d7f8.mp4
|
367 |
+
|
368 |
+
The video here is compressed due to the size limit of github.
|
369 |
+
The original full resolution video is [here](https://hkustconnect-my.sharepoint.com/:v:/g/personal/cqiaa_connect_ust_hk/EXKDI_nahEhKtiYPvvyU9SkBDTG2W4G1AZ_vkC7ekh3ENw?e=Xhgtmk).
|
370 |
+
|
371 |
+
|
372 |
+
## Citation
|
373 |
+
|
374 |
+
```
|
375 |
+
@misc{qi2023fatezero,
|
376 |
+
title={FateZero: Fusing Attentions for Zero-shot Text-based Video Editing},
|
377 |
+
author={Chenyang Qi and Xiaodong Cun and Yong Zhang and Chenyang Lei and Xintao Wang and Ying Shan and Qifeng Chen},
|
378 |
+
year={2023},
|
379 |
+
eprint={2303.09535},
|
380 |
+
archivePrefix={arXiv},
|
381 |
+
primaryClass={cs.CV}
|
382 |
+
}
|
383 |
+
```
|
384 |
+
|
385 |
+
|
386 |
+
## Acknowledgements
|
387 |
+
|
388 |
+
This repository borrows heavily from [Tune-A-Video](https://github.com/showlab/Tune-A-Video) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). thanks the authors for sharing their code and models.
|
389 |
+
|
390 |
+
## Maintenance
|
391 |
+
|
392 |
+
This is the codebase for our research work. We are still working hard to update this repo and more details are coming in days. If you have any questions or ideas to discuss, feel free to contact [Chenyang Qi]([email protected]) or [Xiaodong Cun]([email protected]).
|
393 |
+
|
FateZero/colab_fatezero.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
FateZero/config/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# debug/**
|
FateZero/config/attribute/bear_tiger_lion_leopard.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bear_tiger_lion_leopard.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/bear_tiger_lion_leopard"
|
8 |
+
prompt: "a brown bear walking on the rock against a wall"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: True
|
21 |
+
use_inversion_attention: True
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
a brown bear walking on the rock against a wall,
|
26 |
+
|
27 |
+
# foreground texture style
|
28 |
+
a red tiger walking on the rock against a wall,
|
29 |
+
a yellow leopard walking on the rock against a wall,
|
30 |
+
a brown lion walking on the rock against a wall,
|
31 |
+
]
|
32 |
+
p2p_config:
|
33 |
+
0:
|
34 |
+
# Whether to directly copy the cross attention from source
|
35 |
+
# True: directly copy, better for object replacement
|
36 |
+
# False: keep source attention, better for style
|
37 |
+
is_replace_controller: False
|
38 |
+
|
39 |
+
# Semantic preserving and replacement Debug me
|
40 |
+
cross_replace_steps:
|
41 |
+
default_: 0.8
|
42 |
+
|
43 |
+
# Source background structure preserving, in [0, 1].
|
44 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
45 |
+
self_replace_steps: 0.6
|
46 |
+
|
47 |
+
|
48 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
49 |
+
eq_params:
|
50 |
+
words: ["silver", "sculpture"]
|
51 |
+
values: [2,2]
|
52 |
+
|
53 |
+
# Target structure-divergence hyperparames
|
54 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
55 |
+
# Without following three lines, all self-attention will be replaced
|
56 |
+
blend_words: [['cat',], ["cat",]]
|
57 |
+
masked_self_attention: True
|
58 |
+
# masked_latents: False # performance not so good in our case, need debug
|
59 |
+
bend_th: [2, 2]
|
60 |
+
# preserve source structure of blend_words , [0, 1]
|
61 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
62 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
63 |
+
|
64 |
+
|
65 |
+
1:
|
66 |
+
is_replace_controller: true
|
67 |
+
cross_replace_steps:
|
68 |
+
default_: 0.7
|
69 |
+
self_replace_steps: 0.7
|
70 |
+
2:
|
71 |
+
is_replace_controller: true
|
72 |
+
cross_replace_steps:
|
73 |
+
default_: 0.7
|
74 |
+
self_replace_steps: 0.7
|
75 |
+
3:
|
76 |
+
is_replace_controller: true
|
77 |
+
cross_replace_steps:
|
78 |
+
default_: 0.7
|
79 |
+
self_replace_steps: 0.7
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
85 |
+
sample_seeds: [0]
|
86 |
+
val_all_frames: False
|
87 |
+
|
88 |
+
num_inference_steps: 50
|
89 |
+
prompt2prompt_edit: True
|
90 |
+
|
91 |
+
|
92 |
+
model_config:
|
93 |
+
lora: 160
|
94 |
+
# temporal_downsample_time: 4
|
95 |
+
SparseCausalAttention_index: ['mid']
|
96 |
+
least_sc_channel: 640
|
97 |
+
# least_sc_channel: 100000
|
98 |
+
|
99 |
+
test_pipeline_config:
|
100 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
101 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
102 |
+
|
103 |
+
epsilon: 1e-5
|
104 |
+
train_steps: 10
|
105 |
+
seed: 0
|
106 |
+
learning_rate: 1e-5
|
107 |
+
train_temporal_conv: False
|
108 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/bus_gpu.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bus_gpu.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/bus_gpu"
|
8 |
+
prompt: "a white and blue bus on the road"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: True
|
21 |
+
use_inversion_attention: True
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
a white and blue bus on the road,
|
26 |
+
|
27 |
+
# foreground texture style
|
28 |
+
a black and green GPU on the road
|
29 |
+
]
|
30 |
+
p2p_config:
|
31 |
+
0:
|
32 |
+
# Whether to directly copy the cross attention from source
|
33 |
+
# True: directly copy, better for object replacement
|
34 |
+
# False: keep source attention, better for style
|
35 |
+
is_replace_controller: False
|
36 |
+
|
37 |
+
# Semantic preserving and replacement Debug me
|
38 |
+
cross_replace_steps:
|
39 |
+
default_: 0.8
|
40 |
+
|
41 |
+
# Source background structure preserving, in [0, 1].
|
42 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
43 |
+
self_replace_steps: 0.6
|
44 |
+
|
45 |
+
|
46 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
47 |
+
eq_params:
|
48 |
+
words: ["silver", "sculpture"]
|
49 |
+
values: [2,2]
|
50 |
+
|
51 |
+
# Target structure-divergence hyperparames
|
52 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
53 |
+
# Without following three lines, all self-attention will be replaced
|
54 |
+
blend_words: [['cat',], ["cat",]]
|
55 |
+
masked_self_attention: True
|
56 |
+
# masked_latents: False # performance not so good in our case, need debug
|
57 |
+
bend_th: [2, 2]
|
58 |
+
# preserve source structure of blend_words , [0, 1]
|
59 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
60 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
61 |
+
|
62 |
+
|
63 |
+
1:
|
64 |
+
is_replace_controller: true
|
65 |
+
cross_replace_steps:
|
66 |
+
default_: 0.1
|
67 |
+
self_replace_steps: 0.1
|
68 |
+
|
69 |
+
eq_params:
|
70 |
+
words: ["Nvidia", "GPU"]
|
71 |
+
values: [10, 10] # amplify attention to the word "tiger" by *2
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
77 |
+
sample_seeds: [0]
|
78 |
+
val_all_frames: False
|
79 |
+
|
80 |
+
num_inference_steps: 50
|
81 |
+
prompt2prompt_edit: True
|
82 |
+
|
83 |
+
|
84 |
+
model_config:
|
85 |
+
lora: 160
|
86 |
+
# temporal_downsample_time: 4
|
87 |
+
SparseCausalAttention_index: ['mid']
|
88 |
+
least_sc_channel: 640
|
89 |
+
# least_sc_channel: 100000
|
90 |
+
|
91 |
+
test_pipeline_config:
|
92 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
93 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
94 |
+
|
95 |
+
epsilon: 1e-5
|
96 |
+
train_steps: 10
|
97 |
+
seed: 0
|
98 |
+
learning_rate: 1e-5
|
99 |
+
train_temporal_conv: False
|
100 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/cat_tiger_leopard_grass.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/cat_tiger_leopard_grass.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/cat_tiger_leopard_grass"
|
8 |
+
prompt: "A black cat walking on the floor next to a wall"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: True
|
21 |
+
use_inversion_attention: True
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
A black cat walking on the floor next to a wall,
|
26 |
+
A black cat walking on the grass next to a wall,
|
27 |
+
A red tiger walking on the floor next to a wall,
|
28 |
+
a yellow cute Shiba-Inu walking on the floor next to a wall,
|
29 |
+
a yellow cute leopard walking on the floor next to a wall,
|
30 |
+
]
|
31 |
+
p2p_config:
|
32 |
+
0:
|
33 |
+
# Whether to directly copy the cross attention from source
|
34 |
+
# True: directly copy, better for object replacement
|
35 |
+
# False: keep source attention, better for style
|
36 |
+
is_replace_controller: False
|
37 |
+
|
38 |
+
# Semantic preserving and replacement Debug me
|
39 |
+
cross_replace_steps:
|
40 |
+
default_: 0.8
|
41 |
+
|
42 |
+
# Source background structure preserving, in [0, 1].
|
43 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
44 |
+
self_replace_steps: 0.6
|
45 |
+
|
46 |
+
|
47 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
48 |
+
eq_params:
|
49 |
+
words: ["silver", "sculpture"]
|
50 |
+
values: [2,2]
|
51 |
+
|
52 |
+
# Target structure-divergence hyperparames
|
53 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
54 |
+
# Without following three lines, all self-attention will be replaced
|
55 |
+
blend_words: [['cat',], ["cat",]]
|
56 |
+
masked_self_attention: True
|
57 |
+
# masked_latents: False # performance not so good in our case, need debug
|
58 |
+
bend_th: [2, 2]
|
59 |
+
# preserve source structure of blend_words , [0, 1]
|
60 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
61 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
62 |
+
|
63 |
+
|
64 |
+
1:
|
65 |
+
is_replace_controller: false
|
66 |
+
cross_replace_steps:
|
67 |
+
default_: 0.5
|
68 |
+
self_replace_steps: 0.5
|
69 |
+
2:
|
70 |
+
is_replace_controller: false
|
71 |
+
cross_replace_steps:
|
72 |
+
default_: 0.5
|
73 |
+
self_replace_steps: 0.5
|
74 |
+
3:
|
75 |
+
is_replace_controller: false
|
76 |
+
cross_replace_steps:
|
77 |
+
default_: 0.5
|
78 |
+
self_replace_steps: 0.5
|
79 |
+
4:
|
80 |
+
is_replace_controller: false
|
81 |
+
cross_replace_steps:
|
82 |
+
default_: 0.7
|
83 |
+
self_replace_steps: 0.7
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
89 |
+
sample_seeds: [0]
|
90 |
+
val_all_frames: False
|
91 |
+
|
92 |
+
num_inference_steps: 50
|
93 |
+
prompt2prompt_edit: True
|
94 |
+
|
95 |
+
|
96 |
+
model_config:
|
97 |
+
lora: 160
|
98 |
+
# temporal_downsample_time: 4
|
99 |
+
SparseCausalAttention_index: ['mid']
|
100 |
+
least_sc_channel: 640
|
101 |
+
# least_sc_channel: 100000
|
102 |
+
|
103 |
+
test_pipeline_config:
|
104 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
105 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
106 |
+
|
107 |
+
epsilon: 1e-5
|
108 |
+
train_steps: 10
|
109 |
+
seed: 0
|
110 |
+
learning_rate: 1e-5
|
111 |
+
train_temporal_conv: False
|
112 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/dog_robotic_corgi.yaml
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/dog_robotic_corgi.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/attribute/gray_dog"
|
7 |
+
prompt: "A gray dog sitting on the mat"
|
8 |
+
n_sample_frame: 8
|
9 |
+
# n_sample_frame: 22
|
10 |
+
sampling_rate: 1
|
11 |
+
stride: 80
|
12 |
+
offset:
|
13 |
+
left: 0
|
14 |
+
right: 0
|
15 |
+
top: 0
|
16 |
+
bottom: 0
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: True
|
20 |
+
use_inversion_attention: True
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
# source prompt
|
24 |
+
A gray dog sitting on the mat,
|
25 |
+
|
26 |
+
# foreground texture style
|
27 |
+
A robotic dog sitting on the mat,
|
28 |
+
A yellow corgi sitting on the mat
|
29 |
+
]
|
30 |
+
p2p_config:
|
31 |
+
0:
|
32 |
+
# Whether to directly copy the cross attention from source
|
33 |
+
# True: directly copy, better for object replacement
|
34 |
+
# False: keep source attention, better for style
|
35 |
+
is_replace_controller: False
|
36 |
+
|
37 |
+
# Semantic preserving and replacement Debug me
|
38 |
+
cross_replace_steps:
|
39 |
+
default_: 0.8
|
40 |
+
|
41 |
+
# Source background structure preserving, in [0, 1].
|
42 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
43 |
+
self_replace_steps: 0.6
|
44 |
+
|
45 |
+
|
46 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
47 |
+
eq_params:
|
48 |
+
words: ["silver", "sculpture"]
|
49 |
+
values: [2,2]
|
50 |
+
|
51 |
+
# Target structure-divergence hyperparames
|
52 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
53 |
+
# Without following three lines, all self-attention will be replaced
|
54 |
+
blend_words: [['cat',], ["cat",]]
|
55 |
+
masked_self_attention: True
|
56 |
+
# masked_latents: False # performance not so good in our case, need debug
|
57 |
+
bend_th: [2, 2]
|
58 |
+
# preserve source structure of blend_words , [0, 1]
|
59 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
60 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
61 |
+
|
62 |
+
|
63 |
+
1:
|
64 |
+
is_replace_controller: false
|
65 |
+
cross_replace_steps:
|
66 |
+
default_: 0.5
|
67 |
+
self_replace_steps: 0.5
|
68 |
+
|
69 |
+
eq_params:
|
70 |
+
words: ["robotic"]
|
71 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
72 |
+
|
73 |
+
2:
|
74 |
+
is_replace_controller: false
|
75 |
+
cross_replace_steps:
|
76 |
+
default_: 0.5
|
77 |
+
self_replace_steps: 0.5
|
78 |
+
|
79 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
80 |
+
sample_seeds: [0]
|
81 |
+
val_all_frames: False
|
82 |
+
|
83 |
+
num_inference_steps: 50
|
84 |
+
prompt2prompt_edit: True
|
85 |
+
|
86 |
+
|
87 |
+
model_config:
|
88 |
+
lora: 160
|
89 |
+
# temporal_downsample_time: 4
|
90 |
+
SparseCausalAttention_index: ['mid']
|
91 |
+
least_sc_channel: 640
|
92 |
+
# least_sc_channel: 100000
|
93 |
+
|
94 |
+
test_pipeline_config:
|
95 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
96 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
97 |
+
|
98 |
+
epsilon: 1e-5
|
99 |
+
train_steps: 10
|
100 |
+
seed: 0
|
101 |
+
learning_rate: 1e-5
|
102 |
+
train_temporal_conv: False
|
103 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/duck_rubber.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/duck_rubber.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/attribute/duck_rubber"
|
7 |
+
prompt: "a sleepy white duck"
|
8 |
+
n_sample_frame: 8
|
9 |
+
# n_sample_frame: 22
|
10 |
+
sampling_rate: 1
|
11 |
+
stride: 80
|
12 |
+
offset:
|
13 |
+
left: 0
|
14 |
+
right: 0
|
15 |
+
top: 0
|
16 |
+
bottom: 0
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: True
|
20 |
+
use_inversion_attention: True
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
# source prompt
|
24 |
+
a sleepy white duck,
|
25 |
+
|
26 |
+
# foreground texture style
|
27 |
+
a sleepy yellow rubber duck
|
28 |
+
]
|
29 |
+
p2p_config:
|
30 |
+
0:
|
31 |
+
# Whether to directly copy the cross attention from source
|
32 |
+
# True: directly copy, better for object replacement
|
33 |
+
# False: keep source attention, better for style
|
34 |
+
is_replace_controller: False
|
35 |
+
|
36 |
+
# Semantic preserving and replacement Debug me
|
37 |
+
cross_replace_steps:
|
38 |
+
default_: 0.8
|
39 |
+
|
40 |
+
# Source background structure preserving, in [0, 1].
|
41 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
42 |
+
self_replace_steps: 0.6
|
43 |
+
|
44 |
+
|
45 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
46 |
+
eq_params:
|
47 |
+
words: ["silver", "sculpture"]
|
48 |
+
values: [2,2]
|
49 |
+
|
50 |
+
# Target structure-divergence hyperparames
|
51 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
52 |
+
# Without following three lines, all self-attention will be replaced
|
53 |
+
blend_words: [['cat',], ["cat",]]
|
54 |
+
masked_self_attention: True
|
55 |
+
# masked_latents: False # performance not so good in our case, need debug
|
56 |
+
bend_th: [2, 2]
|
57 |
+
# preserve source structure of blend_words , [0, 1]
|
58 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
59 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
60 |
+
|
61 |
+
|
62 |
+
1:
|
63 |
+
is_replace_controller: False
|
64 |
+
cross_replace_steps:
|
65 |
+
default_: 0.7
|
66 |
+
self_replace_steps: 0.7
|
67 |
+
|
68 |
+
# eq_params:
|
69 |
+
# words: ["yellow", "rubber"]
|
70 |
+
# values: [10, 10] # amplify attention to the word "tiger" by *2
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
76 |
+
sample_seeds: [0]
|
77 |
+
val_all_frames: False
|
78 |
+
|
79 |
+
num_inference_steps: 50
|
80 |
+
prompt2prompt_edit: True
|
81 |
+
|
82 |
+
|
83 |
+
model_config:
|
84 |
+
lora: 160
|
85 |
+
# temporal_downsample_time: 4
|
86 |
+
SparseCausalAttention_index: ['mid']
|
87 |
+
least_sc_channel: 640
|
88 |
+
# least_sc_channel: 100000
|
89 |
+
|
90 |
+
test_pipeline_config:
|
91 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
92 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
93 |
+
|
94 |
+
epsilon: 1e-5
|
95 |
+
train_steps: 10
|
96 |
+
seed: 0
|
97 |
+
learning_rate: 1e-5
|
98 |
+
train_temporal_conv: False
|
99 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/fox_wolf_snow.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/fox_wolf_snow.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/attribute/fox_wolf_snow"
|
7 |
+
prompt: "a white fox sitting in the grass"
|
8 |
+
n_sample_frame: 8
|
9 |
+
# n_sample_frame: 22
|
10 |
+
sampling_rate: 1
|
11 |
+
stride: 80
|
12 |
+
offset:
|
13 |
+
left: 0
|
14 |
+
right: 0
|
15 |
+
top: 0
|
16 |
+
bottom: 0
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: True
|
20 |
+
use_inversion_attention: True
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
# source prompt
|
24 |
+
a white fox sitting in the grass,
|
25 |
+
|
26 |
+
# foreground texture style
|
27 |
+
a grey wolf sitting in the grass,
|
28 |
+
a white fox sitting in the snow
|
29 |
+
]
|
30 |
+
p2p_config:
|
31 |
+
0:
|
32 |
+
# Whether to directly copy the cross attention from source
|
33 |
+
# True: directly copy, better for object replacement
|
34 |
+
# False: keep source attention, better for style
|
35 |
+
is_replace_controller: False
|
36 |
+
|
37 |
+
# Semantic preserving and replacement Debug me
|
38 |
+
cross_replace_steps:
|
39 |
+
default_: 0.8
|
40 |
+
|
41 |
+
# Source background structure preserving, in [0, 1].
|
42 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
43 |
+
self_replace_steps: 0.6
|
44 |
+
|
45 |
+
|
46 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
47 |
+
eq_params:
|
48 |
+
words: ["silver", "sculpture"]
|
49 |
+
values: [2,2]
|
50 |
+
|
51 |
+
# Target structure-divergence hyperparames
|
52 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
53 |
+
# Without following three lines, all self-attention will be replaced
|
54 |
+
blend_words: [['cat',], ["cat",]]
|
55 |
+
masked_self_attention: True
|
56 |
+
# masked_latents: False # performance not so good in our case, need debug
|
57 |
+
bend_th: [2, 2]
|
58 |
+
# preserve source structure of blend_words , [0, 1]
|
59 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
60 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
61 |
+
|
62 |
+
|
63 |
+
1:
|
64 |
+
is_replace_controller: false
|
65 |
+
cross_replace_steps:
|
66 |
+
default_: 0.5
|
67 |
+
self_replace_steps: 0.5
|
68 |
+
|
69 |
+
eq_params:
|
70 |
+
words: ["robotic"]
|
71 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
72 |
+
|
73 |
+
2:
|
74 |
+
is_replace_controller: false
|
75 |
+
cross_replace_steps:
|
76 |
+
default_: 0.5
|
77 |
+
self_replace_steps: 0.5
|
78 |
+
eq_params:
|
79 |
+
words: ["snow"]
|
80 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
81 |
+
|
82 |
+
|
83 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
84 |
+
sample_seeds: [0]
|
85 |
+
val_all_frames: False
|
86 |
+
|
87 |
+
num_inference_steps: 50
|
88 |
+
prompt2prompt_edit: True
|
89 |
+
|
90 |
+
|
91 |
+
model_config:
|
92 |
+
lora: 160
|
93 |
+
# temporal_downsample_time: 4
|
94 |
+
SparseCausalAttention_index: ['mid']
|
95 |
+
least_sc_channel: 640
|
96 |
+
# least_sc_channel: 100000
|
97 |
+
|
98 |
+
test_pipeline_config:
|
99 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
100 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
101 |
+
|
102 |
+
epsilon: 1e-5
|
103 |
+
train_steps: 10
|
104 |
+
seed: 0
|
105 |
+
learning_rate: 1e-5
|
106 |
+
train_temporal_conv: False
|
107 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=1 python test_fatezero.py --config config/attribute/rabbit_straberry_leaves_flowers.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/rabbit_strawberry"
|
8 |
+
prompt: "A rabbit is eating strawberries"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: True
|
21 |
+
use_inversion_attention: True
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
A rabbit is eating strawberries,
|
26 |
+
|
27 |
+
# foreground texture style
|
28 |
+
A white rabbit is eating leaves,
|
29 |
+
A white rabbit is eating flower,
|
30 |
+
A white rabbit is eating orange,
|
31 |
+
|
32 |
+
# a brown lion walking on the rock against a wall,
|
33 |
+
]
|
34 |
+
p2p_config:
|
35 |
+
0:
|
36 |
+
# Whether to directly copy the cross attention from source
|
37 |
+
# True: directly copy, better for object replacement
|
38 |
+
# False: keep source attention, better for style
|
39 |
+
is_replace_controller: False
|
40 |
+
|
41 |
+
# Semantic preserving and replacement Debug me
|
42 |
+
cross_replace_steps:
|
43 |
+
default_: 0.8
|
44 |
+
|
45 |
+
# Source background structure preserving, in [0, 1].
|
46 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
47 |
+
self_replace_steps: 0.6
|
48 |
+
|
49 |
+
|
50 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
51 |
+
eq_params:
|
52 |
+
words: ["silver", "sculpture"]
|
53 |
+
values: [2,2]
|
54 |
+
|
55 |
+
# Target structure-divergence hyperparames
|
56 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
57 |
+
# Without following three lines, all self-attention will be replaced
|
58 |
+
blend_words: [['cat',], ["cat",]]
|
59 |
+
masked_self_attention: True
|
60 |
+
# masked_latents: False # performance not so good in our case, need debug
|
61 |
+
bend_th: [2, 2]
|
62 |
+
# preserve source structure of blend_words , [0, 1]
|
63 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
64 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
65 |
+
1:
|
66 |
+
is_replace_controller: false
|
67 |
+
cross_replace_steps:
|
68 |
+
default_: 0.5
|
69 |
+
self_replace_steps: 0.5
|
70 |
+
eq_params:
|
71 |
+
words: ["leaves"]
|
72 |
+
values: [10]
|
73 |
+
2:
|
74 |
+
is_replace_controller: false
|
75 |
+
cross_replace_steps:
|
76 |
+
default_: 0.5
|
77 |
+
self_replace_steps: 0.5
|
78 |
+
eq_params:
|
79 |
+
words: ["flower"]
|
80 |
+
values: [10]
|
81 |
+
3:
|
82 |
+
is_replace_controller: false
|
83 |
+
cross_replace_steps:
|
84 |
+
default_: 0.5
|
85 |
+
self_replace_steps: 0.5
|
86 |
+
eq_params:
|
87 |
+
words: ["orange"]
|
88 |
+
values: [10]
|
89 |
+
|
90 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
91 |
+
sample_seeds: [0]
|
92 |
+
val_all_frames: False
|
93 |
+
|
94 |
+
num_inference_steps: 50
|
95 |
+
prompt2prompt_edit: True
|
96 |
+
|
97 |
+
|
98 |
+
model_config:
|
99 |
+
lora: 160
|
100 |
+
# temporal_downsample_time: 4
|
101 |
+
SparseCausalAttention_index: ['mid']
|
102 |
+
least_sc_channel: 640
|
103 |
+
# least_sc_channel: 100000
|
104 |
+
|
105 |
+
test_pipeline_config:
|
106 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
107 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
108 |
+
|
109 |
+
epsilon: 1e-5
|
110 |
+
train_steps: 10
|
111 |
+
seed: 0
|
112 |
+
learning_rate: 1e-5
|
113 |
+
train_temporal_conv: False
|
114 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/squ_carrot_robot_eggplant.yaml
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/squ_carrot_robot_eggplant.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/squirrel_carrot"
|
8 |
+
prompt: "A squirrel is eating a carrot"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: True
|
21 |
+
use_inversion_attention: True
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
A squirrel is eating a carrot,
|
26 |
+
A robot squirrel is eating a carrot,
|
27 |
+
A rabbit is eating a eggplant,
|
28 |
+
A robot mouse is eating a screwdriver,
|
29 |
+
A white mouse is eating a peanut,
|
30 |
+
]
|
31 |
+
p2p_config:
|
32 |
+
0:
|
33 |
+
# Whether to directly copy the cross attention from source
|
34 |
+
# True: directly copy, better for object replacement
|
35 |
+
# False: keep source attention, better for style
|
36 |
+
is_replace_controller: False
|
37 |
+
|
38 |
+
# Semantic preserving and replacement Debug me
|
39 |
+
cross_replace_steps:
|
40 |
+
default_: 0.8
|
41 |
+
|
42 |
+
# Source background structure preserving, in [0, 1].
|
43 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
44 |
+
self_replace_steps: 0.6
|
45 |
+
|
46 |
+
|
47 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
48 |
+
eq_params:
|
49 |
+
words: ["silver", "sculpture"]
|
50 |
+
values: [2,2]
|
51 |
+
|
52 |
+
# Target structure-divergence hyperparames
|
53 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
54 |
+
# Without following three lines, all self-attention will be replaced
|
55 |
+
blend_words: [['cat',], ["cat",]]
|
56 |
+
masked_self_attention: True
|
57 |
+
# masked_latents: False # performance not so good in our case, need debug
|
58 |
+
bend_th: [2, 2]
|
59 |
+
# preserve source structure of blend_words , [0, 1]
|
60 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
61 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
62 |
+
|
63 |
+
|
64 |
+
1:
|
65 |
+
is_replace_controller: false
|
66 |
+
cross_replace_steps:
|
67 |
+
default_: 0.5
|
68 |
+
self_replace_steps: 0.4
|
69 |
+
eq_params:
|
70 |
+
words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
|
71 |
+
values: [10, 10, 20, 10, 10, 10]
|
72 |
+
2:
|
73 |
+
is_replace_controller: false
|
74 |
+
cross_replace_steps:
|
75 |
+
default_: 0.5
|
76 |
+
self_replace_steps: 0.5
|
77 |
+
eq_params:
|
78 |
+
words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
|
79 |
+
values: [10, 10, 20, 10, 10, 10]
|
80 |
+
3:
|
81 |
+
is_replace_controller: false
|
82 |
+
cross_replace_steps:
|
83 |
+
default_: 0.5
|
84 |
+
self_replace_steps: 0.5
|
85 |
+
eq_params:
|
86 |
+
words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
|
87 |
+
values: [10, 10, 20, 10, 10, 10]
|
88 |
+
4:
|
89 |
+
is_replace_controller: false
|
90 |
+
cross_replace_steps:
|
91 |
+
default_: 0.5
|
92 |
+
self_replace_steps: 0.5
|
93 |
+
eq_params:
|
94 |
+
words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"]
|
95 |
+
values: [10, 10, 20, 10, 10, 10]
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
100 |
+
sample_seeds: [0]
|
101 |
+
val_all_frames: False
|
102 |
+
|
103 |
+
num_inference_steps: 50
|
104 |
+
prompt2prompt_edit: True
|
105 |
+
|
106 |
+
|
107 |
+
model_config:
|
108 |
+
lora: 160
|
109 |
+
# temporal_downsample_time: 4
|
110 |
+
SparseCausalAttention_index: ['mid']
|
111 |
+
least_sc_channel: 640
|
112 |
+
# least_sc_channel: 100000
|
113 |
+
|
114 |
+
test_pipeline_config:
|
115 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
116 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
117 |
+
|
118 |
+
epsilon: 1e-5
|
119 |
+
train_steps: 10
|
120 |
+
seed: 0
|
121 |
+
learning_rate: 1e-5
|
122 |
+
train_temporal_conv: False
|
123 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/attribute/swan_swa.yaml
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/swan_swa.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
path: "data/attribute/swan_swarov"
|
8 |
+
prompt: "a black swan with a red beak swimming in a river near a wall and bushes,"
|
9 |
+
n_sample_frame: 8
|
10 |
+
# n_sample_frame: 22
|
11 |
+
sampling_rate: 1
|
12 |
+
stride: 80
|
13 |
+
offset:
|
14 |
+
left: 0
|
15 |
+
right: 0
|
16 |
+
top: 0
|
17 |
+
bottom: 0
|
18 |
+
|
19 |
+
use_train_latents: True
|
20 |
+
|
21 |
+
validation_sample_logger_config:
|
22 |
+
use_train_latents: True
|
23 |
+
use_inversion_attention: True
|
24 |
+
guidance_scale: 7.5
|
25 |
+
prompts: [
|
26 |
+
# source prompt
|
27 |
+
a black swan with a red beak swimming in a river near a wall and bushes,
|
28 |
+
|
29 |
+
# foreground texture style
|
30 |
+
a Swarovski crystal swan with a red beak swimming in a river near a wall and bushes,
|
31 |
+
]
|
32 |
+
p2p_config:
|
33 |
+
0:
|
34 |
+
# Whether to directly copy the cross attention from source
|
35 |
+
# True: directly copy, better for object replacement
|
36 |
+
# False: keep source attention, better for style
|
37 |
+
is_replace_controller: False
|
38 |
+
|
39 |
+
# Semantic preserving and replacement Debug me
|
40 |
+
cross_replace_steps:
|
41 |
+
default_: 0.8
|
42 |
+
|
43 |
+
# Source background structure preserving, in [0, 1].
|
44 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
45 |
+
self_replace_steps: 0.6
|
46 |
+
|
47 |
+
|
48 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
49 |
+
eq_params:
|
50 |
+
words: ["silver", "sculpture"]
|
51 |
+
values: [2,2]
|
52 |
+
|
53 |
+
# Target structure-divergence hyperparames
|
54 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
55 |
+
# Without following three lines, all self-attention will be replaced
|
56 |
+
blend_words: [['cat',], ["cat",]]
|
57 |
+
masked_self_attention: True
|
58 |
+
# masked_latents: False # performance not so good in our case, need debug
|
59 |
+
bend_th: [2, 2]
|
60 |
+
# preserve source structure of blend_words , [0, 1]
|
61 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
62 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
63 |
+
|
64 |
+
|
65 |
+
1:
|
66 |
+
is_replace_controller: False
|
67 |
+
cross_replace_steps:
|
68 |
+
default_: 0.8
|
69 |
+
self_replace_steps: 0.6
|
70 |
+
|
71 |
+
eq_params:
|
72 |
+
words: ["Swarovski", "crystal"]
|
73 |
+
values: [5, 5] # amplify attention to the word "tiger" by *2
|
74 |
+
use_inversion_attention: True
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
79 |
+
sample_seeds: [0]
|
80 |
+
val_all_frames: False
|
81 |
+
|
82 |
+
num_inference_steps: 50
|
83 |
+
prompt2prompt_edit: True
|
84 |
+
|
85 |
+
|
86 |
+
model_config:
|
87 |
+
lora: 160
|
88 |
+
# temporal_downsample_time: 4
|
89 |
+
SparseCausalAttention_index: ['mid']
|
90 |
+
least_sc_channel: 1280
|
91 |
+
# least_sc_channel: 100000
|
92 |
+
|
93 |
+
test_pipeline_config:
|
94 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
95 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
96 |
+
|
97 |
+
epsilon: 1e-5
|
98 |
+
train_steps: 10
|
99 |
+
seed: 0
|
100 |
+
learning_rate: 1e-5
|
101 |
+
train_temporal_conv: False
|
102 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "FateZero/data/teaser_car-turn"
|
7 |
+
prompt: "a silver jeep driving down a curvy road in the countryside"
|
8 |
+
n_sample_frame: 8
|
9 |
+
sampling_rate: 1
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: true
|
20 |
+
use_inversion_attention: true
|
21 |
+
guidance_scale: 7.5
|
22 |
+
source_prompt: "${train_dataset.prompt}"
|
23 |
+
prompts: [
|
24 |
+
# a silver jeep driving down a curvy road in the countryside,
|
25 |
+
watercolor painting of a silver jeep driving down a curvy road in the countryside,
|
26 |
+
]
|
27 |
+
p2p_config:
|
28 |
+
0:
|
29 |
+
# Whether to directly copy the cross attention from source
|
30 |
+
# True: directly copy, better for object replacement
|
31 |
+
# False: keep source attention, better for style
|
32 |
+
|
33 |
+
is_replace_controller: False
|
34 |
+
|
35 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
36 |
+
cross_replace_steps:
|
37 |
+
default_: 0.8
|
38 |
+
|
39 |
+
# Source background structure preserving, in [0, 1].
|
40 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
41 |
+
self_replace_steps: 0.8
|
42 |
+
|
43 |
+
|
44 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
45 |
+
eq_params:
|
46 |
+
words: ["watercolor"]
|
47 |
+
values: [10,10]
|
48 |
+
|
49 |
+
# Target structure-divergence hyperparames
|
50 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
51 |
+
# Without following three lines, all self-attention will be replaced
|
52 |
+
# blend_words: [['jeep',], ["car",]]
|
53 |
+
# masked_self_attention: True
|
54 |
+
# masked_latents: False # performance not so good in our case, need debug
|
55 |
+
# bend_th: [2, 2]
|
56 |
+
# preserve source structure of blend_words , [0, 1]
|
57 |
+
# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
|
58 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
59 |
+
|
60 |
+
|
61 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
62 |
+
sample_seeds: [0]
|
63 |
+
|
64 |
+
num_inference_steps: 10
|
65 |
+
prompt2prompt_edit: True
|
66 |
+
|
67 |
+
model_config:
|
68 |
+
lora: 160
|
69 |
+
# temporal_downsample_time: 4
|
70 |
+
SparseCausalAttention_index: ['mid']
|
71 |
+
least_sc_channel: 640
|
72 |
+
# least_sc_channel: 100000
|
73 |
+
|
74 |
+
test_pipeline_config:
|
75 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
76 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
77 |
+
|
78 |
+
epsilon: 1e-5
|
79 |
+
train_steps: 10
|
80 |
+
seed: 0
|
81 |
+
learning_rate: 1e-5
|
82 |
+
train_temporal_conv: False
|
83 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/teaser_car-turn"
|
7 |
+
prompt: "a silver jeep driving down a curvy road in the countryside"
|
8 |
+
n_sample_frame: 8
|
9 |
+
sampling_rate: 1
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: true
|
20 |
+
use_inversion_attention: true
|
21 |
+
guidance_scale: 7.5
|
22 |
+
source_prompt: "${train_dataset.prompt}"
|
23 |
+
prompts: [
|
24 |
+
# a silver jeep driving down a curvy road in the countryside,
|
25 |
+
watercolor painting of a silver jeep driving down a curvy road in the countryside,
|
26 |
+
]
|
27 |
+
p2p_config:
|
28 |
+
0:
|
29 |
+
# Whether to directly copy the cross attention from source
|
30 |
+
# True: directly copy, better for object replacement
|
31 |
+
# False: keep source attention, better for style
|
32 |
+
|
33 |
+
is_replace_controller: False
|
34 |
+
|
35 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
36 |
+
cross_replace_steps:
|
37 |
+
default_: 0.8
|
38 |
+
|
39 |
+
# Source background structure preserving, in [0, 1].
|
40 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
41 |
+
self_replace_steps: 0.8
|
42 |
+
|
43 |
+
|
44 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
45 |
+
eq_params:
|
46 |
+
words: ["watercolor"]
|
47 |
+
values: [10,10]
|
48 |
+
|
49 |
+
# Target structure-divergence hyperparames
|
50 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
51 |
+
# Without following three lines, all self-attention will be replaced
|
52 |
+
# blend_words: [['jeep',], ["car",]]
|
53 |
+
# masked_self_attention: True
|
54 |
+
# masked_latents: False # performance not so good in our case, need debug
|
55 |
+
# bend_th: [2, 2]
|
56 |
+
# preserve source structure of blend_words , [0, 1]
|
57 |
+
# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
|
58 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
59 |
+
|
60 |
+
|
61 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
62 |
+
sample_seeds: [0]
|
63 |
+
|
64 |
+
num_inference_steps: 10
|
65 |
+
prompt2prompt_edit: True
|
66 |
+
|
67 |
+
disk_store: True
|
68 |
+
model_config:
|
69 |
+
lora: 160
|
70 |
+
# temporal_downsample_time: 4
|
71 |
+
SparseCausalAttention_index: ['mid']
|
72 |
+
least_sc_channel: 640
|
73 |
+
# least_sc_channel: 100000
|
74 |
+
|
75 |
+
test_pipeline_config:
|
76 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
77 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
78 |
+
|
79 |
+
epsilon: 1e-5
|
80 |
+
train_steps: 10
|
81 |
+
seed: 0
|
82 |
+
learning_rate: 1e-5
|
83 |
+
train_temporal_conv: False
|
84 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/jeep_watercolor.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/teaser_car-turn"
|
7 |
+
prompt: "a silver jeep driving down a curvy road in the countryside"
|
8 |
+
n_sample_frame: 8
|
9 |
+
sampling_rate: 1
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: true
|
20 |
+
use_inversion_attention: true
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
a silver jeep driving down a curvy road in the countryside,
|
24 |
+
watercolor painting of a silver jeep driving down a curvy road in the countryside,
|
25 |
+
]
|
26 |
+
p2p_config:
|
27 |
+
0:
|
28 |
+
# Whether to directly copy the cross attention from source
|
29 |
+
# True: directly copy, better for object replacement
|
30 |
+
# False: keep source attention, better for style
|
31 |
+
is_replace_controller: False
|
32 |
+
|
33 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
34 |
+
cross_replace_steps:
|
35 |
+
default_: 0.8
|
36 |
+
|
37 |
+
# Source background structure preserving, in [0, 1].
|
38 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
39 |
+
self_replace_steps: 0.9
|
40 |
+
|
41 |
+
|
42 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
43 |
+
# eq_params:
|
44 |
+
# words: ["", ""]
|
45 |
+
# values: [10,10]
|
46 |
+
|
47 |
+
# Target structure-divergence hyperparames
|
48 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
49 |
+
# Without following three lines, all self-attention will be replaced
|
50 |
+
# blend_words: [['jeep',], ["car",]]
|
51 |
+
masked_self_attention: True
|
52 |
+
# masked_latents: False # Directly copy the latents, performance not so good in our case
|
53 |
+
bend_th: [2, 2]
|
54 |
+
# preserve source structure of blend_words , [0, 1]
|
55 |
+
# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
|
56 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
57 |
+
|
58 |
+
|
59 |
+
1:
|
60 |
+
cross_replace_steps:
|
61 |
+
default_: 0.8
|
62 |
+
self_replace_steps: 0.8
|
63 |
+
|
64 |
+
eq_params:
|
65 |
+
words: ["watercolor"]
|
66 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
67 |
+
use_inversion_attention: True
|
68 |
+
is_replace_controller: False
|
69 |
+
|
70 |
+
|
71 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
72 |
+
sample_seeds: [0]
|
73 |
+
|
74 |
+
num_inference_steps: 50
|
75 |
+
prompt2prompt_edit: True
|
76 |
+
|
77 |
+
|
78 |
+
model_config:
|
79 |
+
lora: 160
|
80 |
+
# temporal_downsample_time: 4
|
81 |
+
SparseCausalAttention_index: ['mid']
|
82 |
+
least_sc_channel: 640
|
83 |
+
# least_sc_channel: 100000
|
84 |
+
|
85 |
+
test_pipeline_config:
|
86 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
87 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
88 |
+
|
89 |
+
epsilon: 1e-5
|
90 |
+
train_steps: 10
|
91 |
+
seed: 0
|
92 |
+
learning_rate: 1e-5
|
93 |
+
train_temporal_conv: False
|
94 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/lily_monet.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
path: "data/style/red_water_lily_opening"
|
6 |
+
prompt: "a pink water lily"
|
7 |
+
start_sample_frame: 1
|
8 |
+
n_sample_frame: 8
|
9 |
+
# n_sample_frame: 22
|
10 |
+
sampling_rate: 20
|
11 |
+
stride: 8000
|
12 |
+
# offset:
|
13 |
+
# left: 300
|
14 |
+
# right: 0
|
15 |
+
# top: 0
|
16 |
+
# bottom: 0
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: True
|
20 |
+
use_inversion_attention: True
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
a pink water lily,
|
24 |
+
Claude Monet painting of a pink water lily,
|
25 |
+
]
|
26 |
+
p2p_config:
|
27 |
+
0:
|
28 |
+
# Whether to directly copy the cross attention from source
|
29 |
+
# True: directly copy, better for object replacement
|
30 |
+
# False: keep source attention, better for style
|
31 |
+
is_replace_controller: False
|
32 |
+
|
33 |
+
# Semantic preserving and replacement Debug me
|
34 |
+
cross_replace_steps:
|
35 |
+
default_: 0.7
|
36 |
+
|
37 |
+
# Source background structure preserving, in [0, 1].
|
38 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
39 |
+
self_replace_steps: 0.7
|
40 |
+
|
41 |
+
|
42 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
43 |
+
eq_params:
|
44 |
+
words: ["silver", "sculpture"]
|
45 |
+
values: [2,2]
|
46 |
+
|
47 |
+
# Target structure-divergence hyperparames
|
48 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
49 |
+
# Without following three lines, all self-attention will be replaced
|
50 |
+
blend_words: [['cat',], ["cat",]]
|
51 |
+
masked_self_attention: True
|
52 |
+
# masked_latents: False # performance not so good in our case, need debug
|
53 |
+
bend_th: [2, 2]
|
54 |
+
# preserve source structure of blend_words , [0, 1]
|
55 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
56 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
57 |
+
|
58 |
+
|
59 |
+
1:
|
60 |
+
is_replace_controller: False
|
61 |
+
cross_replace_steps:
|
62 |
+
default_: 0.5
|
63 |
+
self_replace_steps: 0.5
|
64 |
+
|
65 |
+
eq_params:
|
66 |
+
words: ["Monet"]
|
67 |
+
values: [10]
|
68 |
+
|
69 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
70 |
+
sample_seeds: [0]
|
71 |
+
val_all_frames: False
|
72 |
+
|
73 |
+
num_inference_steps: 50
|
74 |
+
prompt2prompt_edit: True
|
75 |
+
|
76 |
+
|
77 |
+
model_config:
|
78 |
+
lora: 160
|
79 |
+
# temporal_downsample_time: 4
|
80 |
+
SparseCausalAttention_index: ['mid']
|
81 |
+
least_sc_channel: 1280
|
82 |
+
# least_sc_channel: 100000
|
83 |
+
|
84 |
+
test_pipeline_config:
|
85 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
86 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
87 |
+
|
88 |
+
epsilon: 1e-5
|
89 |
+
train_steps: 10
|
90 |
+
seed: 0
|
91 |
+
learning_rate: 1e-5
|
92 |
+
train_temporal_conv: False
|
93 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/rabit_pokemon.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
path: "data/style/rabit"
|
6 |
+
prompt: "A rabbit is eating a watermelon"
|
7 |
+
n_sample_frame: 8
|
8 |
+
# n_sample_frame: 22
|
9 |
+
sampling_rate: 3
|
10 |
+
stride: 80
|
11 |
+
|
12 |
+
|
13 |
+
validation_sample_logger_config:
|
14 |
+
use_train_latents: True
|
15 |
+
use_inversion_attention: True
|
16 |
+
guidance_scale: 7.5
|
17 |
+
prompts: [
|
18 |
+
# source prompt
|
19 |
+
A rabbit is eating a watermelon,
|
20 |
+
# overall style
|
21 |
+
pokemon cartoon of A rabbit is eating a watermelon,
|
22 |
+
]
|
23 |
+
p2p_config:
|
24 |
+
0:
|
25 |
+
# Whether to directly copy the cross attention from source
|
26 |
+
# True: directly copy, better for object replacement
|
27 |
+
# False: keep source attention, better for style
|
28 |
+
is_replace_controller: False
|
29 |
+
|
30 |
+
# Semantic preserving and replacement Debug me
|
31 |
+
cross_replace_steps:
|
32 |
+
default_: 0.8
|
33 |
+
|
34 |
+
# Source background structure preserving, in [0, 1].
|
35 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
36 |
+
self_replace_steps: 0.6
|
37 |
+
|
38 |
+
|
39 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
40 |
+
eq_params:
|
41 |
+
words: ["silver", "sculpture"]
|
42 |
+
values: [2,2]
|
43 |
+
|
44 |
+
# Target structure-divergence hyperparames
|
45 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
46 |
+
# Without following three lines, all self-attention will be replaced
|
47 |
+
blend_words: [['cat',], ["cat",]]
|
48 |
+
masked_self_attention: True
|
49 |
+
# masked_latents: False # performance not so good in our case, need debug
|
50 |
+
bend_th: [2, 2]
|
51 |
+
# preserve source structure of blend_words , [0, 1]
|
52 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
53 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
54 |
+
|
55 |
+
|
56 |
+
1:
|
57 |
+
is_replace_controller: False
|
58 |
+
cross_replace_steps:
|
59 |
+
default_: 0.7
|
60 |
+
self_replace_steps: 0.7
|
61 |
+
|
62 |
+
eq_params:
|
63 |
+
words: ["pokemon", "cartoon"]
|
64 |
+
values: [3, 3] # amplify attention to the word "tiger" by *2
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
69 |
+
sample_seeds: [0]
|
70 |
+
val_all_frames: False
|
71 |
+
|
72 |
+
num_inference_steps: 50
|
73 |
+
prompt2prompt_edit: True
|
74 |
+
|
75 |
+
|
76 |
+
model_config:
|
77 |
+
# lora: 160
|
78 |
+
# temporal_downsample_time: 4
|
79 |
+
# SparseCausalAttention_index: ['mid']
|
80 |
+
# least_sc_channel: 640
|
81 |
+
# least_sc_channel: 100000
|
82 |
+
|
83 |
+
test_pipeline_config:
|
84 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
85 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
86 |
+
|
87 |
+
epsilon: 1e-5
|
88 |
+
train_steps: 50
|
89 |
+
seed: 0
|
90 |
+
learning_rate: 1e-5
|
91 |
+
train_temporal_conv: False
|
92 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/sun_flower_van_gogh.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
train_dataset:
|
4 |
+
path: "data/style/sunflower"
|
5 |
+
prompt: "a yellow sunflower"
|
6 |
+
start_sample_frame: 0
|
7 |
+
n_sample_frame: 8
|
8 |
+
sampling_rate: 1
|
9 |
+
|
10 |
+
|
11 |
+
validation_sample_logger_config:
|
12 |
+
use_train_latents: True
|
13 |
+
use_inversion_attention: True
|
14 |
+
guidance_scale: 7.5
|
15 |
+
prompts: [
|
16 |
+
a yellow sunflower,
|
17 |
+
van gogh style painting of a yellow sunflower,
|
18 |
+
]
|
19 |
+
p2p_config:
|
20 |
+
0:
|
21 |
+
# Whether to directly copy the cross attention from source
|
22 |
+
# True: directly copy, better for object replacement
|
23 |
+
# False: keep source attention, better for style
|
24 |
+
is_replace_controller: False
|
25 |
+
|
26 |
+
# Semantic preserving and replacement Debug me
|
27 |
+
cross_replace_steps:
|
28 |
+
default_: 0.7
|
29 |
+
|
30 |
+
# Source background structure preserving, in [0, 1].
|
31 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
32 |
+
self_replace_steps: 0.7
|
33 |
+
|
34 |
+
|
35 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
36 |
+
eq_params:
|
37 |
+
words: ["silver", "sculpture"]
|
38 |
+
values: [2,2]
|
39 |
+
|
40 |
+
# Target structure-divergence hyperparames
|
41 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
42 |
+
# Without following three lines, all self-attention will be replaced
|
43 |
+
blend_words: [['cat',], ["cat",]]
|
44 |
+
masked_self_attention: True
|
45 |
+
# masked_latents: False # performance not so good in our case, need debug
|
46 |
+
bend_th: [2, 2]
|
47 |
+
# preserve source structure of blend_words , [0, 1]
|
48 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
49 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
50 |
+
|
51 |
+
|
52 |
+
1:
|
53 |
+
is_replace_controller: False
|
54 |
+
cross_replace_steps:
|
55 |
+
default_: 0.5
|
56 |
+
self_replace_steps: 0.5
|
57 |
+
|
58 |
+
eq_params:
|
59 |
+
words: ["van", "gogh"]
|
60 |
+
values: [10, 10] # amplify attention to the word "tiger" by *2
|
61 |
+
|
62 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
63 |
+
sample_seeds: [0]
|
64 |
+
val_all_frames: False
|
65 |
+
|
66 |
+
num_inference_steps: 50
|
67 |
+
prompt2prompt_edit: True
|
68 |
+
|
69 |
+
|
70 |
+
model_config:
|
71 |
+
lora: 160
|
72 |
+
# temporal_downsample_time: 4
|
73 |
+
SparseCausalAttention_index: ['mid']
|
74 |
+
least_sc_channel: 640
|
75 |
+
# least_sc_channel: 100000
|
76 |
+
|
77 |
+
test_pipeline_config:
|
78 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
79 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
80 |
+
|
81 |
+
epsilon: 1e-5
|
82 |
+
train_steps: 10
|
83 |
+
seed: 0
|
84 |
+
learning_rate: 1e-5
|
85 |
+
train_temporal_conv: False
|
86 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/surf_ukiyo.yaml
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
train_dataset:
|
4 |
+
path: "data/style/surf"
|
5 |
+
prompt: "a man with round helmet surfing on a white wave in blue ocean with a rope"
|
6 |
+
n_sample_frame: 1
|
7 |
+
|
8 |
+
sampling_rate: 8
|
9 |
+
|
10 |
+
|
11 |
+
# use_train_latents: True
|
12 |
+
|
13 |
+
validation_sample_logger_config:
|
14 |
+
use_train_latents: true
|
15 |
+
use_inversion_attention: true
|
16 |
+
guidance_scale: 7.5
|
17 |
+
prompts: [
|
18 |
+
a man with round helmet surfing on a white wave in blue ocean with a rope,
|
19 |
+
The Ukiyo-e style painting of a man with round helmet surfing on a white wave in blue ocean with a rope
|
20 |
+
]
|
21 |
+
p2p_config:
|
22 |
+
0:
|
23 |
+
# Whether to directly copy the cross attention from source
|
24 |
+
# True: directly copy, better for object replacement
|
25 |
+
# False: keep source attention, better for style
|
26 |
+
is_replace_controller: False
|
27 |
+
|
28 |
+
# Semantic preserving and replacement Debug me
|
29 |
+
cross_replace_steps:
|
30 |
+
default_: 0.8
|
31 |
+
|
32 |
+
# Source background structure preserving, in [0, 1].
|
33 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
34 |
+
self_replace_steps: 0.8
|
35 |
+
|
36 |
+
|
37 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
38 |
+
eq_params:
|
39 |
+
words: ["silver", "sculpture"]
|
40 |
+
values: [2,2]
|
41 |
+
|
42 |
+
# Target structure-divergence hyperparames
|
43 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
44 |
+
# Without following three lines, all self-attention will be replaced
|
45 |
+
blend_words: [['cat',], ["cat",]]
|
46 |
+
masked_self_attention: True
|
47 |
+
# masked_latents: False # performance not so good in our case, need debug
|
48 |
+
bend_th: [2, 2]
|
49 |
+
# preserve source structure of blend_words , [0, 1]
|
50 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
51 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
52 |
+
|
53 |
+
1:
|
54 |
+
is_replace_controller: False
|
55 |
+
cross_replace_steps:
|
56 |
+
default_: 0.9
|
57 |
+
self_replace_steps: 0.9
|
58 |
+
|
59 |
+
eq_params:
|
60 |
+
words: ["Ukiyo-e"]
|
61 |
+
values: [10, 10] # amplify attention to the word "tiger" by *2
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
67 |
+
sample_seeds: [0]
|
68 |
+
val_all_frames: False
|
69 |
+
|
70 |
+
num_inference_steps: 50
|
71 |
+
prompt2prompt_edit: True
|
72 |
+
|
73 |
+
|
74 |
+
model_config:
|
75 |
+
# lora: 160
|
76 |
+
# temporal_downsample_time: 4
|
77 |
+
SparseCausalAttention_index: ['mid']
|
78 |
+
least_sc_channel: 640
|
79 |
+
# least_sc_channel: 100000
|
80 |
+
|
81 |
+
test_pipeline_config:
|
82 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
83 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
84 |
+
|
85 |
+
epsilon: 1e-5
|
86 |
+
train_steps: 50
|
87 |
+
seed: 0
|
88 |
+
learning_rate: 1e-5
|
89 |
+
train_temporal_conv: False
|
90 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/swan_cartoon.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
path: "data/style/blackswan"
|
6 |
+
prompt: "a black swan with a red beak swimming in a river near a wall and bushes,"
|
7 |
+
n_sample_frame: 8
|
8 |
+
# n_sample_frame: 22
|
9 |
+
sampling_rate: 6
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
# use_train_latents: True
|
18 |
+
|
19 |
+
validation_sample_logger_config:
|
20 |
+
use_train_latents: true
|
21 |
+
use_inversion_attention: true
|
22 |
+
guidance_scale: 7.5
|
23 |
+
prompts: [
|
24 |
+
# source prompt
|
25 |
+
a black swan with a red beak swimming in a river near a wall and bushes,
|
26 |
+
cartoon photo of a black swan with a red beak swimming in a river near a wall and bushes,
|
27 |
+
]
|
28 |
+
p2p_config:
|
29 |
+
0:
|
30 |
+
# Whether to directly copy the cross attention from source
|
31 |
+
# True: directly copy, better for object replacement
|
32 |
+
# False: keep source attention, better for style
|
33 |
+
is_replace_controller: False
|
34 |
+
|
35 |
+
# Semantic preserving and replacement Debug me
|
36 |
+
cross_replace_steps:
|
37 |
+
default_: 0.8
|
38 |
+
|
39 |
+
# Source background structure preserving, in [0, 1].
|
40 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
41 |
+
self_replace_steps: 0.6
|
42 |
+
|
43 |
+
|
44 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
45 |
+
eq_params:
|
46 |
+
words: ["silver", "sculpture"]
|
47 |
+
values: [2,2]
|
48 |
+
|
49 |
+
# Target structure-divergence hyperparames
|
50 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
51 |
+
# Without following three lines, all self-attention will be replaced
|
52 |
+
blend_words: [['cat',], ["cat",]]
|
53 |
+
masked_self_attention: True
|
54 |
+
# masked_latents: False # performance not so good in our case, need debug
|
55 |
+
bend_th: [2, 2]
|
56 |
+
# preserve source structure of blend_words , [0, 1]
|
57 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
58 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
59 |
+
|
60 |
+
# Fixed hyperparams
|
61 |
+
use_inversion_attention: True
|
62 |
+
|
63 |
+
1:
|
64 |
+
is_replace_controller: False
|
65 |
+
cross_replace_steps:
|
66 |
+
default_: 0.8
|
67 |
+
self_replace_steps: 0.7
|
68 |
+
|
69 |
+
eq_params:
|
70 |
+
words: ["cartoon"]
|
71 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
72 |
+
use_inversion_attention: True
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
77 |
+
sample_seeds: [0]
|
78 |
+
val_all_frames: False
|
79 |
+
|
80 |
+
num_inference_steps: 50
|
81 |
+
# guidance_scale: 7.5
|
82 |
+
prompt2prompt_edit: True
|
83 |
+
|
84 |
+
|
85 |
+
model_config:
|
86 |
+
lora: 160
|
87 |
+
# temporal_downsample_time: 4
|
88 |
+
SparseCausalAttention_index: ['mid']
|
89 |
+
least_sc_channel: 640
|
90 |
+
# least_sc_channel: 100000
|
91 |
+
|
92 |
+
test_pipeline_config:
|
93 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
94 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
95 |
+
|
96 |
+
epsilon: 1e-5
|
97 |
+
train_steps: 10
|
98 |
+
seed: 0
|
99 |
+
learning_rate: 1e-5
|
100 |
+
train_temporal_conv: False
|
101 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/style/train_shinkai.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-4"
|
2 |
+
|
3 |
+
train_dataset:
|
4 |
+
path: "data/style/train"
|
5 |
+
prompt: "a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track"
|
6 |
+
n_sample_frame: 32
|
7 |
+
# n_sample_frame: 22
|
8 |
+
sampling_rate: 7
|
9 |
+
stride: 80
|
10 |
+
# offset:
|
11 |
+
# left: 300
|
12 |
+
# right: 0
|
13 |
+
# top: 0
|
14 |
+
# bottom: 0
|
15 |
+
|
16 |
+
use_train_latents: True
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: True
|
20 |
+
use_inversion_attention: True
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track,
|
24 |
+
a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track Makoto Shinkai style
|
25 |
+
|
26 |
+
]
|
27 |
+
p2p_config:
|
28 |
+
0:
|
29 |
+
# Whether to directly copy the cross attention from source
|
30 |
+
# True: directly copy, better for object replacement
|
31 |
+
# False: keep source attention, better for style
|
32 |
+
is_replace_controller: False
|
33 |
+
|
34 |
+
# Semantic preserving and replacement Debug me
|
35 |
+
cross_replace_steps:
|
36 |
+
default_: 1.0
|
37 |
+
|
38 |
+
# Source background structure preserving, in [0, 1].
|
39 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
40 |
+
self_replace_steps: 1.0
|
41 |
+
|
42 |
+
|
43 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
44 |
+
# eq_params:
|
45 |
+
# words: ["silver", "sculpture"]
|
46 |
+
# values: [2,2]
|
47 |
+
|
48 |
+
# Target structure-divergence hyperparames
|
49 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
50 |
+
# Without following three lines, all self-attention will be replaced
|
51 |
+
# blend_words: [['cat',], ["cat",]]
|
52 |
+
# masked_self_attention: True
|
53 |
+
# # masked_latents: False # performance not so good in our case, need debug
|
54 |
+
# bend_th: [2, 2]
|
55 |
+
# preserve source structure of blend_words , [0, 1]
|
56 |
+
# default is bend_th: [2, 2] # preserve all source self-attention
|
57 |
+
# bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention
|
58 |
+
|
59 |
+
|
60 |
+
1:
|
61 |
+
is_replace_controller: False
|
62 |
+
cross_replace_steps:
|
63 |
+
default_: 1.0
|
64 |
+
self_replace_steps: 0.9
|
65 |
+
|
66 |
+
eq_params:
|
67 |
+
words: ["Makoto", "Shinkai"]
|
68 |
+
values: [10, 10] # amplify attention to the word "tiger" by *2
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
74 |
+
sample_seeds: [0]
|
75 |
+
val_all_frames: False
|
76 |
+
|
77 |
+
num_inference_steps: 50
|
78 |
+
prompt2prompt_edit: True
|
79 |
+
|
80 |
+
|
81 |
+
model_config:
|
82 |
+
lora: 160
|
83 |
+
# temporal_downsample_time: 4
|
84 |
+
SparseCausalAttention_index: ['mid']
|
85 |
+
least_sc_channel: 1280
|
86 |
+
# least_sc_channel: 100000
|
87 |
+
|
88 |
+
test_pipeline_config:
|
89 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
90 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
91 |
+
|
92 |
+
epsilon: 1e-5
|
93 |
+
train_steps: 10
|
94 |
+
seed: 0
|
95 |
+
learning_rate: 1e-5
|
96 |
+
train_temporal_conv: False
|
97 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/teaser/jeep_posche.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_posche.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "./ckpt/jeep_tuned_200"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "data/teaser_car-turn"
|
7 |
+
prompt: "a silver jeep driving down a curvy road in the countryside,"
|
8 |
+
n_sample_frame: 8
|
9 |
+
sampling_rate: 1
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: true
|
20 |
+
use_inversion_attention: true
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
a silver jeep driving down a curvy road in the countryside,
|
24 |
+
a Porsche car driving down a curvy road in the countryside,
|
25 |
+
]
|
26 |
+
p2p_config:
|
27 |
+
0:
|
28 |
+
# Whether to directly copy the cross attention from source
|
29 |
+
# True: directly copy, better for object replacement
|
30 |
+
# False: keep source attention, better for style
|
31 |
+
is_replace_controller: False
|
32 |
+
|
33 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
34 |
+
cross_replace_steps:
|
35 |
+
default_: 0.8
|
36 |
+
|
37 |
+
# Source background structure preserving, in [0, 1].
|
38 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
39 |
+
self_replace_steps: 0.9
|
40 |
+
|
41 |
+
|
42 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
43 |
+
# Usefull in style editing
|
44 |
+
eq_params:
|
45 |
+
words: ["watercolor", "painting"]
|
46 |
+
values: [10,10]
|
47 |
+
|
48 |
+
# Target structure-divergence hyperparames
|
49 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
50 |
+
# Without following three lines, all self-attention will be replaced
|
51 |
+
# Usefull in shape editing
|
52 |
+
blend_words: [['jeep',], ["car",]]
|
53 |
+
masked_self_attention: True
|
54 |
+
# masked_latents: False # Directly copy the latents, performance not so good in our case
|
55 |
+
|
56 |
+
# preserve source structure of blend_words , [0, 1]
|
57 |
+
# bend_th-> [1.0, 1.0], mask -> 0, use inversion-time attention, the structure is similar to the input
|
58 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
59 |
+
bend_th: [0.3, 0.3]
|
60 |
+
|
61 |
+
1:
|
62 |
+
cross_replace_steps:
|
63 |
+
default_: 0.5
|
64 |
+
self_replace_steps: 0.5
|
65 |
+
|
66 |
+
use_inversion_attention: True
|
67 |
+
is_replace_controller: True
|
68 |
+
|
69 |
+
blend_words: [['silver', 'jeep'], ["Porsche", 'car']] # for local edit. If it is not local yet - use only the source object: blend_word = ((('cat',), ("cat",))).
|
70 |
+
masked_self_attention: True
|
71 |
+
bend_th: [0.3, 0.3]
|
72 |
+
|
73 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
74 |
+
sample_seeds: [0]
|
75 |
+
|
76 |
+
num_inference_steps: 50
|
77 |
+
prompt2prompt_edit: True
|
78 |
+
|
79 |
+
|
80 |
+
model_config:
|
81 |
+
lora: 160
|
82 |
+
|
83 |
+
|
84 |
+
test_pipeline_config:
|
85 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
86 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
87 |
+
|
88 |
+
epsilon: 1e-5
|
89 |
+
train_steps: 10
|
90 |
+
seed: 0
|
91 |
+
learning_rate: 1e-5
|
92 |
+
train_temporal_conv: False
|
93 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/config/teaser/jeep_watercolor.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml
|
2 |
+
|
3 |
+
pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4"
|
4 |
+
|
5 |
+
train_dataset:
|
6 |
+
path: "FateZero/data/teaser_car-turn"
|
7 |
+
prompt: "a silver jeep driving down a curvy road in the countryside"
|
8 |
+
n_sample_frame: 8
|
9 |
+
sampling_rate: 1
|
10 |
+
stride: 80
|
11 |
+
offset:
|
12 |
+
left: 0
|
13 |
+
right: 0
|
14 |
+
top: 0
|
15 |
+
bottom: 0
|
16 |
+
|
17 |
+
|
18 |
+
validation_sample_logger_config:
|
19 |
+
use_train_latents: true
|
20 |
+
use_inversion_attention: true
|
21 |
+
guidance_scale: 7.5
|
22 |
+
prompts: [
|
23 |
+
a silver jeep driving down a curvy road in the countryside,
|
24 |
+
watercolor painting of a silver jeep driving down a curvy road in the countryside,
|
25 |
+
]
|
26 |
+
p2p_config:
|
27 |
+
0:
|
28 |
+
# Whether to directly copy the cross attention from source
|
29 |
+
# True: directly copy, better for object replacement
|
30 |
+
# False: keep source attention, better for style
|
31 |
+
is_replace_controller: False
|
32 |
+
|
33 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
34 |
+
cross_replace_steps:
|
35 |
+
default_: 0.8
|
36 |
+
|
37 |
+
# Source background structure preserving, in [0, 1].
|
38 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
39 |
+
self_replace_steps: 0.9
|
40 |
+
|
41 |
+
|
42 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
43 |
+
# eq_params:
|
44 |
+
# words: ["", ""]
|
45 |
+
# values: [10,10]
|
46 |
+
|
47 |
+
# Target structure-divergence hyperparames
|
48 |
+
# If you change the shape of object better to use all three line, otherwise, no need.
|
49 |
+
# Without following three lines, all self-attention will be replaced
|
50 |
+
# blend_words: [['jeep',], ["car",]]
|
51 |
+
masked_self_attention: True
|
52 |
+
# masked_latents: False # Directly copy the latents, performance not so good in our case
|
53 |
+
bend_th: [2, 2]
|
54 |
+
# preserve source structure of blend_words , [0, 1]
|
55 |
+
# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
|
56 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
57 |
+
|
58 |
+
|
59 |
+
1:
|
60 |
+
cross_replace_steps:
|
61 |
+
default_: 0.8
|
62 |
+
self_replace_steps: 0.8
|
63 |
+
|
64 |
+
eq_params:
|
65 |
+
words: ["watercolor"]
|
66 |
+
values: [10] # amplify attention to the word "tiger" by *2
|
67 |
+
use_inversion_attention: True
|
68 |
+
is_replace_controller: False
|
69 |
+
|
70 |
+
|
71 |
+
clip_length: "${..train_dataset.n_sample_frame}"
|
72 |
+
sample_seeds: [0]
|
73 |
+
|
74 |
+
num_inference_steps: 50
|
75 |
+
prompt2prompt_edit: True
|
76 |
+
|
77 |
+
|
78 |
+
model_config:
|
79 |
+
lora: 160
|
80 |
+
# temporal_downsample_time: 4
|
81 |
+
SparseCausalAttention_index: ['mid']
|
82 |
+
least_sc_channel: 640
|
83 |
+
# least_sc_channel: 100000
|
84 |
+
|
85 |
+
test_pipeline_config:
|
86 |
+
target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline
|
87 |
+
num_inference_steps: "${..validation_sample_logger.num_inference_steps}"
|
88 |
+
|
89 |
+
epsilon: 1e-5
|
90 |
+
train_steps: 10
|
91 |
+
seed: 0
|
92 |
+
learning_rate: 1e-5
|
93 |
+
train_temporal_conv: False
|
94 |
+
guidance_scale: "${validation_sample_logger_config.guidance_scale}"
|
FateZero/data/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!teaser_car-turn
|
3 |
+
!teaser_car-turn/*
|
4 |
+
!.gitignore
|
FateZero/data/teaser_car-turn/00000.png
ADDED
FateZero/data/teaser_car-turn/00001.png
ADDED
FateZero/data/teaser_car-turn/00002.png
ADDED
FateZero/data/teaser_car-turn/00003.png
ADDED
FateZero/data/teaser_car-turn/00004.png
ADDED
FateZero/data/teaser_car-turn/00005.png
ADDED
FateZero/data/teaser_car-turn/00006.png
ADDED
FateZero/data/teaser_car-turn/00007.png
ADDED
FateZero/docs/EditingGuidance.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EditingGuidance
|
2 |
+
|
3 |
+
## Prompt Engineering
|
4 |
+
For the results in the paper and webpage, we get the source prompt using the BLIP model embedded in the [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/).
|
5 |
+
|
6 |
+
Click the "interrogate CLIP", and we will get a source prompt automatically. Then, we remove the last few useless words.
|
7 |
+
|
8 |
+
<img src="../docs/blip.png" height="220px"/>
|
9 |
+
|
10 |
+
During stylization, you may use a very simple source prompt "A photo" as a baseline if your input video is too complicated to describe by one sentence.
|
11 |
+
|
12 |
+
### Validate the prompt
|
13 |
+
|
14 |
+
- Put the source prompt into the stable diffusion. If the generated image is close to our input video, it can be a good source prompt.
|
15 |
+
- A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve.
|
16 |
+
- Put the target prompt into the stable diffusion. We can check the upper bound of our editing effect. A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet")
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
## FateZero hyperparameters
|
24 |
+
We give a simple analysis of the involved hyperparaters as follows:
|
25 |
+
``` yaml
|
26 |
+
# Whether to directly copy the cross attention from source
|
27 |
+
# True: directly copy, better for object replacement
|
28 |
+
# False: keep source attention, better for style
|
29 |
+
is_replace_controller: False
|
30 |
+
|
31 |
+
# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout
|
32 |
+
cross_replace_steps:
|
33 |
+
default_: 0.8
|
34 |
+
|
35 |
+
# Source background structure preserving, in [0, 1].
|
36 |
+
# e.g., =0.6 Replace the first 60% steps self-attention
|
37 |
+
self_replace_steps: 0.8
|
38 |
+
|
39 |
+
|
40 |
+
# Amplify the target-words cross attention, larger value, more close to target
|
41 |
+
# eq_params:
|
42 |
+
# words: ["", ""]
|
43 |
+
# values: [10,10]
|
44 |
+
|
45 |
+
# Target structure-divergence hyperparames
|
46 |
+
# If you change the shape of object, it is better to use all three line; otherwise, no need.
|
47 |
+
# Without following three lines, all self-attention will be replaced
|
48 |
+
blend_words: [['jeep',], ["car",]]
|
49 |
+
masked_self_attention: True
|
50 |
+
# masked_latents: False # Directly copy the latents, performance not so good in our case
|
51 |
+
bend_th: [2, 2]
|
52 |
+
# preserve source structure of blend_words in [0, 1]
|
53 |
+
# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention
|
54 |
+
# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention
|
55 |
+
```
|
56 |
+
|
57 |
+
## DDIM hyperparameters
|
58 |
+
|
59 |
+
We profile the cost of editing 8 frames on an Nvidia 3090, fp16 of accelerator, xformers.
|
60 |
+
|
61 |
+
| Configs | Attention location | DDIM Inver. Step | CPU memory | GPU memory | Inversion time | Editing time time | Quality
|
62 |
+
|------------------|------------------ |------------------|------------------|------------------|------------------|----| ---- |
|
63 |
+
| [basic](../config/teaser/jeep_watercolor.yaml) | RAM | 50 | 100G | 12G | 60s | 40s | Full support
|
64 |
+
| [low cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml) | RAM | 10 | 15G | 12G | 10s | 10s | OK for Style, not work for shape
|
65 |
+
| [lower cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml) | DISK | 10 | 6G | 12G | 33 s | 100s | OK for Style, not work for shape
|
FateZero/docs/OpenSans-Regular.ttf
ADDED
Binary file (148 kB). View file
|
|
FateZero/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch==1.12.1+cu113 # --index-url https://download.pytorch.org/whl/cu113
|
3 |
+
torchvision==0.13.1+cu113 # --index-url https://download.pytorch.org/whl/cu113
|
4 |
+
diffusers[torch]==0.11.1
|
5 |
+
accelerate==0.15.0
|
6 |
+
transformers==4.25.1
|
7 |
+
bitsandbytes==0.35.4
|
8 |
+
einops
|
9 |
+
omegaconf
|
10 |
+
ftfy
|
11 |
+
tensorboard
|
12 |
+
modelcards
|
13 |
+
imageio
|
14 |
+
triton
|
15 |
+
click
|
16 |
+
opencv-python
|
17 |
+
imageio[ffmpeg]
|
FateZero/test_fatezero.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import copy
|
4 |
+
from typing import Optional,Dict
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import click
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.data
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
|
13 |
+
from accelerate import Accelerator
|
14 |
+
from accelerate.logging import get_logger
|
15 |
+
from accelerate.utils import set_seed
|
16 |
+
from diffusers import (
|
17 |
+
AutoencoderKL,
|
18 |
+
DDIMScheduler,
|
19 |
+
)
|
20 |
+
from diffusers.utils.import_utils import is_xformers_available
|
21 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
22 |
+
from einops import rearrange
|
23 |
+
|
24 |
+
import sys
|
25 |
+
sys.path.append('FateZero')
|
26 |
+
from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
|
27 |
+
from video_diffusion.data.dataset import ImageSequenceDataset
|
28 |
+
from video_diffusion.common.util import get_time_string, get_function_args
|
29 |
+
from video_diffusion.common.image_util import log_train_samples
|
30 |
+
from video_diffusion.common.instantiate_from_config import instantiate_from_config
|
31 |
+
from video_diffusion.pipelines.p2pvalidation_loop import p2pSampleLogger
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
def collate_fn(examples):
|
37 |
+
"""Concat a batch of sampled image in dataloader
|
38 |
+
"""
|
39 |
+
batch = {
|
40 |
+
"prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0),
|
41 |
+
"images": torch.stack([example["images"] for example in examples]),
|
42 |
+
}
|
43 |
+
return batch
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def test(
|
48 |
+
config: str,
|
49 |
+
pretrained_model_path: str,
|
50 |
+
train_dataset: Dict,
|
51 |
+
logdir: str = None,
|
52 |
+
validation_sample_logger_config: Optional[Dict] = None,
|
53 |
+
test_pipeline_config: Optional[Dict] = None,
|
54 |
+
gradient_accumulation_steps: int = 1,
|
55 |
+
seed: Optional[int] = None,
|
56 |
+
mixed_precision: Optional[str] = "fp16",
|
57 |
+
train_batch_size: int = 1,
|
58 |
+
model_config: dict={},
|
59 |
+
verbose: bool=True,
|
60 |
+
**kwargs
|
61 |
+
|
62 |
+
):
|
63 |
+
args = get_function_args()
|
64 |
+
|
65 |
+
time_string = get_time_string()
|
66 |
+
if logdir is None:
|
67 |
+
logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
|
68 |
+
logdir += f"_{time_string}"
|
69 |
+
|
70 |
+
accelerator = Accelerator(
|
71 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
72 |
+
mixed_precision=mixed_precision,
|
73 |
+
)
|
74 |
+
if accelerator.is_main_process:
|
75 |
+
os.makedirs(logdir, exist_ok=True)
|
76 |
+
OmegaConf.save(args, os.path.join(logdir, "config.yml"))
|
77 |
+
|
78 |
+
if seed is not None:
|
79 |
+
set_seed(seed)
|
80 |
+
|
81 |
+
# Load the tokenizer
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
83 |
+
pretrained_model_path,
|
84 |
+
subfolder="tokenizer",
|
85 |
+
use_fast=False,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Load models and create wrapper for stable diffusion
|
89 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
90 |
+
pretrained_model_path,
|
91 |
+
subfolder="text_encoder",
|
92 |
+
)
|
93 |
+
|
94 |
+
vae = AutoencoderKL.from_pretrained(
|
95 |
+
pretrained_model_path,
|
96 |
+
subfolder="vae",
|
97 |
+
)
|
98 |
+
|
99 |
+
unet = UNetPseudo3DConditionModel.from_2d_model(
|
100 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
101 |
+
)
|
102 |
+
|
103 |
+
if 'target' not in test_pipeline_config:
|
104 |
+
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
105 |
+
|
106 |
+
pipeline = instantiate_from_config(
|
107 |
+
test_pipeline_config,
|
108 |
+
vae=vae,
|
109 |
+
text_encoder=text_encoder,
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
unet=unet,
|
112 |
+
scheduler=DDIMScheduler.from_pretrained(
|
113 |
+
pretrained_model_path,
|
114 |
+
subfolder="scheduler",
|
115 |
+
),
|
116 |
+
disk_store=kwargs.get('disk_store', False)
|
117 |
+
)
|
118 |
+
pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps'])
|
119 |
+
pipeline.set_progress_bar_config(disable=True)
|
120 |
+
|
121 |
+
|
122 |
+
if is_xformers_available():
|
123 |
+
try:
|
124 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
125 |
+
except Exception as e:
|
126 |
+
logger.warning(
|
127 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
128 |
+
f" correctly and a GPU is available: {e}"
|
129 |
+
)
|
130 |
+
|
131 |
+
vae.requires_grad_(False)
|
132 |
+
unet.requires_grad_(False)
|
133 |
+
text_encoder.requires_grad_(False)
|
134 |
+
prompt_ids = tokenizer(
|
135 |
+
train_dataset["prompt"],
|
136 |
+
truncation=True,
|
137 |
+
padding="max_length",
|
138 |
+
max_length=tokenizer.model_max_length,
|
139 |
+
return_tensors="pt",
|
140 |
+
).input_ids
|
141 |
+
train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids)
|
142 |
+
|
143 |
+
train_dataloader = torch.utils.data.DataLoader(
|
144 |
+
train_dataset,
|
145 |
+
batch_size=train_batch_size,
|
146 |
+
shuffle=True,
|
147 |
+
num_workers=4,
|
148 |
+
collate_fn=collate_fn,
|
149 |
+
)
|
150 |
+
train_sample_save_path = os.path.join(logdir, "train_samples.gif")
|
151 |
+
log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader)
|
152 |
+
|
153 |
+
unet, train_dataloader = accelerator.prepare(
|
154 |
+
unet, train_dataloader
|
155 |
+
)
|
156 |
+
|
157 |
+
weight_dtype = torch.float32
|
158 |
+
if accelerator.mixed_precision == "fp16":
|
159 |
+
weight_dtype = torch.float16
|
160 |
+
print('use fp16')
|
161 |
+
elif accelerator.mixed_precision == "bf16":
|
162 |
+
weight_dtype = torch.bfloat16
|
163 |
+
|
164 |
+
# Move text_encode and vae to gpu.
|
165 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
166 |
+
# These models are only used for inference, keeping weights in full precision is not required.
|
167 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
168 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
169 |
+
|
170 |
+
|
171 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
172 |
+
# The trackers initializes automatically on the main process.
|
173 |
+
if accelerator.is_main_process:
|
174 |
+
accelerator.init_trackers("video") # , config=vars(args))
|
175 |
+
logger.info("***** wait to fix the logger path *****")
|
176 |
+
|
177 |
+
if validation_sample_logger_config is not None and accelerator.is_main_process:
|
178 |
+
validation_sample_logger = p2pSampleLogger(**validation_sample_logger_config, logdir=logdir)
|
179 |
+
# validation_sample_logger.log_sample_images(
|
180 |
+
# pipeline=pipeline,
|
181 |
+
# device=accelerator.device,
|
182 |
+
# step=0,
|
183 |
+
# )
|
184 |
+
def make_data_yielder(dataloader):
|
185 |
+
while True:
|
186 |
+
for batch in dataloader:
|
187 |
+
yield batch
|
188 |
+
accelerator.wait_for_everyone()
|
189 |
+
|
190 |
+
train_data_yielder = make_data_yielder(train_dataloader)
|
191 |
+
|
192 |
+
|
193 |
+
batch = next(train_data_yielder)
|
194 |
+
if validation_sample_logger_config.get('use_train_latents', False):
|
195 |
+
# Precompute the latents for this video to align the initial latents in training and test
|
196 |
+
assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video"
|
197 |
+
# we only inference for latents, no training
|
198 |
+
vae.eval()
|
199 |
+
text_encoder.eval()
|
200 |
+
unet.eval()
|
201 |
+
|
202 |
+
text_embeddings = pipeline._encode_prompt(
|
203 |
+
train_dataset.prompt,
|
204 |
+
device = accelerator.device,
|
205 |
+
num_images_per_prompt = 1,
|
206 |
+
do_classifier_free_guidance = True,
|
207 |
+
negative_prompt=None
|
208 |
+
)
|
209 |
+
|
210 |
+
use_inversion_attention = validation_sample_logger_config.get('use_inversion_attention', False)
|
211 |
+
batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted(
|
212 |
+
rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"),
|
213 |
+
batch_size = 1,
|
214 |
+
num_images_per_prompt = 1, # not sure how to use it
|
215 |
+
text_embeddings = text_embeddings,
|
216 |
+
prompt = train_dataset.prompt,
|
217 |
+
store_attention=use_inversion_attention,
|
218 |
+
LOW_RESOURCE = True, # not classifier-free guidance
|
219 |
+
save_path = logdir if verbose else None
|
220 |
+
)
|
221 |
+
|
222 |
+
batch['ddim_init_latents'] = batch['latents_all_step'][-1]
|
223 |
+
|
224 |
+
else:
|
225 |
+
batch['ddim_init_latents'] = None
|
226 |
+
|
227 |
+
vae.eval()
|
228 |
+
text_encoder.eval()
|
229 |
+
unet.eval()
|
230 |
+
|
231 |
+
# with accelerator.accumulate(unet):
|
232 |
+
# Convert images to latent space
|
233 |
+
images = batch["images"].to(dtype=weight_dtype)
|
234 |
+
images = rearrange(images, "b c f h w -> (b f) c h w")
|
235 |
+
|
236 |
+
|
237 |
+
if accelerator.is_main_process:
|
238 |
+
|
239 |
+
if validation_sample_logger is not None:
|
240 |
+
unet.eval()
|
241 |
+
samples_all, save_path = validation_sample_logger.log_sample_images(
|
242 |
+
image=images, # torch.Size([8, 3, 512, 512])
|
243 |
+
pipeline=pipeline,
|
244 |
+
device=accelerator.device,
|
245 |
+
step=0,
|
246 |
+
latents = batch['ddim_init_latents'],
|
247 |
+
save_dir = logdir if verbose else None
|
248 |
+
)
|
249 |
+
# accelerator.log(logs, step=step)
|
250 |
+
print('accelerator.end_training()')
|
251 |
+
accelerator.end_training()
|
252 |
+
return save_path
|
253 |
+
|
254 |
+
|
255 |
+
# @click.command()
|
256 |
+
# @click.option("--config", type=str, default="FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml")
|
257 |
+
def run(config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'):
|
258 |
+
print(f'in run function {config}')
|
259 |
+
Omegadict = OmegaConf.load(config)
|
260 |
+
if 'unet' in os.listdir(Omegadict['pretrained_model_path']):
|
261 |
+
test(config=config, **Omegadict)
|
262 |
+
print('test finished')
|
263 |
+
return '/home/cqiaa/diffusion/hugging_face/Tune-A-Video-inference/FateZero/result/low_resource_teaser/jeep_watercolor_ddim_10_steps_230327-200651/sample/step_0_0_0.mp4'
|
264 |
+
else:
|
265 |
+
# Go through all ckpt if possible
|
266 |
+
checkpoint_list = sorted(glob(os.path.join(Omegadict['pretrained_model_path'], 'checkpoint_*')))
|
267 |
+
print('checkpoint to evaluate:')
|
268 |
+
for checkpoint in checkpoint_list:
|
269 |
+
epoch = checkpoint.split('_')[-1]
|
270 |
+
|
271 |
+
for checkpoint in tqdm(checkpoint_list):
|
272 |
+
epoch = checkpoint.split('_')[-1]
|
273 |
+
if 'pretrained_epoch_list' not in Omegadict or int(epoch) in Omegadict['pretrained_epoch_list']:
|
274 |
+
print(f'Evaluate {checkpoint}')
|
275 |
+
# Update saving dir and ckpt
|
276 |
+
Omegadict_checkpoint = copy.deepcopy(Omegadict)
|
277 |
+
Omegadict_checkpoint['pretrained_model_path'] = checkpoint
|
278 |
+
|
279 |
+
if 'logdir' not in Omegadict_checkpoint:
|
280 |
+
logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
|
281 |
+
logdir += f"/{os.path.basename(checkpoint)}"
|
282 |
+
|
283 |
+
Omegadict_checkpoint['logdir'] = logdir
|
284 |
+
print(f'Saving at {logdir}')
|
285 |
+
|
286 |
+
test(config=config, **Omegadict_checkpoint)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
run('FateZero/config/teaser/jeep_watercolor.yaml')
|
FateZero/test_fatezero_dataset.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from test_fatezero import *
|
4 |
+
from glob import glob
|
5 |
+
import copy
|
6 |
+
|
7 |
+
@click.command()
|
8 |
+
@click.option("--edit_config", type=str, default="config/supp/style/0313_style_edit_warp_640.yaml")
|
9 |
+
@click.option("--dataset_config", type=str, default="data/supp_edit_dataset/dataset_prompt.yaml")
|
10 |
+
def run(edit_config, dataset_config):
|
11 |
+
Omegadict_edit_config = OmegaConf.load(edit_config)
|
12 |
+
Omegadict_dataset_config = OmegaConf.load(dataset_config)
|
13 |
+
|
14 |
+
# Go trough all data sample
|
15 |
+
data_sample_list = sorted(Omegadict_dataset_config.keys())
|
16 |
+
print(f'Datasample to evaluate: {data_sample_list}')
|
17 |
+
dataset_time_string = get_time_string()
|
18 |
+
for data_sample in data_sample_list:
|
19 |
+
print(f'Evaluate {data_sample}')
|
20 |
+
|
21 |
+
for p2p_config_index, p2p_config in Omegadict_edit_config['validation_sample_logger_config']['p2p_config'].items():
|
22 |
+
edit_config_now = copy.deepcopy(Omegadict_edit_config)
|
23 |
+
edit_config_now['train_dataset'] = copy.deepcopy(Omegadict_dataset_config[data_sample])
|
24 |
+
edit_config_now['train_dataset'].pop('target')
|
25 |
+
if 'eq_params' in edit_config_now['train_dataset']:
|
26 |
+
edit_config_now['train_dataset'].pop('eq_params')
|
27 |
+
# edit_config_now['train_dataset']['prompt'] = Omegadict_dataset_config[data_sample]['source']
|
28 |
+
|
29 |
+
edit_config_now['validation_sample_logger_config']['prompts'] \
|
30 |
+
= copy.deepcopy( [Omegadict_dataset_config[data_sample]['prompt'],]+ OmegaConf.to_object(Omegadict_dataset_config[data_sample]['target']))
|
31 |
+
p2p_config_now = dict()
|
32 |
+
for i in range(len(edit_config_now['validation_sample_logger_config']['prompts'])):
|
33 |
+
p2p_config_now[i] = p2p_config
|
34 |
+
if 'eq_params' in Omegadict_dataset_config[data_sample]:
|
35 |
+
p2p_config_now[i]['eq_params'] = Omegadict_dataset_config[data_sample]['eq_params']
|
36 |
+
|
37 |
+
edit_config_now['validation_sample_logger_config']['p2p_config'] = copy.deepcopy(p2p_config_now)
|
38 |
+
edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['prompt']
|
39 |
+
# edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['eq_params']
|
40 |
+
|
41 |
+
|
42 |
+
# if 'logdir' not in edit_config_now:
|
43 |
+
logdir = edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_config_{p2p_config_index}'+f'_{os.path.basename(dataset_config)[:-5]}'+f'_{dataset_time_string}'
|
44 |
+
logdir += f"/{data_sample}"
|
45 |
+
edit_config_now['logdir'] = logdir
|
46 |
+
print(f'Saving at {logdir}')
|
47 |
+
|
48 |
+
test(config=edit_config, **edit_config_now)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
run()
|
FateZero/test_install.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
import sys
|
5 |
+
print(f"python version {sys.version}")
|
6 |
+
print(f"torch version {torch.__version__}")
|
7 |
+
print(f"validate gpu status:")
|
8 |
+
print( torch.tensor(1.0).cuda()*2)
|
9 |
+
os.system("nvcc --version")
|
10 |
+
|
11 |
+
import diffusers
|
12 |
+
print(diffusers.__version__)
|
13 |
+
print(diffusers.__file__)
|
14 |
+
|
15 |
+
try:
|
16 |
+
import bitsandbytes
|
17 |
+
print(bitsandbytes.__file__)
|
18 |
+
except:
|
19 |
+
print("fail to import bitsandbytes")
|
20 |
+
|
21 |
+
os.system("accelerate env")
|
22 |
+
|
23 |
+
os.system("python -m xformers.info")
|
FateZero/train_tune_a_video.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os,copy
|
2 |
+
import inspect
|
3 |
+
from typing import Optional, List, Dict, Union
|
4 |
+
import PIL
|
5 |
+
import click
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
|
13 |
+
from accelerate import Accelerator
|
14 |
+
from accelerate.utils import set_seed
|
15 |
+
from diffusers import (
|
16 |
+
AutoencoderKL,
|
17 |
+
DDPMScheduler,
|
18 |
+
DDIMScheduler,
|
19 |
+
UNet2DConditionModel,
|
20 |
+
)
|
21 |
+
from diffusers.optimization import get_scheduler
|
22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
23 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
24 |
+
|
25 |
+
from tqdm.auto import tqdm
|
26 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
27 |
+
from einops import rearrange
|
28 |
+
|
29 |
+
from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
|
30 |
+
from video_diffusion.data.dataset import ImageSequenceDataset
|
31 |
+
from video_diffusion.common.util import get_time_string, get_function_args
|
32 |
+
from video_diffusion.common.logger import get_logger_config_path
|
33 |
+
from video_diffusion.common.image_util import log_train_samples, log_train_reg_samples
|
34 |
+
from video_diffusion.common.instantiate_from_config import instantiate_from_config, get_obj_from_str
|
35 |
+
from video_diffusion.pipelines.validation_loop import SampleLogger
|
36 |
+
|
37 |
+
|
38 |
+
def collate_fn(examples):
|
39 |
+
batch = {
|
40 |
+
"prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0),
|
41 |
+
"images": torch.stack([example["images"] for example in examples]),
|
42 |
+
|
43 |
+
}
|
44 |
+
if "class_images" in examples[0]:
|
45 |
+
batch["class_prompt_ids"] = torch.cat([example["class_prompt_ids"] for example in examples], dim=0)
|
46 |
+
batch["class_images"] = torch.stack([example["class_images"] for example in examples])
|
47 |
+
return batch
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def train(
|
52 |
+
config: str,
|
53 |
+
pretrained_model_path: str,
|
54 |
+
train_dataset: Dict,
|
55 |
+
logdir: str = None,
|
56 |
+
train_steps: int = 300,
|
57 |
+
validation_steps: int = 1000,
|
58 |
+
validation_sample_logger_config: Optional[Dict] = None,
|
59 |
+
test_pipeline_config: Optional[Dict] = dict(),
|
60 |
+
trainer_pipeline_config: Optional[Dict] = dict(),
|
61 |
+
gradient_accumulation_steps: int = 1,
|
62 |
+
seed: Optional[int] = None,
|
63 |
+
mixed_precision: Optional[str] = "fp16",
|
64 |
+
enable_xformers: bool = True,
|
65 |
+
train_batch_size: int = 1,
|
66 |
+
learning_rate: float = 3e-5,
|
67 |
+
scale_lr: bool = False,
|
68 |
+
lr_scheduler: str = "constant", # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
|
69 |
+
lr_warmup_steps: int = 0,
|
70 |
+
use_8bit_adam: bool = True,
|
71 |
+
adam_beta1: float = 0.9,
|
72 |
+
adam_beta2: float = 0.999,
|
73 |
+
adam_weight_decay: float = 1e-2,
|
74 |
+
adam_epsilon: float = 1e-08,
|
75 |
+
max_grad_norm: float = 1.0,
|
76 |
+
gradient_checkpointing: bool = False,
|
77 |
+
train_temporal_conv: bool = False,
|
78 |
+
checkpointing_steps: int = 1000,
|
79 |
+
model_config: dict={},
|
80 |
+
# use_train_latents: bool=False,
|
81 |
+
# kwr
|
82 |
+
# **kwargs
|
83 |
+
):
|
84 |
+
args = get_function_args()
|
85 |
+
# args.update(kwargs)
|
86 |
+
train_dataset_config = copy.deepcopy(train_dataset)
|
87 |
+
time_string = get_time_string()
|
88 |
+
if logdir is None:
|
89 |
+
logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')
|
90 |
+
logdir += f"_{time_string}"
|
91 |
+
|
92 |
+
accelerator = Accelerator(
|
93 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
94 |
+
mixed_precision=mixed_precision,
|
95 |
+
)
|
96 |
+
if accelerator.is_main_process:
|
97 |
+
os.makedirs(logdir, exist_ok=True)
|
98 |
+
OmegaConf.save(args, os.path.join(logdir, "config.yml"))
|
99 |
+
logger = get_logger_config_path(logdir)
|
100 |
+
if seed is not None:
|
101 |
+
set_seed(seed)
|
102 |
+
|
103 |
+
# Load the tokenizer
|
104 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
105 |
+
pretrained_model_path,
|
106 |
+
subfolder="tokenizer",
|
107 |
+
use_fast=False,
|
108 |
+
)
|
109 |
+
|
110 |
+
# Load models and create wrapper for stable diffusion
|
111 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
112 |
+
pretrained_model_path,
|
113 |
+
subfolder="text_encoder",
|
114 |
+
)
|
115 |
+
|
116 |
+
vae = AutoencoderKL.from_pretrained(
|
117 |
+
pretrained_model_path,
|
118 |
+
subfolder="vae",
|
119 |
+
)
|
120 |
+
|
121 |
+
unet = UNetPseudo3DConditionModel.from_2d_model(
|
122 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
if 'target' not in test_pipeline_config:
|
127 |
+
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
128 |
+
|
129 |
+
pipeline = instantiate_from_config(
|
130 |
+
test_pipeline_config,
|
131 |
+
vae=vae,
|
132 |
+
text_encoder=text_encoder,
|
133 |
+
tokenizer=tokenizer,
|
134 |
+
unet=unet,
|
135 |
+
scheduler=DDIMScheduler.from_pretrained(
|
136 |
+
pretrained_model_path,
|
137 |
+
subfolder="scheduler",
|
138 |
+
),
|
139 |
+
)
|
140 |
+
pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps'])
|
141 |
+
pipeline.set_progress_bar_config(disable=True)
|
142 |
+
|
143 |
+
|
144 |
+
if is_xformers_available() and enable_xformers:
|
145 |
+
# if False: # Disable xformers for null inversion
|
146 |
+
try:
|
147 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
148 |
+
print('enable xformers in the training and testing')
|
149 |
+
except Exception as e:
|
150 |
+
logger.warning(
|
151 |
+
"Could not enable memory efficient attention. Make sure xformers is installed"
|
152 |
+
f" correctly and a GPU is available: {e}"
|
153 |
+
)
|
154 |
+
|
155 |
+
vae.requires_grad_(False)
|
156 |
+
unet.requires_grad_(False)
|
157 |
+
text_encoder.requires_grad_(False)
|
158 |
+
|
159 |
+
# Start of config trainable parameters in Unet and optimizer
|
160 |
+
trainable_modules = ("attn_temporal", ".to_q")
|
161 |
+
if train_temporal_conv:
|
162 |
+
trainable_modules += ("conv_temporal",)
|
163 |
+
for name, module in unet.named_modules():
|
164 |
+
if name.endswith(trainable_modules):
|
165 |
+
for params in module.parameters():
|
166 |
+
params.requires_grad = True
|
167 |
+
|
168 |
+
|
169 |
+
if gradient_checkpointing:
|
170 |
+
print('enable gradient checkpointing in the training and testing')
|
171 |
+
unet.enable_gradient_checkpointing()
|
172 |
+
|
173 |
+
if scale_lr:
|
174 |
+
learning_rate = (
|
175 |
+
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
176 |
+
)
|
177 |
+
|
178 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
179 |
+
if use_8bit_adam:
|
180 |
+
try:
|
181 |
+
import bitsandbytes as bnb
|
182 |
+
except ImportError:
|
183 |
+
raise ImportError(
|
184 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
185 |
+
)
|
186 |
+
|
187 |
+
optimizer_class = bnb.optim.AdamW8bit
|
188 |
+
else:
|
189 |
+
optimizer_class = torch.optim.AdamW
|
190 |
+
|
191 |
+
params_to_optimize = unet.parameters()
|
192 |
+
num_trainable_modules = 0
|
193 |
+
num_trainable_params = 0
|
194 |
+
num_unet_params = 0
|
195 |
+
for params in params_to_optimize:
|
196 |
+
num_unet_params += params.numel()
|
197 |
+
if params.requires_grad == True:
|
198 |
+
num_trainable_modules +=1
|
199 |
+
num_trainable_params += params.numel()
|
200 |
+
|
201 |
+
logger.info(f"Num of trainable modules: {num_trainable_modules}")
|
202 |
+
logger.info(f"Num of trainable params: {num_trainable_params/(1024*1024):.2f} M")
|
203 |
+
logger.info(f"Num of unet params: {num_unet_params/(1024*1024):.2f} M ")
|
204 |
+
|
205 |
+
|
206 |
+
params_to_optimize = unet.parameters()
|
207 |
+
optimizer = optimizer_class(
|
208 |
+
params_to_optimize,
|
209 |
+
lr=learning_rate,
|
210 |
+
betas=(adam_beta1, adam_beta2),
|
211 |
+
weight_decay=adam_weight_decay,
|
212 |
+
eps=adam_epsilon,
|
213 |
+
)
|
214 |
+
# End of config trainable parameters in Unet and optimizer
|
215 |
+
|
216 |
+
|
217 |
+
prompt_ids = tokenizer(
|
218 |
+
train_dataset["prompt"],
|
219 |
+
truncation=True,
|
220 |
+
padding="max_length",
|
221 |
+
max_length=tokenizer.model_max_length,
|
222 |
+
return_tensors="pt",
|
223 |
+
).input_ids
|
224 |
+
|
225 |
+
if 'class_data_root' in train_dataset_config:
|
226 |
+
if 'class_data_prompt' not in train_dataset_config:
|
227 |
+
train_dataset_config['class_data_prompt'] = train_dataset_config['prompt']
|
228 |
+
class_prompt_ids = tokenizer(
|
229 |
+
train_dataset_config["class_data_prompt"],
|
230 |
+
truncation=True,
|
231 |
+
padding="max_length",
|
232 |
+
max_length=tokenizer.model_max_length,
|
233 |
+
return_tensors="pt",
|
234 |
+
).input_ids
|
235 |
+
else:
|
236 |
+
class_prompt_ids = None
|
237 |
+
train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids, class_prompt_ids=class_prompt_ids)
|
238 |
+
|
239 |
+
train_dataloader = torch.utils.data.DataLoader(
|
240 |
+
train_dataset,
|
241 |
+
batch_size=train_batch_size,
|
242 |
+
shuffle=True,
|
243 |
+
num_workers=16,
|
244 |
+
collate_fn=collate_fn,
|
245 |
+
)
|
246 |
+
|
247 |
+
train_sample_save_path = os.path.join(logdir, "train_samples.gif")
|
248 |
+
log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader)
|
249 |
+
if 'class_data_root' in train_dataset_config:
|
250 |
+
log_train_reg_samples(save_path=train_sample_save_path.replace('train_samples', 'class_data_samples'), train_dataloader=train_dataloader)
|
251 |
+
|
252 |
+
# Prepare learning rate scheduler in accelerate config
|
253 |
+
lr_scheduler = get_scheduler(
|
254 |
+
lr_scheduler,
|
255 |
+
optimizer=optimizer,
|
256 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
257 |
+
num_training_steps=train_steps * gradient_accumulation_steps,
|
258 |
+
)
|
259 |
+
|
260 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
261 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
262 |
+
)
|
263 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
264 |
+
|
265 |
+
weight_dtype = torch.float32
|
266 |
+
if accelerator.mixed_precision == "fp16":
|
267 |
+
weight_dtype = torch.float16
|
268 |
+
print('enable float16 in the training and testing')
|
269 |
+
elif accelerator.mixed_precision == "bf16":
|
270 |
+
weight_dtype = torch.bfloat16
|
271 |
+
|
272 |
+
# Move text_encode and vae to gpu.
|
273 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
274 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
275 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
276 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
277 |
+
|
278 |
+
|
279 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
280 |
+
# The trackers initializes automatically on the main process.
|
281 |
+
if accelerator.is_main_process:
|
282 |
+
accelerator.init_trackers("video") # , config=vars(args))
|
283 |
+
|
284 |
+
# Start of config trainer
|
285 |
+
trainer = instantiate_from_config(
|
286 |
+
trainer_pipeline_config,
|
287 |
+
vae=vae,
|
288 |
+
text_encoder=text_encoder,
|
289 |
+
tokenizer=tokenizer,
|
290 |
+
unet=unet,
|
291 |
+
scheduler= DDPMScheduler.from_pretrained(
|
292 |
+
pretrained_model_path,
|
293 |
+
subfolder="scheduler",
|
294 |
+
),
|
295 |
+
# training hyperparams
|
296 |
+
weight_dtype=weight_dtype,
|
297 |
+
accelerator=accelerator,
|
298 |
+
optimizer=optimizer,
|
299 |
+
max_grad_norm=max_grad_norm,
|
300 |
+
lr_scheduler=lr_scheduler,
|
301 |
+
prior_preservation=None
|
302 |
+
)
|
303 |
+
trainer.print_pipeline(logger)
|
304 |
+
# Train!
|
305 |
+
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
306 |
+
logger.info("***** Running training *****")
|
307 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
308 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
309 |
+
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
310 |
+
logger.info(
|
311 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
312 |
+
)
|
313 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
314 |
+
logger.info(f" Total optimization steps = {train_steps}")
|
315 |
+
step = 0
|
316 |
+
# End of config trainer
|
317 |
+
|
318 |
+
if validation_sample_logger_config is not None and accelerator.is_main_process:
|
319 |
+
validation_sample_logger = SampleLogger(**validation_sample_logger_config, logdir=logdir)
|
320 |
+
|
321 |
+
|
322 |
+
# Only show the progress bar once on each machine.
|
323 |
+
progress_bar = tqdm(
|
324 |
+
range(step, train_steps),
|
325 |
+
disable=not accelerator.is_local_main_process,
|
326 |
+
)
|
327 |
+
progress_bar.set_description("Steps")
|
328 |
+
|
329 |
+
def make_data_yielder(dataloader):
|
330 |
+
while True:
|
331 |
+
for batch in dataloader:
|
332 |
+
yield batch
|
333 |
+
accelerator.wait_for_everyone()
|
334 |
+
|
335 |
+
train_data_yielder = make_data_yielder(train_dataloader)
|
336 |
+
|
337 |
+
|
338 |
+
assert(train_dataset.overfit_length == 1), "Only support overfiting on a single video"
|
339 |
+
# batch = next(train_data_yielder)
|
340 |
+
|
341 |
+
|
342 |
+
while step < train_steps:
|
343 |
+
batch = next(train_data_yielder)
|
344 |
+
"""************************* start of an iteration*******************************"""
|
345 |
+
loss = trainer.step(batch)
|
346 |
+
# torch.cuda.empty_cache()
|
347 |
+
|
348 |
+
"""************************* end of an iteration*******************************"""
|
349 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
350 |
+
if accelerator.sync_gradients:
|
351 |
+
progress_bar.update(1)
|
352 |
+
step += 1
|
353 |
+
|
354 |
+
if accelerator.is_main_process:
|
355 |
+
|
356 |
+
if validation_sample_logger is not None and (step % validation_steps == 0):
|
357 |
+
unet.eval()
|
358 |
+
|
359 |
+
val_image = rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w")
|
360 |
+
|
361 |
+
# Unet is changing in different iteration; we should invert online
|
362 |
+
if validation_sample_logger_config.get('use_train_latents', False):
|
363 |
+
# Precompute the latents for this video to align the initial latents in training and test
|
364 |
+
assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video"
|
365 |
+
# we only inference for latents, no training
|
366 |
+
vae.eval()
|
367 |
+
text_encoder.eval()
|
368 |
+
unet.eval()
|
369 |
+
|
370 |
+
text_embeddings = pipeline._encode_prompt(
|
371 |
+
train_dataset.prompt,
|
372 |
+
device = accelerator.device,
|
373 |
+
num_images_per_prompt = 1,
|
374 |
+
do_classifier_free_guidance = True,
|
375 |
+
negative_prompt=None
|
376 |
+
)
|
377 |
+
batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted(
|
378 |
+
rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"),
|
379 |
+
batch_size = 1 ,
|
380 |
+
num_images_per_prompt = 1, # not sure how to use it
|
381 |
+
text_embeddings = text_embeddings
|
382 |
+
)
|
383 |
+
batch['ddim_init_latents'] = batch['latents_all_step'][-1]
|
384 |
+
else:
|
385 |
+
batch['ddim_init_latents'] = None
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
validation_sample_logger.log_sample_images(
|
390 |
+
# image=rearrange(train_dataset.get_all()["images"].to(accelerator.device, dtype=weight_dtype), "c f h w -> f c h w"), # torch.Size([8, 3, 512, 512])
|
391 |
+
image= val_image, # torch.Size([8, 3, 512, 512])
|
392 |
+
pipeline=pipeline,
|
393 |
+
device=accelerator.device,
|
394 |
+
step=step,
|
395 |
+
latents = batch['ddim_init_latents'],
|
396 |
+
)
|
397 |
+
torch.cuda.empty_cache()
|
398 |
+
unet.train()
|
399 |
+
|
400 |
+
if step % checkpointing_steps == 0:
|
401 |
+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
|
402 |
+
inspect.signature(accelerator.unwrap_model).parameters.keys()
|
403 |
+
)
|
404 |
+
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
|
405 |
+
pipeline_save = get_obj_from_str(test_pipeline_config["target"]).from_pretrained(
|
406 |
+
pretrained_model_path,
|
407 |
+
unet=accelerator.unwrap_model(unet, **extra_args),
|
408 |
+
)
|
409 |
+
checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}")
|
410 |
+
pipeline_save.save_pretrained(checkpoint_save_path)
|
411 |
+
|
412 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
413 |
+
progress_bar.set_postfix(**logs)
|
414 |
+
accelerator.log(logs, step=step)
|
415 |
+
|
416 |
+
accelerator.end_training()
|
417 |
+
|
418 |
+
|
419 |
+
@click.command()
|
420 |
+
@click.option("--config", type=str, default="config/sample.yml")
|
421 |
+
def run(config):
|
422 |
+
train(config=config, **OmegaConf.load(config))
|
423 |
+
|
424 |
+
|
425 |
+
if __name__ == "__main__":
|
426 |
+
run()
|
FateZero/video_diffusion/common/image_util.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import textwrap
|
4 |
+
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
from typing import Sequence
|
8 |
+
import requests
|
9 |
+
import cv2
|
10 |
+
from PIL import Image, ImageDraw, ImageFont
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torchvision import transforms
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
22 |
+
|
23 |
+
FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf"
|
24 |
+
FONT_PATH = "./docs/OpenSans-Regular.ttf"
|
25 |
+
|
26 |
+
|
27 |
+
def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image:
|
28 |
+
new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color)
|
29 |
+
new_image.paste(image, (left, top))
|
30 |
+
return new_image
|
31 |
+
|
32 |
+
|
33 |
+
def download_font_opensans(path=FONT_PATH):
|
34 |
+
font_url = FONT_URL
|
35 |
+
response = requests.get(font_url)
|
36 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
37 |
+
with open(path, "wb") as f:
|
38 |
+
f.write(response.content)
|
39 |
+
|
40 |
+
|
41 |
+
def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image:
|
42 |
+
image_w = image.width
|
43 |
+
_, _, text_w, text_h = font.getbbox(text)
|
44 |
+
line_size = math.floor(len(text) * image_w / text_w)
|
45 |
+
|
46 |
+
lines = textwrap.wrap(text, width=line_size)
|
47 |
+
padding = text_h * len(lines)
|
48 |
+
image = pad(image, top=padding + 3)
|
49 |
+
|
50 |
+
ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font)
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
def annotate_image(image: Image.Image, text: str, font_size: int = 15):
|
55 |
+
if not os.path.isfile(FONT_PATH):
|
56 |
+
download_font_opensans()
|
57 |
+
font = ImageFont.truetype(FONT_PATH, size=font_size)
|
58 |
+
return annotate_image_with_font(image=image, text=text, font=font)
|
59 |
+
|
60 |
+
|
61 |
+
def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image:
|
62 |
+
if isinstance(images[0], np.ndarray):
|
63 |
+
images = [Image.fromarray(i) for i in images]
|
64 |
+
|
65 |
+
if rows is None:
|
66 |
+
assert cols is not None
|
67 |
+
rows = math.ceil(len(images) / cols)
|
68 |
+
else:
|
69 |
+
cols = math.ceil(len(images) / rows)
|
70 |
+
|
71 |
+
w, h = images[0].size
|
72 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
73 |
+
for i, image in enumerate(images):
|
74 |
+
if image.size != (w, h):
|
75 |
+
image = image.resize((w, h))
|
76 |
+
grid.paste(image, box=(i % cols * w, i // cols * h))
|
77 |
+
return grid
|
78 |
+
|
79 |
+
|
80 |
+
def save_images_as_gif(
|
81 |
+
images: Sequence[Image.Image],
|
82 |
+
save_path: str,
|
83 |
+
loop=0,
|
84 |
+
duration=100,
|
85 |
+
optimize=False,
|
86 |
+
) -> None:
|
87 |
+
|
88 |
+
images[0].save(
|
89 |
+
save_path,
|
90 |
+
save_all=True,
|
91 |
+
append_images=images[1:],
|
92 |
+
optimize=optimize,
|
93 |
+
loop=loop,
|
94 |
+
duration=duration,
|
95 |
+
)
|
96 |
+
|
97 |
+
def save_images_as_mp4(
|
98 |
+
images: Sequence[Image.Image],
|
99 |
+
save_path: str,
|
100 |
+
) -> None:
|
101 |
+
# images[0].save(
|
102 |
+
# save_path,
|
103 |
+
# save_all=True,
|
104 |
+
# append_images=images[1:],
|
105 |
+
# optimize=optimize,
|
106 |
+
# loop=loop,
|
107 |
+
# duration=duration,
|
108 |
+
# )
|
109 |
+
writer_edit = imageio.get_writer(
|
110 |
+
save_path,
|
111 |
+
fps=10)
|
112 |
+
for i in images:
|
113 |
+
init_image = i.convert("RGB")
|
114 |
+
writer_edit.append_data(np.array(init_image))
|
115 |
+
writer_edit.close()
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def save_images_as_folder(
|
120 |
+
images: Sequence[Image.Image],
|
121 |
+
save_path: str,
|
122 |
+
) -> None:
|
123 |
+
os.makedirs(save_path, exist_ok=True)
|
124 |
+
for index, image in enumerate(images):
|
125 |
+
init_image = image
|
126 |
+
if len(np.array(init_image).shape) == 3:
|
127 |
+
cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)[:, :, ::-1])
|
128 |
+
else:
|
129 |
+
cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image))
|
130 |
+
|
131 |
+
def log_train_samples(
|
132 |
+
train_dataloader,
|
133 |
+
save_path,
|
134 |
+
num_batch: int = 4,
|
135 |
+
):
|
136 |
+
train_samples = []
|
137 |
+
for idx, batch in enumerate(train_dataloader):
|
138 |
+
if idx >= num_batch:
|
139 |
+
break
|
140 |
+
train_samples.append(batch["images"])
|
141 |
+
|
142 |
+
train_samples = torch.cat(train_samples).numpy()
|
143 |
+
train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
|
144 |
+
train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
|
145 |
+
train_samples = numpy_batch_seq_to_pil(train_samples)
|
146 |
+
train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
|
147 |
+
# save_images_as_gif(train_samples, save_path)
|
148 |
+
save_gif_mp4_folder_type(train_samples, save_path)
|
149 |
+
|
150 |
+
def log_train_reg_samples(
|
151 |
+
train_dataloader,
|
152 |
+
save_path,
|
153 |
+
num_batch: int = 4,
|
154 |
+
):
|
155 |
+
train_samples = []
|
156 |
+
for idx, batch in enumerate(train_dataloader):
|
157 |
+
if idx >= num_batch:
|
158 |
+
break
|
159 |
+
train_samples.append(batch["class_images"])
|
160 |
+
|
161 |
+
train_samples = torch.cat(train_samples).numpy()
|
162 |
+
train_samples = rearrange(train_samples, "b c f h w -> b f h w c")
|
163 |
+
train_samples = (train_samples * 0.5 + 0.5).clip(0, 1)
|
164 |
+
train_samples = numpy_batch_seq_to_pil(train_samples)
|
165 |
+
train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)]
|
166 |
+
# save_images_as_gif(train_samples, save_path)
|
167 |
+
save_gif_mp4_folder_type(train_samples, save_path)
|
168 |
+
|
169 |
+
|
170 |
+
def save_gif_mp4_folder_type(images, save_path, save_gif=False):
|
171 |
+
|
172 |
+
if isinstance(images[0], np.ndarray):
|
173 |
+
images = [Image.fromarray(i) for i in images]
|
174 |
+
elif isinstance(images[0], torch.Tensor):
|
175 |
+
images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images]
|
176 |
+
save_path_mp4 = save_path.replace('gif', 'mp4')
|
177 |
+
save_path_folder = save_path.replace('.gif', '')
|
178 |
+
if save_gif: save_images_as_gif(images, save_path)
|
179 |
+
save_images_as_mp4(images, save_path_mp4)
|
180 |
+
save_images_as_folder(images, save_path_folder)
|
181 |
+
|
182 |
+
# copy from video_diffusion/pipelines/stable_diffusion.py
|
183 |
+
def numpy_seq_to_pil(images):
|
184 |
+
"""
|
185 |
+
Convert a numpy image or a batch of images to a PIL image.
|
186 |
+
"""
|
187 |
+
if images.ndim == 3:
|
188 |
+
images = images[None, ...]
|
189 |
+
images = (images * 255).round().astype("uint8")
|
190 |
+
if images.shape[-1] == 1:
|
191 |
+
# special case for grayscale (single channel) images
|
192 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
193 |
+
else:
|
194 |
+
pil_images = [Image.fromarray(image) for image in images]
|
195 |
+
|
196 |
+
return pil_images
|
197 |
+
|
198 |
+
# copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py
|
199 |
+
def numpy_batch_seq_to_pil(images):
|
200 |
+
pil_images = []
|
201 |
+
for sequence in images:
|
202 |
+
pil_images.append(numpy_seq_to_pil(sequence))
|
203 |
+
return pil_images
|
FateZero/video_diffusion/common/instantiate_from_config.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copy from stable diffusion
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
|
6 |
+
|
7 |
+
def instantiate_from_config(config:dict, **args_from_code):
|
8 |
+
"""Util funciton to decompose differenct modules using config
|
9 |
+
|
10 |
+
Args:
|
11 |
+
config (dict): with key of "target" and "params", better from yaml
|
12 |
+
static
|
13 |
+
args_from_code: additional con
|
14 |
+
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
a validation/training pipeline, a module
|
18 |
+
"""
|
19 |
+
if not "target" in config:
|
20 |
+
if config == '__is_first_stage__':
|
21 |
+
return None
|
22 |
+
elif config == "__is_unconditional__":
|
23 |
+
return None
|
24 |
+
raise KeyError("Expected key `target` to instantiate.")
|
25 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code)
|
26 |
+
|
27 |
+
|
28 |
+
def get_obj_from_str(string, reload=False):
|
29 |
+
module, cls = string.rsplit(".", 1)
|
30 |
+
if reload:
|
31 |
+
module_imp = importlib.import_module(module)
|
32 |
+
importlib.reload(module_imp)
|
33 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
FateZero/video_diffusion/common/logger.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging, logging.handlers
|
3 |
+
from accelerate.logging import get_logger
|
4 |
+
|
5 |
+
def get_logger_config_path(logdir):
|
6 |
+
# accelerate handles the logger in multiprocessing
|
7 |
+
logger = get_logger(__name__)
|
8 |
+
logging.basicConfig(
|
9 |
+
level=logging.INFO,
|
10 |
+
format='%(asctime)s:%(levelname)s : %(message)s',
|
11 |
+
datefmt='%a, %d %b %Y %H:%M:%S',
|
12 |
+
filename=os.path.join(logdir, 'log.log'),
|
13 |
+
filemode='w')
|
14 |
+
chlr = logging.StreamHandler()
|
15 |
+
chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s'))
|
16 |
+
logger.logger.addHandler(chlr)
|
17 |
+
return logger
|
FateZero/video_diffusion/common/set_seed.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
|
8 |
+
from accelerate.utils import set_seed
|
9 |
+
|
10 |
+
|
11 |
+
def video_set_seed(seed: int):
|
12 |
+
"""
|
13 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
seed (`int`): The seed to set.
|
17 |
+
device_specific (`bool`, *optional*, defaults to `False`):
|
18 |
+
Whether to differ the seed on each device slightly with `self.process_index`.
|
19 |
+
"""
|
20 |
+
set_seed(seed)
|
21 |
+
random.seed(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
# torch.use_deterministic_algorithms(True, warn_only=True)
|
27 |
+
# [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic)
|
28 |
+
|
FateZero/video_diffusion/common/util.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import inspect
|
5 |
+
import datetime
|
6 |
+
from typing import List, Tuple, Optional, Dict
|
7 |
+
|
8 |
+
|
9 |
+
def glob_files(
|
10 |
+
root_path: str,
|
11 |
+
extensions: Tuple[str],
|
12 |
+
recursive: bool = True,
|
13 |
+
skip_hidden_directories: bool = True,
|
14 |
+
max_directories: Optional[int] = None,
|
15 |
+
max_files: Optional[int] = None,
|
16 |
+
relative_path: bool = False,
|
17 |
+
) -> Tuple[List[str], bool, bool]:
|
18 |
+
"""glob files with specified extensions
|
19 |
+
|
20 |
+
Args:
|
21 |
+
root_path (str): _description_
|
22 |
+
extensions (Tuple[str]): _description_
|
23 |
+
recursive (bool, optional): _description_. Defaults to True.
|
24 |
+
skip_hidden_directories (bool, optional): _description_. Defaults to True.
|
25 |
+
max_directories (Optional[int], optional): max number of directories to search. Defaults to None.
|
26 |
+
max_files (Optional[int], optional): max file number limit. Defaults to None.
|
27 |
+
relative_path (bool, optional): _description_. Defaults to False.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Tuple[List[str], bool, bool]: _description_
|
31 |
+
"""
|
32 |
+
paths = []
|
33 |
+
hit_max_directories = False
|
34 |
+
hit_max_files = False
|
35 |
+
for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)):
|
36 |
+
if skip_hidden_directories and os.path.basename(directory).startswith("."):
|
37 |
+
continue
|
38 |
+
|
39 |
+
if max_directories is not None and directory_idx >= max_directories:
|
40 |
+
hit_max_directories = True
|
41 |
+
break
|
42 |
+
|
43 |
+
paths += [
|
44 |
+
os.path.join(directory, fname)
|
45 |
+
for fname in sorted(fnames)
|
46 |
+
if fname.lower().endswith(extensions)
|
47 |
+
]
|
48 |
+
|
49 |
+
if not recursive:
|
50 |
+
break
|
51 |
+
|
52 |
+
if max_files is not None and len(paths) > max_files:
|
53 |
+
hit_max_files = True
|
54 |
+
paths = paths[:max_files]
|
55 |
+
break
|
56 |
+
|
57 |
+
if relative_path:
|
58 |
+
paths = [os.path.relpath(p, root_path) for p in paths]
|
59 |
+
|
60 |
+
return paths, hit_max_directories, hit_max_files
|
61 |
+
|
62 |
+
|
63 |
+
def get_time_string() -> str:
|
64 |
+
x = datetime.datetime.now()
|
65 |
+
return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}"
|
66 |
+
|
67 |
+
|
68 |
+
def get_function_args() -> Dict:
|
69 |
+
frame = sys._getframe(1)
|
70 |
+
args, _, _, values = inspect.getargvalues(frame)
|
71 |
+
args_dict = copy.deepcopy({arg: values[arg] for arg in args})
|
72 |
+
|
73 |
+
return args_dict
|
FateZero/video_diffusion/data/dataset.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from einops import rearrange
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
from .transform import short_size_scale, random_crop, center_crop, offset_crop
|
12 |
+
from ..common.image_util import IMAGE_EXTENSION
|
13 |
+
|
14 |
+
import sys
|
15 |
+
sys.path.append('FateZero')
|
16 |
+
|
17 |
+
class ImageSequenceDataset(Dataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
path: str,
|
21 |
+
prompt_ids: torch.Tensor,
|
22 |
+
prompt: str,
|
23 |
+
start_sample_frame: int=0,
|
24 |
+
n_sample_frame: int = 8,
|
25 |
+
sampling_rate: int = 1,
|
26 |
+
stride: int = 1,
|
27 |
+
image_mode: str = "RGB",
|
28 |
+
image_size: int = 512,
|
29 |
+
crop: str = "center",
|
30 |
+
|
31 |
+
class_data_root: str = None,
|
32 |
+
class_prompt_ids: torch.Tensor = None,
|
33 |
+
|
34 |
+
offset: dict = {
|
35 |
+
"left": 0,
|
36 |
+
"right": 0,
|
37 |
+
"top": 0,
|
38 |
+
"bottom": 0
|
39 |
+
}
|
40 |
+
):
|
41 |
+
self.path = path
|
42 |
+
self.images = self.get_image_list(path)
|
43 |
+
self.n_images = len(self.images)
|
44 |
+
self.offset = offset
|
45 |
+
|
46 |
+
if n_sample_frame < 0:
|
47 |
+
n_sample_frame = len(self.images)
|
48 |
+
self.start_sample_frame = start_sample_frame
|
49 |
+
|
50 |
+
self.n_sample_frame = n_sample_frame
|
51 |
+
self.sampling_rate = sampling_rate
|
52 |
+
|
53 |
+
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
|
54 |
+
if self.n_images < self.sequence_length:
|
55 |
+
raise ValueError("self.n_images < self.sequence_length")
|
56 |
+
self.stride = stride
|
57 |
+
|
58 |
+
self.image_mode = image_mode
|
59 |
+
self.image_size = image_size
|
60 |
+
crop_methods = {
|
61 |
+
"center": center_crop,
|
62 |
+
"random": random_crop,
|
63 |
+
}
|
64 |
+
if crop not in crop_methods:
|
65 |
+
raise ValueError
|
66 |
+
self.crop = crop_methods[crop]
|
67 |
+
|
68 |
+
self.prompt = prompt
|
69 |
+
self.prompt_ids = prompt_ids
|
70 |
+
self.overfit_length = (self.n_images - self.sequence_length) // self.stride + 1
|
71 |
+
# Negative prompt for regularization
|
72 |
+
if class_data_root is not None:
|
73 |
+
self.class_data_root = Path(class_data_root)
|
74 |
+
self.class_images_path = sorted(list(self.class_data_root.iterdir()))
|
75 |
+
self.num_class_images = len(self.class_images_path)
|
76 |
+
self.class_prompt_ids = class_prompt_ids
|
77 |
+
|
78 |
+
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
max_len = (self.n_images - self.sequence_length) // self.stride + 1
|
82 |
+
|
83 |
+
if hasattr(self, 'num_class_images'):
|
84 |
+
max_len = max(max_len, self.num_class_images)
|
85 |
+
# return (self.n_images - self.sequence_length) // self.stride + 1
|
86 |
+
return max_len
|
87 |
+
|
88 |
+
def __getitem__(self, index):
|
89 |
+
return_batch = {}
|
90 |
+
frame_indices = self.get_frame_indices(index%self.video_len)
|
91 |
+
frames = [self.load_frame(i) for i in frame_indices]
|
92 |
+
frames = self.transform(frames)
|
93 |
+
|
94 |
+
return_batch.update(
|
95 |
+
{
|
96 |
+
"images": frames,
|
97 |
+
"prompt_ids": self.prompt_ids,
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
if hasattr(self, 'class_data_root'):
|
102 |
+
class_index = index % (self.num_class_images - self.n_sample_frame)
|
103 |
+
class_indices = self.get_class_indices(class_index)
|
104 |
+
frames = [self.load_class_frame(i) for i in class_indices]
|
105 |
+
return_batch["class_images"] = self.tensorize_frames(frames)
|
106 |
+
return_batch["class_prompt_ids"] = self.class_prompt_ids
|
107 |
+
return return_batch
|
108 |
+
|
109 |
+
def get_all(self, val_length=None):
|
110 |
+
if val_length is None:
|
111 |
+
val_length = len(self.images)
|
112 |
+
frame_indices = (i for i in range(val_length))
|
113 |
+
frames = [self.load_frame(i) for i in frame_indices]
|
114 |
+
frames = self.transform(frames)
|
115 |
+
|
116 |
+
return {
|
117 |
+
"images": frames,
|
118 |
+
"prompt_ids": self.prompt_ids,
|
119 |
+
}
|
120 |
+
|
121 |
+
def transform(self, frames):
|
122 |
+
frames = self.tensorize_frames(frames)
|
123 |
+
frames = offset_crop(frames, **self.offset)
|
124 |
+
frames = short_size_scale(frames, size=self.image_size)
|
125 |
+
frames = self.crop(frames, height=self.image_size, width=self.image_size)
|
126 |
+
return frames
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def tensorize_frames(frames):
|
130 |
+
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
|
131 |
+
return torch.from_numpy(frames).div(255) * 2 - 1
|
132 |
+
|
133 |
+
def load_frame(self, index):
|
134 |
+
image_path = os.path.join(self.path, self.images[index])
|
135 |
+
return Image.open(image_path).convert(self.image_mode)
|
136 |
+
|
137 |
+
def load_class_frame(self, index):
|
138 |
+
image_path = self.class_images_path[index]
|
139 |
+
return Image.open(image_path).convert(self.image_mode)
|
140 |
+
|
141 |
+
def get_frame_indices(self, index):
|
142 |
+
if self.start_sample_frame is not None:
|
143 |
+
frame_start = self.start_sample_frame + self.stride * index
|
144 |
+
else:
|
145 |
+
frame_start = self.stride * index
|
146 |
+
return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame))
|
147 |
+
|
148 |
+
def get_class_indices(self, index):
|
149 |
+
frame_start = index
|
150 |
+
return (frame_start + i for i in range(self.n_sample_frame))
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def get_image_list(path):
|
154 |
+
images = []
|
155 |
+
for file in sorted(os.listdir(path)):
|
156 |
+
if file.endswith(IMAGE_EXTENSION):
|
157 |
+
images.append(file)
|
158 |
+
return images
|
FateZero/video_diffusion/data/transform.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def short_size_scale(images, size):
|
7 |
+
h, w = images.shape[-2:]
|
8 |
+
short, long = (h, w) if h < w else (w, h)
|
9 |
+
|
10 |
+
scale = size / short
|
11 |
+
long_target = int(scale * long)
|
12 |
+
|
13 |
+
target_size = (size, long_target) if h < w else (long_target, size)
|
14 |
+
|
15 |
+
return torch.nn.functional.interpolate(
|
16 |
+
input=images, size=target_size, mode="bilinear", antialias=True
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def random_short_side_scale(images, size_min, size_max):
|
21 |
+
size = random.randint(size_min, size_max)
|
22 |
+
return short_size_scale(images, size)
|
23 |
+
|
24 |
+
|
25 |
+
def random_crop(images, height, width):
|
26 |
+
image_h, image_w = images.shape[-2:]
|
27 |
+
h_start = random.randint(0, image_h - height)
|
28 |
+
w_start = random.randint(0, image_w - width)
|
29 |
+
return images[:, :, h_start : h_start + height, w_start : w_start + width]
|
30 |
+
|
31 |
+
|
32 |
+
def center_crop(images, height, width):
|
33 |
+
# offset_crop(images, 0,0, 200, 0)
|
34 |
+
image_h, image_w = images.shape[-2:]
|
35 |
+
h_start = (image_h - height) // 2
|
36 |
+
w_start = (image_w - width) // 2
|
37 |
+
return images[:, :, h_start : h_start + height, w_start : w_start + width]
|
38 |
+
|
39 |
+
def offset_crop(image, left=0, right=0, top=200, bottom=0):
|
40 |
+
|
41 |
+
n, c, h, w = image.shape
|
42 |
+
left = min(left, w-1)
|
43 |
+
right = min(right, w - left - 1)
|
44 |
+
top = min(top, h - 1)
|
45 |
+
bottom = min(bottom, h - top - 1)
|
46 |
+
image = image[:, :, top:h-bottom, left:w-right]
|
47 |
+
|
48 |
+
return image
|
FateZero/video_diffusion/models/attention.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code mostly taken from https://github.com/huggingface/diffusers
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.modeling_utils import ModelMixin
|
10 |
+
from diffusers.models.attention import FeedForward, CrossAttention, AdaLayerNorm
|
11 |
+
from diffusers.utils import BaseOutput
|
12 |
+
from diffusers.utils.import_utils import is_xformers_available
|
13 |
+
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class SpatioTemporalTransformerModelOutput(BaseOutput):
|
19 |
+
"""torch.FloatTensor of shape [batch x channel x frames x height x width]"""
|
20 |
+
|
21 |
+
sample: torch.FloatTensor
|
22 |
+
|
23 |
+
|
24 |
+
if is_xformers_available():
|
25 |
+
import xformers
|
26 |
+
import xformers.ops
|
27 |
+
else:
|
28 |
+
xformers = None
|
29 |
+
|
30 |
+
|
31 |
+
class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin):
|
32 |
+
@register_to_config
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
num_attention_heads: int = 16,
|
36 |
+
attention_head_dim: int = 88,
|
37 |
+
in_channels: Optional[int] = None,
|
38 |
+
num_layers: int = 1,
|
39 |
+
dropout: float = 0.0,
|
40 |
+
norm_num_groups: int = 32,
|
41 |
+
cross_attention_dim: Optional[int] = None,
|
42 |
+
attention_bias: bool = False,
|
43 |
+
activation_fn: str = "geglu",
|
44 |
+
num_embeds_ada_norm: Optional[int] = None,
|
45 |
+
use_linear_projection: bool = False,
|
46 |
+
only_cross_attention: bool = False,
|
47 |
+
upcast_attention: bool = False,
|
48 |
+
model_config: dict = {},
|
49 |
+
**transformer_kwargs,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.use_linear_projection = use_linear_projection
|
53 |
+
self.num_attention_heads = num_attention_heads
|
54 |
+
self.attention_head_dim = attention_head_dim
|
55 |
+
inner_dim = num_attention_heads * attention_head_dim
|
56 |
+
|
57 |
+
# Define input layers
|
58 |
+
self.in_channels = in_channels
|
59 |
+
|
60 |
+
self.norm = torch.nn.GroupNorm(
|
61 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
62 |
+
)
|
63 |
+
if use_linear_projection:
|
64 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
65 |
+
else:
|
66 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
67 |
+
|
68 |
+
# Define transformers blocks
|
69 |
+
self.transformer_blocks = nn.ModuleList(
|
70 |
+
[
|
71 |
+
SpatioTemporalTransformerBlock(
|
72 |
+
inner_dim,
|
73 |
+
num_attention_heads,
|
74 |
+
attention_head_dim,
|
75 |
+
dropout=dropout,
|
76 |
+
cross_attention_dim=cross_attention_dim,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
attention_bias=attention_bias,
|
80 |
+
only_cross_attention=only_cross_attention,
|
81 |
+
upcast_attention=upcast_attention,
|
82 |
+
model_config=model_config,
|
83 |
+
**transformer_kwargs,
|
84 |
+
)
|
85 |
+
for d in range(num_layers)
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
# Define output layers
|
90 |
+
if use_linear_projection:
|
91 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
92 |
+
else:
|
93 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True
|
97 |
+
):
|
98 |
+
# 1. Input
|
99 |
+
clip_length = None
|
100 |
+
is_video = hidden_states.ndim == 5
|
101 |
+
if is_video:
|
102 |
+
clip_length = hidden_states.shape[2]
|
103 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
104 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0)
|
105 |
+
else:
|
106 |
+
# To adapt to classifier-free guidance where encoder_hidden_states=2
|
107 |
+
batch_size = hidden_states.shape[0]//encoder_hidden_states.shape[0]
|
108 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(batch_size, 0)
|
109 |
+
*_, h, w = hidden_states.shape
|
110 |
+
residual = hidden_states
|
111 |
+
|
112 |
+
hidden_states = self.norm(hidden_states)
|
113 |
+
if not self.use_linear_projection:
|
114 |
+
hidden_states = self.proj_in(hidden_states)
|
115 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") # (bf) (hw) c
|
116 |
+
else:
|
117 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
118 |
+
hidden_states = self.proj_in(hidden_states)
|
119 |
+
|
120 |
+
# 2. Blocks
|
121 |
+
for block in self.transformer_blocks:
|
122 |
+
hidden_states = block(
|
123 |
+
hidden_states, # [16, 4096, 320]
|
124 |
+
encoder_hidden_states=encoder_hidden_states, # ([1, 77, 768]
|
125 |
+
timestep=timestep,
|
126 |
+
clip_length=clip_length,
|
127 |
+
)
|
128 |
+
|
129 |
+
# 3. Output
|
130 |
+
if not self.use_linear_projection:
|
131 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
132 |
+
hidden_states = self.proj_out(hidden_states)
|
133 |
+
else:
|
134 |
+
hidden_states = self.proj_out(hidden_states)
|
135 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
136 |
+
|
137 |
+
output = hidden_states + residual
|
138 |
+
if is_video:
|
139 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length)
|
140 |
+
|
141 |
+
if not return_dict:
|
142 |
+
return (output,)
|
143 |
+
|
144 |
+
return SpatioTemporalTransformerModelOutput(sample=output)
|
145 |
+
|
146 |
+
import copy
|
147 |
+
class SpatioTemporalTransformerBlock(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
dim: int,
|
151 |
+
num_attention_heads: int,
|
152 |
+
attention_head_dim: int,
|
153 |
+
dropout=0.0,
|
154 |
+
cross_attention_dim: Optional[int] = None,
|
155 |
+
activation_fn: str = "geglu",
|
156 |
+
num_embeds_ada_norm: Optional[int] = None,
|
157 |
+
attention_bias: bool = False,
|
158 |
+
only_cross_attention: bool = False,
|
159 |
+
upcast_attention: bool = False,
|
160 |
+
use_sparse_causal_attention: bool = True,
|
161 |
+
temporal_attention_position: str = "after_feedforward",
|
162 |
+
model_config: dict = {}
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
self.only_cross_attention = only_cross_attention
|
167 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
168 |
+
self.use_sparse_causal_attention = use_sparse_causal_attention
|
169 |
+
# For safety, freeze the model_config
|
170 |
+
self.model_config = copy.deepcopy(model_config)
|
171 |
+
if 'least_sc_channel' in model_config:
|
172 |
+
if dim< model_config['least_sc_channel']:
|
173 |
+
self.model_config['SparseCausalAttention_index'] = []
|
174 |
+
|
175 |
+
self.temporal_attention_position = temporal_attention_position
|
176 |
+
temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"]
|
177 |
+
if temporal_attention_position not in temporal_attention_positions:
|
178 |
+
raise ValueError(
|
179 |
+
f"`temporal_attention_position` must be one of {temporal_attention_positions}"
|
180 |
+
)
|
181 |
+
|
182 |
+
# 1. Spatial-Attn
|
183 |
+
spatial_attention = SparseCausalAttention if use_sparse_causal_attention else CrossAttention
|
184 |
+
self.attn1 = spatial_attention(
|
185 |
+
query_dim=dim,
|
186 |
+
heads=num_attention_heads,
|
187 |
+
dim_head=attention_head_dim,
|
188 |
+
dropout=dropout,
|
189 |
+
bias=attention_bias,
|
190 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
191 |
+
upcast_attention=upcast_attention,
|
192 |
+
) # is a self-attention
|
193 |
+
self.norm1 = (
|
194 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
195 |
+
)
|
196 |
+
|
197 |
+
# 2. Cross-Attn
|
198 |
+
if cross_attention_dim is not None:
|
199 |
+
self.attn2 = CrossAttention(
|
200 |
+
query_dim=dim,
|
201 |
+
cross_attention_dim=cross_attention_dim,
|
202 |
+
heads=num_attention_heads,
|
203 |
+
dim_head=attention_head_dim,
|
204 |
+
dropout=dropout,
|
205 |
+
bias=attention_bias,
|
206 |
+
upcast_attention=upcast_attention,
|
207 |
+
) # is self-attn if encoder_hidden_states is none
|
208 |
+
self.norm2 = (
|
209 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
self.attn2 = None
|
213 |
+
self.norm2 = None
|
214 |
+
|
215 |
+
# 3. Temporal-Attn
|
216 |
+
self.attn_temporal = CrossAttention(
|
217 |
+
query_dim=dim,
|
218 |
+
heads=num_attention_heads,
|
219 |
+
dim_head=attention_head_dim,
|
220 |
+
dropout=dropout,
|
221 |
+
bias=attention_bias,
|
222 |
+
upcast_attention=upcast_attention,
|
223 |
+
)
|
224 |
+
nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) # initialize as an identity function
|
225 |
+
self.norm_temporal = (
|
226 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
227 |
+
)
|
228 |
+
# efficient_attention_backward_cutlass is not implemented for large channels
|
229 |
+
self.use_xformers = (dim <= 320) or "3090" not in torch.cuda.get_device_name(0)
|
230 |
+
|
231 |
+
# 4. Feed-forward
|
232 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
233 |
+
self.norm3 = nn.LayerNorm(dim)
|
234 |
+
|
235 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
236 |
+
if not is_xformers_available():
|
237 |
+
print("Here is how to install it")
|
238 |
+
raise ModuleNotFoundError(
|
239 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
240 |
+
" xformers",
|
241 |
+
name="xformers",
|
242 |
+
)
|
243 |
+
elif not torch.cuda.is_available():
|
244 |
+
raise ValueError(
|
245 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
246 |
+
" available for GPU "
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
try:
|
250 |
+
# Make sure we can run the memory efficient attention
|
251 |
+
if use_memory_efficient_attention_xformers is True:
|
252 |
+
|
253 |
+
_ = xformers.ops.memory_efficient_attention(
|
254 |
+
torch.randn((1, 2, 40), device="cuda"),
|
255 |
+
torch.randn((1, 2, 40), device="cuda"),
|
256 |
+
torch.randn((1, 2, 40), device="cuda"),
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
|
260 |
+
pass
|
261 |
+
except Exception as e:
|
262 |
+
raise e
|
263 |
+
# self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
264 |
+
# self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
265 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers
|
266 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers
|
267 |
+
# self.attn_temporal._use_memory_efficient_attention_xformers = (
|
268 |
+
# use_memory_efficient_attention_xformers
|
269 |
+
# ), # FIXME: enabling this raises CUDA ERROR. Gotta dig in.
|
270 |
+
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
hidden_states,
|
274 |
+
encoder_hidden_states=None,
|
275 |
+
timestep=None,
|
276 |
+
attention_mask=None,
|
277 |
+
clip_length=None,
|
278 |
+
):
|
279 |
+
# 1. Self-Attention
|
280 |
+
norm_hidden_states = (
|
281 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
282 |
+
)
|
283 |
+
|
284 |
+
kwargs = dict(
|
285 |
+
hidden_states=norm_hidden_states,
|
286 |
+
attention_mask=attention_mask,
|
287 |
+
)
|
288 |
+
if self.only_cross_attention:
|
289 |
+
kwargs.update(encoder_hidden_states=encoder_hidden_states)
|
290 |
+
if self.use_sparse_causal_attention:
|
291 |
+
kwargs.update(clip_length=clip_length)
|
292 |
+
if 'SparseCausalAttention_index' in self.model_config.keys():
|
293 |
+
kwargs.update(SparseCausalAttention_index = self.model_config['SparseCausalAttention_index'])
|
294 |
+
|
295 |
+
hidden_states = hidden_states + self.attn1(**kwargs)
|
296 |
+
|
297 |
+
if clip_length is not None and self.temporal_attention_position == "after_spatial":
|
298 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
299 |
+
|
300 |
+
if self.attn2 is not None:
|
301 |
+
# 2. Cross-Attention
|
302 |
+
norm_hidden_states = (
|
303 |
+
self.norm2(hidden_states, timestep)
|
304 |
+
if self.use_ada_layer_norm
|
305 |
+
else self.norm2(hidden_states)
|
306 |
+
)
|
307 |
+
hidden_states = (
|
308 |
+
self.attn2(
|
309 |
+
norm_hidden_states, # [16, 4096, 320]
|
310 |
+
encoder_hidden_states=encoder_hidden_states, # [1, 77, 768]
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
)
|
313 |
+
+ hidden_states
|
314 |
+
)
|
315 |
+
|
316 |
+
if clip_length is not None and self.temporal_attention_position == "after_cross":
|
317 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
318 |
+
|
319 |
+
# 3. Feed-forward
|
320 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
321 |
+
|
322 |
+
if clip_length is not None and self.temporal_attention_position == "after_feedforward":
|
323 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
324 |
+
|
325 |
+
return hidden_states
|
326 |
+
|
327 |
+
def apply_temporal_attention(self, hidden_states, timestep, clip_length):
|
328 |
+
d = hidden_states.shape[1]
|
329 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length)
|
330 |
+
norm_hidden_states = (
|
331 |
+
self.norm_temporal(hidden_states, timestep)
|
332 |
+
if self.use_ada_layer_norm
|
333 |
+
else self.norm_temporal(hidden_states)
|
334 |
+
)
|
335 |
+
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
|
336 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
337 |
+
return hidden_states
|
338 |
+
|
339 |
+
|
340 |
+
class SparseCausalAttention(CrossAttention):
|
341 |
+
def forward(
|
342 |
+
self,
|
343 |
+
hidden_states,
|
344 |
+
encoder_hidden_states=None,
|
345 |
+
attention_mask=None,
|
346 |
+
clip_length: int = None,
|
347 |
+
SparseCausalAttention_index: list = [-1, 'first']
|
348 |
+
):
|
349 |
+
if (
|
350 |
+
self.added_kv_proj_dim is not None
|
351 |
+
or encoder_hidden_states is not None
|
352 |
+
or attention_mask is not None
|
353 |
+
):
|
354 |
+
raise NotImplementedError
|
355 |
+
|
356 |
+
if self.group_norm is not None:
|
357 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
358 |
+
|
359 |
+
query = self.to_q(hidden_states)
|
360 |
+
dim = query.shape[-1]
|
361 |
+
query = self.reshape_heads_to_batch_dim(query)
|
362 |
+
|
363 |
+
key = self.to_k(hidden_states)
|
364 |
+
value = self.to_v(hidden_states)
|
365 |
+
|
366 |
+
if clip_length is not None:
|
367 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
|
368 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
|
369 |
+
|
370 |
+
|
371 |
+
# ***********************Start of SparseCausalAttention_index**********
|
372 |
+
frame_index_list = []
|
373 |
+
# print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}')
|
374 |
+
if len(SparseCausalAttention_index) > 0:
|
375 |
+
for index in SparseCausalAttention_index:
|
376 |
+
if isinstance(index, str):
|
377 |
+
if index == 'first':
|
378 |
+
frame_index = [0] * clip_length
|
379 |
+
if index == 'last':
|
380 |
+
frame_index = [clip_length-1] * clip_length
|
381 |
+
if (index == 'mid') or (index == 'middle'):
|
382 |
+
frame_index = [int(clip_length-1)//2] * clip_length
|
383 |
+
else:
|
384 |
+
assert isinstance(index, int), 'relative index must be int'
|
385 |
+
frame_index = torch.arange(clip_length) + index
|
386 |
+
frame_index = frame_index.clip(0, clip_length-1)
|
387 |
+
|
388 |
+
frame_index_list.append(frame_index)
|
389 |
+
|
390 |
+
key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list
|
391 |
+
], dim=2)
|
392 |
+
value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list
|
393 |
+
], dim=2)
|
394 |
+
|
395 |
+
|
396 |
+
# ***********************End of SparseCausalAttention_index**********
|
397 |
+
key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
|
398 |
+
value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
|
399 |
+
|
400 |
+
|
401 |
+
key = self.reshape_heads_to_batch_dim(key)
|
402 |
+
value = self.reshape_heads_to_batch_dim(value)
|
403 |
+
|
404 |
+
# attention, what we cannot get enough of
|
405 |
+
if self._use_memory_efficient_attention_xformers:
|
406 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
407 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
408 |
+
hidden_states = hidden_states.to(query.dtype)
|
409 |
+
else:
|
410 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
411 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
412 |
+
else:
|
413 |
+
hidden_states = self._sliced_attention(
|
414 |
+
query, key, value, hidden_states.shape[1], dim, attention_mask
|
415 |
+
)
|
416 |
+
|
417 |
+
# linear proj
|
418 |
+
hidden_states = self.to_out[0](hidden_states)
|
419 |
+
|
420 |
+
# dropout
|
421 |
+
hidden_states = self.to_out[1](hidden_states)
|
422 |
+
return hidden_states
|
423 |
+
|
424 |
+
# FIXME
|
425 |
+
class SparseCausalAttention_fixme(CrossAttention):
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
hidden_states,
|
429 |
+
encoder_hidden_states=None,
|
430 |
+
attention_mask=None,
|
431 |
+
clip_length: int = None,
|
432 |
+
):
|
433 |
+
if (
|
434 |
+
self.added_kv_proj_dim is not None
|
435 |
+
or encoder_hidden_states is not None
|
436 |
+
or attention_mask is not None
|
437 |
+
):
|
438 |
+
raise NotImplementedError
|
439 |
+
|
440 |
+
if self.group_norm is not None:
|
441 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
442 |
+
|
443 |
+
query = self.to_q(hidden_states)
|
444 |
+
dim = query.shape[-1]
|
445 |
+
query = self.reshape_heads_to_batch_dim(query)
|
446 |
+
|
447 |
+
key = self.to_k(hidden_states)
|
448 |
+
value = self.to_v(hidden_states)
|
449 |
+
|
450 |
+
prev_frame_index = torch.arange(clip_length) - 1
|
451 |
+
prev_frame_index[0] = 0
|
452 |
+
|
453 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
|
454 |
+
key = torch.cat([key[:, [0] * clip_length], key[:, prev_frame_index]], dim=2)
|
455 |
+
key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
|
456 |
+
|
457 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
|
458 |
+
value = torch.cat([value[:, [0] * clip_length], value[:, prev_frame_index]], dim=2)
|
459 |
+
value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
|
460 |
+
|
461 |
+
key = self.reshape_heads_to_batch_dim(key)
|
462 |
+
value = self.reshape_heads_to_batch_dim(value)
|
463 |
+
|
464 |
+
|
465 |
+
if self._use_memory_efficient_attention_xformers:
|
466 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
467 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
468 |
+
hidden_states = hidden_states.to(query.dtype)
|
469 |
+
else:
|
470 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
471 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
472 |
+
else:
|
473 |
+
hidden_states = self._sliced_attention(
|
474 |
+
query, key, value, hidden_states.shape[1], dim, attention_mask
|
475 |
+
)
|
476 |
+
|
477 |
+
# linear proj
|
478 |
+
hidden_states = self.to_out[0](hidden_states)
|
479 |
+
|
480 |
+
# dropout
|
481 |
+
hidden_states = self.to_out[1](hidden_states)
|
482 |
+
return hidden_states
|