gokaygokay commited on
Commit
0a88b62
1 Parent(s): 23ce364

Upload 93 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -35
  2. README.md +4 -7
  3. app.py +284 -0
  4. assets/logo.png +0 -0
  5. assets/overview_3.png +0 -0
  6. assets/radar.png +0 -0
  7. assets/runtime.png +0 -0
  8. assets/teaser.png +3 -0
  9. demos/example_000.png +0 -0
  10. demos/example_001.png +0 -0
  11. demos/example_002.png +0 -0
  12. demos/example_003.png +3 -0
  13. demos/example_list.txt +2 -0
  14. infer/__init__.py +28 -0
  15. infer/gif_render.py +55 -0
  16. infer/image_to_views.py +81 -0
  17. infer/rembg.py +26 -0
  18. infer/text_to_image.py +80 -0
  19. infer/utils.py +77 -0
  20. infer/views_to_mesh.py +94 -0
  21. mvd/__init__.py +0 -0
  22. mvd/hunyuan3d_mvd_lite_pipeline.py +493 -0
  23. mvd/hunyuan3d_mvd_std_pipeline.py +471 -0
  24. mvd/utils.py +85 -0
  25. requirements.txt +22 -0
  26. scripts/image_to_3d.sh +8 -0
  27. scripts/image_to_3d_demo.sh +8 -0
  28. scripts/image_to_3d_fast.sh +6 -0
  29. scripts/image_to_3d_fast_demo.sh +6 -0
  30. scripts/text_to_3d.sh +7 -0
  31. scripts/text_to_3d_demo.sh +7 -0
  32. scripts/text_to_3d_fast.sh +6 -0
  33. scripts/text_to_3d_fast_demo.sh +6 -0
  34. svrm/.DS_Store +0 -0
  35. svrm/configs/2024-10-24T22-36-18-project.yaml +32 -0
  36. svrm/configs/svrm.yaml +32 -0
  37. svrm/ldm/.DS_Store +0 -0
  38. svrm/ldm/models/svrm.py +263 -0
  39. svrm/ldm/modules/attention.py +457 -0
  40. svrm/ldm/modules/encoders/__init__.py +0 -0
  41. svrm/ldm/modules/encoders/dinov2/__init__.py +0 -0
  42. svrm/ldm/modules/encoders/dinov2/hub/__init__.py +0 -0
  43. svrm/ldm/modules/encoders/dinov2/hub/backbones.py +156 -0
  44. svrm/ldm/modules/encoders/dinov2/hub/utils.py +39 -0
  45. svrm/ldm/modules/encoders/dinov2/layers/__init__.py +11 -0
  46. svrm/ldm/modules/encoders/dinov2/layers/attention.py +89 -0
  47. svrm/ldm/modules/encoders/dinov2/layers/block.py +269 -0
  48. svrm/ldm/modules/encoders/dinov2/layers/dino_head.py +58 -0
  49. svrm/ldm/modules/encoders/dinov2/layers/drop_path.py +34 -0
  50. svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py +27 -0
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
37
+ demos/example_003.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,11 @@
1
  ---
2
- title: InstantIR
3
- emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: diffusion-based Image Restoration model
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Hunyuan3D-1.0
3
+ emoji: 😻
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Text-to-3D and Image-to-3D Generation
11
+ ---
 
 
 
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from huggingface_hub import hf_hub_download
4
+ import gradio as gr
5
+ from glob import glob
6
+ import shutil
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from einops import rearrange
11
+ import argparse
12
+
13
+ # Suppress warnings
14
+ warnings.simplefilter('ignore', category=UserWarning)
15
+ warnings.simplefilter('ignore', category=FutureWarning)
16
+ warnings.simplefilter('ignore', category=DeprecationWarning)
17
+
18
+ def download_models():
19
+ # Create weights directory if it doesn't exist
20
+ os.makedirs("weights", exist_ok=True)
21
+ os.makedirs("weights/hunyuanDiT", exist_ok=True)
22
+
23
+ # Download Hunyuan3D-1 model
24
+ try:
25
+ hf_hub_download(
26
+ repo_id="tencent/Hunyuan3D-1",
27
+ local_dir="./weights",
28
+ resume_download=True
29
+ )
30
+ print("Successfully downloaded Hunyuan3D-1 model")
31
+ except Exception as e:
32
+ print(f"Error downloading Hunyuan3D-1: {e}")
33
+
34
+ # Download HunyuanDiT model
35
+ try:
36
+ hf_hub_download(
37
+ repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
38
+ local_dir="./weights/hunyuanDiT",
39
+ resume_download=True
40
+ )
41
+ print("Successfully downloaded HunyuanDiT model")
42
+ except Exception as e:
43
+ print(f"Error downloading HunyuanDiT: {e}")
44
+
45
+ # Download models before starting the app
46
+ download_models()
47
+
48
+ # Parse arguments
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--use_lite", default=False, action="store_true")
51
+ parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
52
+ parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
53
+ parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
54
+ parser.add_argument("--save_memory", default=False, action="store_true")
55
+ parser.add_argument("--device", default="cuda:0", type=str)
56
+ args = parser.parse_args()
57
+
58
+ # Constants
59
+ CONST_PORT = 8080
60
+ CONST_MAX_QUEUE = 1
61
+ CONST_SERVER = '0.0.0.0'
62
+
63
+ CONST_HEADER = '''
64
+ <h2><b>Official 🤗 Gradio Demo</b></h2>
65
+ <h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>
66
+ <b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
67
+ '''
68
+
69
+ # Helper functions
70
+ def get_example_img_list():
71
+ print('Loading example img list ...')
72
+ return sorted(glob('./demos/example_*.png'))
73
+
74
+ def get_example_txt_list():
75
+ print('Loading example txt list ...')
76
+ txt_list = []
77
+ for line in open('./demos/example_list.txt'):
78
+ txt_list.append(line.strip())
79
+ return txt_list
80
+
81
+ example_is = get_example_img_list()
82
+ example_ts = get_example_txt_list()
83
+
84
+ # Import required workers
85
+ from infer import seed_everything, save_gif
86
+ from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
87
+
88
+ # Initialize workers
89
+ worker_xbg = Removebg()
90
+ print(f"loading {args.text2image_path}")
91
+ worker_t2i = Text2Image(
92
+ pretrain=args.text2image_path,
93
+ device=args.device,
94
+ save_memory=args.save_memory
95
+ )
96
+ worker_i2v = Image2Views(
97
+ use_lite=args.use_lite,
98
+ device=args.device
99
+ )
100
+ worker_v23 = Views2Mesh(
101
+ args.mv23d_cfg_path,
102
+ args.mv23d_ckt_path,
103
+ use_lite=args.use_lite,
104
+ device=args.device
105
+ )
106
+ worker_gif = GifRenderer(args.device)
107
+
108
+ # Pipeline stages
109
+ def stage_0_t2i(text, image, seed, step):
110
+ os.makedirs('./outputs/app_output', exist_ok=True)
111
+ exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
112
+ cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0
113
+
114
+ if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"):
115
+ shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}")
116
+ save_folder = f'./outputs/app_output/{cur_id}'
117
+ os.makedirs(save_folder, exist_ok=True)
118
+
119
+ dst = save_folder + '/img.png'
120
+
121
+ if not text:
122
+ if image is None:
123
+ return dst, save_folder
124
+ image.save(dst)
125
+ return dst, save_folder
126
+
127
+ image = worker_t2i(text, seed, step)
128
+ image.save(dst)
129
+ dst = worker_xbg(image, save_folder)
130
+ return dst, save_folder
131
+
132
+ def stage_1_xbg(image, save_folder):
133
+ if isinstance(image, str):
134
+ image = Image.open(image)
135
+ dst = save_folder + '/img_nobg.png'
136
+ rgba = worker_xbg(image)
137
+ rgba.save(dst)
138
+ return dst
139
+
140
+ def stage_2_i2v(image, seed, step, save_folder):
141
+ if isinstance(image, str):
142
+ image = Image.open(image)
143
+ gif_dst = save_folder + '/views.gif'
144
+ res_img, pils = worker_i2v(image, seed, step)
145
+ save_gif(pils, gif_dst)
146
+ views_img, cond_img = res_img[0], res_img[1]
147
+ img_array = np.asarray(views_img, dtype=np.uint8)
148
+ show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
149
+ show_img = show_img[worker_i2v.order, ...]
150
+ show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
151
+ show_img = Image.fromarray(show_img)
152
+ return views_img, cond_img, show_img
153
+
154
+ def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000,
155
+ do_texture_mapping=True, do_render=True):
156
+ do_texture_mapping = do_texture_mapping or do_render
157
+ obj_dst = save_folder + '/mesh_with_colors.obj'
158
+ glb_dst = save_folder + '/mesh.glb'
159
+ worker_v23(
160
+ views_pil,
161
+ cond_pil,
162
+ seed=seed,
163
+ save_folder=save_folder,
164
+ target_face_count=target_face_count,
165
+ do_texture_mapping=do_texture_mapping
166
+ )
167
+ return obj_dst, glb_dst
168
+
169
+ def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
170
+ if not do_render_gif:
171
+ return None
172
+ gif_dst = save_folder + '/output.gif'
173
+ worker_gif(
174
+ save_folder + '/mesh.obj',
175
+ gif_dst_path=gif_dst
176
+ )
177
+ return gif_dst
178
+
179
+ # Gradio Interface
180
+ with gr.Blocks() as demo:
181
+ gr.Markdown(CONST_HEADER)
182
+
183
+ with gr.Row(variant="panel"):
184
+ with gr.Column(scale=2):
185
+ with gr.Tab("Text to 3D"):
186
+ with gr.Column():
187
+ text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
188
+ lines=1, max_lines=10, label='Input text')
189
+ with gr.Row():
190
+ textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
191
+ textgen_step = gr.Number(value=25, label="T2I step", precision=0)
192
+ textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
193
+ textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
194
+ textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
195
+
196
+ with gr.Row():
197
+ textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
198
+ textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
199
+ textgen_submit = gr.Button("Generate", variant="primary")
200
+
201
+ gr.Examples(examples=example_ts, inputs=[text], label="Txt examples")
202
+
203
+ with gr.Tab("Image to 3D"):
204
+ with gr.Column():
205
+ input_image = gr.Image(label="Input image", width=256, height=256,
206
+ type="pil", image_mode="RGBA", sources="upload")
207
+ with gr.Row():
208
+ imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
209
+ imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
210
+ imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
211
+
212
+ with gr.Row():
213
+ imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
214
+ imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
215
+ imggen_submit = gr.Button("Generate", variant="primary")
216
+
217
+ gr.Examples(examples=example_is, inputs=[input_image], label="Img examples")
218
+
219
+ with gr.Column(scale=3):
220
+ with gr.Tab("rembg image"):
221
+ rem_bg_image = gr.Image(label="No background image", width=256, height=256,
222
+ type="pil", image_mode="RGBA")
223
+
224
+ with gr.Tab("Multi views"):
225
+ result_image = gr.Image(label="Multi views", type="pil")
226
+ with gr.Tab("Obj"):
227
+ result_3dobj = gr.Model3D(label="Output obj")
228
+ with gr.Tab("Glb"):
229
+ result_3dglb = gr.Model3D(label="Output glb")
230
+ with gr.Tab("GIF"):
231
+ result_gif = gr.Image(label="Rendered GIF")
232
+
233
+ # States
234
+ none = gr.State(None)
235
+ save_folder = gr.State()
236
+ cond_image = gr.State()
237
+ views_image = gr.State()
238
+ text_image = gr.State()
239
+
240
+ # Event handlers
241
+ textgen_submit.click(
242
+ fn=stage_0_t2i,
243
+ inputs=[text, none, textgen_seed, textgen_step],
244
+ outputs=[rem_bg_image, save_folder],
245
+ ).success(
246
+ fn=stage_2_i2v,
247
+ inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
248
+ outputs=[views_image, cond_image, result_image],
249
+ ).success(
250
+ fn=stage_3_v23,
251
+ inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces,
252
+ textgen_do_texture_mapping, textgen_do_render_gif],
253
+ outputs=[result_3dobj, result_3dglb],
254
+ ).success(
255
+ fn=stage_4_gif,
256
+ inputs=[result_3dglb, save_folder, textgen_do_render_gif],
257
+ outputs=[result_gif],
258
+ )
259
+
260
+ imggen_submit.click(
261
+ fn=stage_0_t2i,
262
+ inputs=[none, input_image, textgen_seed, textgen_step],
263
+ outputs=[text_image, save_folder],
264
+ ).success(
265
+ fn=stage_1_xbg,
266
+ inputs=[text_image, save_folder],
267
+ outputs=[rem_bg_image],
268
+ ).success(
269
+ fn=stage_2_i2v,
270
+ inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
271
+ outputs=[views_image, cond_image, result_image],
272
+ ).success(
273
+ fn=stage_3_v23,
274
+ inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces,
275
+ imggen_do_texture_mapping, imggen_do_render_gif],
276
+ outputs=[result_3dobj, result_3dglb],
277
+ ).success(
278
+ fn=stage_4_gif,
279
+ inputs=[result_3dglb, save_folder, imggen_do_render_gif],
280
+ outputs=[result_gif],
281
+ )
282
+
283
+ demo.queue(max_size=CONST_MAX_QUEUE)
284
+ demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
assets/logo.png ADDED
assets/overview_3.png ADDED
assets/radar.png ADDED
assets/runtime.png ADDED
assets/teaser.png ADDED

Git LFS Details

  • SHA256: af24eeebe39864d377b7ef8e11521a8b7cba964c14032cc28bd0d95bd5219c00
  • Pointer size: 132 Bytes
  • Size of remote file: 3.1 MB
demos/example_000.png ADDED
demos/example_001.png ADDED
demos/example_002.png ADDED
demos/example_003.png ADDED

Git LFS Details

  • SHA256: d947e0ef10baf761abb78d2842519ae7428bc6eadab26a159510ddcaf2a47e67
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
demos/example_list.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ a pot of green plants grows in a red flower pot.
2
+ a lovely rabbit eating carrots
infer/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ from .utils import seed_everything, timing_decorator, auto_amp_inference
24
+ from .rembg import Removebg
25
+ from .text_to_image import Text2Image
26
+ from .image_to_views import Image2Views, save_gif
27
+ from .views_to_mesh import Views2Mesh
28
+ from .gif_render import GifRenderer
infer/gif_render.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ from svrm.ldm.vis_util import render
24
+ from .utils import seed_everything, timing_decorator
25
+
26
+ class GifRenderer():
27
+ '''
28
+ render frame(s) of mesh using pytorch3d
29
+ '''
30
+ def __init__(self, device="cuda:0"):
31
+ self.device = device
32
+
33
+ @timing_decorator("gif render")
34
+ def __call__(
35
+ self,
36
+ obj_filename,
37
+ elev=0,
38
+ azim=0,
39
+ resolution=512,
40
+ gif_dst_path='',
41
+ n_views=120,
42
+ fps=30,
43
+ rgb=True
44
+ ):
45
+ render(
46
+ obj_filename,
47
+ elev=elev,
48
+ azim=azim,
49
+ resolution=resolution,
50
+ gif_dst_path=gif_dst_path,
51
+ n_views=n_views,
52
+ fps=fps,
53
+ device=self.device,
54
+ rgb=rgb
55
+ )
infer/image_to_views.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import os
24
+ import time
25
+ import torch
26
+ import random
27
+ import numpy as np
28
+ from PIL import Image
29
+ from einops import rearrange
30
+ from PIL import Image, ImageSequence
31
+
32
+ from .utils import seed_everything, timing_decorator, auto_amp_inference
33
+ from .utils import get_parameter_number, set_parameter_grad_false
34
+ from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
35
+ from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline
36
+
37
+
38
+ def save_gif(pils, save_path, df=False):
39
+ # save a list of PIL.Image to gif
40
+ spf = 4000 / len(pils)
41
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
42
+ pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0)
43
+ return save_path
44
+
45
+
46
+ class Image2Views():
47
+ def __init__(self, device="cuda:0", use_lite=False):
48
+ self.device = device
49
+ if use_lite:
50
+ self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
51
+ "./weights/mvd_lite",
52
+ torch_dtype = torch.float16,
53
+ use_safetensors = True,
54
+ )
55
+ else:
56
+ self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
57
+ "./weights/mvd_std",
58
+ torch_dtype = torch.float16,
59
+ use_safetensors = True,
60
+ )
61
+ self.pipe = self.pipe.to(device)
62
+ self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
63
+ set_parameter_grad_false(self.pipe.unet)
64
+ print('image2views unet model', get_parameter_number(self.pipe.unet))
65
+
66
+ @torch.no_grad()
67
+ @timing_decorator("image to views")
68
+ @auto_amp_inference
69
+ def __call__(self, pil_img, seed=0, steps=50, guidance_scale=2.0, guidance_curve=lambda t:2.0):
70
+ seed_everything(seed)
71
+ generator = torch.Generator(device=self.device)
72
+ res_img = self.pipe(pil_img,
73
+ num_inference_steps=steps,
74
+ guidance_scale=guidance_scale,
75
+ guidance_curve=guidance_curve,
76
+ generat=generator).images
77
+ show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
78
+ pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order]
79
+ torch.cuda.empty_cache()
80
+ return res_img, pils
81
+
infer/rembg.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rembg import remove, new_session
2
+ from .utils import timing_decorator
3
+
4
+ class Removebg():
5
+ def __init__(self, name="u2net"):
6
+ '''
7
+ name: rembg
8
+ '''
9
+ self.session = new_session(name)
10
+
11
+ @timing_decorator("remove background")
12
+ def __call__(self, rgb_img, force=False):
13
+ '''
14
+ inputs:
15
+ rgb_img: PIL.Image, with RGB mode expected
16
+ force: bool, input is RGBA mode
17
+ return:
18
+ rgba_img: PIL.Image with RGBA mode
19
+ '''
20
+ if rgb_img.mode == "RGBA":
21
+ if force:
22
+ rgb_img = rgb_img.convert("RGB")
23
+ else:
24
+ return rgb_img
25
+ rgba_img = remove(rgb_img, session=self.session)
26
+ return rgba_img
infer/text_to_image.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import torch
24
+ from .utils import seed_everything, timing_decorator, auto_amp_inference
25
+ from .utils import get_parameter_number, set_parameter_grad_false
26
+ from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image
27
+
28
+ class Text2Image():
29
+ def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=False):
30
+ '''
31
+ save_memory: if GPU memory is low, can set it
32
+ '''
33
+ self.save_memory = save_memory
34
+ self.device = device
35
+ self.pipe = AutoPipelineForText2Image.from_pretrained(
36
+ pretrain,
37
+ torch_dtype = torch.float16,
38
+ enable_pag = True,
39
+ pag_applied_layers = ["blocks.(16|17|18|19)"]
40
+ )
41
+ set_parameter_grad_false(self.pipe.transformer)
42
+ print('text2image transformer model', get_parameter_number(self.pipe.transformer))
43
+ if not save_memory:
44
+ self.pipe = self.pipe.to(device)
45
+ self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
46
+ "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
47
+ "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
48
+
49
+ @torch.no_grad()
50
+ @timing_decorator('text to image')
51
+ @auto_amp_inference
52
+ def __call__(self, *args, **kwargs):
53
+ if self.save_memory:
54
+ self.pipe = self.pipe.to(self.device)
55
+ torch.cuda.empty_cache()
56
+ res = self.call(*args, **kwargs)
57
+ self.pipe = self.pipe.to("cpu")
58
+ else:
59
+ res = self.call(*args, **kwargs)
60
+ torch.cuda.empty_cache()
61
+ return res
62
+
63
+ def call(self, prompt, seed=0, steps=25):
64
+ '''
65
+ inputs:
66
+ prompr: str
67
+ seed: int
68
+ steps: int
69
+ return:
70
+ rgb: PIL.Image
71
+ '''
72
+ prompt = prompt + ",白色背景,3D风格,最佳质量"
73
+ seed_everything(seed)
74
+ generator = torch.Generator(device=self.device)
75
+ if seed is not None: generator = generator.manual_seed(int(seed))
76
+ rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps,
77
+ pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0]
78
+ torch.cuda.empty_cache()
79
+ return rgb
80
+
infer/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import os
24
+ import time
25
+ import random
26
+ import numpy as np
27
+ import torch
28
+ from torch.cuda.amp import autocast, GradScaler
29
+ from functools import wraps
30
+
31
+ def seed_everything(seed):
32
+ '''
33
+ seed everthing
34
+ '''
35
+ random.seed(seed)
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
39
+
40
+ def timing_decorator(category: str):
41
+ '''
42
+ timing_decorator: record time
43
+ '''
44
+ def decorator(func):
45
+ func.call_count = 0
46
+ @wraps(func)
47
+ def wrapper(*args, **kwargs):
48
+ start_time = time.time()
49
+ result = func(*args, **kwargs)
50
+ end_time = time.time()
51
+ elapsed_time = end_time - start_time
52
+ func.call_count += 1
53
+ print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen
54
+ return result
55
+ return wrapper
56
+ return decorator
57
+
58
+ def auto_amp_inference(func):
59
+ '''
60
+ with torch.cuda.amp.autocast()"
61
+ xxx
62
+ '''
63
+ @wraps(func)
64
+ def wrapper(*args, **kwargs):
65
+ with autocast():
66
+ output = func(*args, **kwargs)
67
+ return output
68
+ return wrapper
69
+
70
+ def get_parameter_number(model):
71
+ total_num = sum(p.numel() for p in model.parameters())
72
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
73
+ return {'Total': total_num, 'Trainable': trainable_num}
74
+
75
+ def set_parameter_grad_false(model):
76
+ for p in model.parameters():
77
+ p.requires_grad = False
infer/views_to_mesh.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import os
24
+ import time
25
+ import torch
26
+ import random
27
+ import numpy as np
28
+ from PIL import Image
29
+ from einops import rearrange
30
+ from PIL import Image, ImageSequence
31
+
32
+ from .utils import seed_everything, timing_decorator, auto_amp_inference
33
+ from .utils import get_parameter_number, set_parameter_grad_false
34
+ from svrm.predictor import MV23DPredictor
35
+
36
+
37
+ class Views2Mesh():
38
+ def __init__(self, mv23d_cfg_path, mv23d_ckt_path, device="cuda:0", use_lite=False):
39
+ '''
40
+ mv23d_cfg_path: config yaml file
41
+ mv23d_ckt_path: path to ckpt
42
+ use_lite:
43
+ '''
44
+ self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
45
+ self.mv23d_predictor.model.eval()
46
+ self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
47
+ set_parameter_grad_false(self.mv23d_predictor.model)
48
+ print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
49
+
50
+ @torch.no_grad()
51
+ @timing_decorator("views to mesh")
52
+ @auto_amp_inference
53
+ def __call__(
54
+ self,
55
+ views_pil=None,
56
+ cond_pil=None,
57
+ gif_pil=None,
58
+ seed=0,
59
+ target_face_count = 10000,
60
+ do_texture_mapping = True,
61
+ save_folder='./outputs/test'
62
+ ):
63
+ '''
64
+ can set views_pil, cond_pil simutaously or set gif_pil only
65
+ seed: int
66
+ target_face_count: int
67
+ save_folder: path to save mesh files
68
+ '''
69
+ save_dir = save_folder
70
+ os.makedirs(save_dir, exist_ok=True)
71
+
72
+ if views_pil is not None and cond_pil is not None:
73
+ show_image = rearrange(np.asarray(views_pil, dtype=np.uint8),
74
+ '(n h) (m w) c -> (n m) h w c', n=3, m=2)
75
+ views = [Image.fromarray(show_image[idx]) for idx in self.order]
76
+ image_list = [cond_pil]+ views
77
+ image_list = [img.convert('RGB') for img in image_list]
78
+ elif gif_pil is not None:
79
+ image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)]
80
+
81
+ image_input = image_list[0]
82
+ image_list = image_list[1:] + image_list[:1]
83
+
84
+ seed_everything(seed)
85
+ self.mv23d_predictor.predict(
86
+ image_list,
87
+ save_dir = save_dir,
88
+ image_input = image_input,
89
+ target_face_count = target_face_count,
90
+ do_texture_mapping = do_texture_mapping
91
+ )
92
+ torch.cuda.empty_cache()
93
+ return save_dir
94
+
mvd/__init__.py ADDED
File without changes
mvd/hunyuan3d_mvd_lite_pipeline.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import math
24
+ import numpy
25
+ import torch
26
+ import inspect
27
+ import warnings
28
+ from PIL import Image
29
+ from einops import rearrange
30
+ import torch.nn.functional as F
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.configuration_utils import FrozenDict
33
+ from diffusers.image_processor import VaeImageProcessor
34
+ from typing import Any, Callable, Dict, List, Optional, Union
35
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
36
+ from diffusers.schedulers import KarrasDiffusionSchedulers
37
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
38
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
39
+ from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput
40
+ from diffusers.loaders import (
41
+ FromSingleFileMixin,
42
+ LoraLoaderMixin,
43
+ TextualInversionLoaderMixin
44
+ )
45
+ from transformers import (
46
+ CLIPImageProcessor,
47
+ CLIPTextModel,
48
+ CLIPTokenizer,
49
+ CLIPVisionModelWithProjection
50
+ )
51
+ from diffusers.models.attention_processor import (
52
+ Attention,
53
+ AttnProcessor,
54
+ XFormersAttnProcessor,
55
+ AttnProcessor2_0
56
+ )
57
+
58
+ from .utils import to_rgb_image, white_out_background, recenter_img
59
+
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from here import Hunyuan3d_MVD_Qing_Pipeline
66
+
67
+ >>> pipe = Hunyuan3d_MVD_Qing_Pipeline.from_pretrained(
68
+ ... "Tencent-Hunyuan-3D/MVD-Qing", torch_dtype=torch.float16
69
+ ... )
70
+ >>> pipe.to("cuda")
71
+
72
+ >>> img = Image.open("demo.png")
73
+ >>> res_img = pipe(img).images[0]
74
+ """
75
+
76
+ def unscale_latents(latents): return latents / 0.75 + 0.22
77
+ def unscale_image (image ): return image / 0.50 * 0.80
78
+
79
+
80
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
81
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
82
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
83
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
84
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
85
+ return noise_cfg
86
+
87
+
88
+
89
+ class ReferenceOnlyAttnProc(torch.nn.Module):
90
+ # reference attention
91
+ def __init__(self, chained_proc, enabled=False, name=None):
92
+ super().__init__()
93
+ self.enabled = enabled
94
+ self.chained_proc = chained_proc
95
+ self.name = name
96
+
97
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
98
+ if encoder_hidden_states is None: encoder_hidden_states = hidden_states
99
+ if self.enabled:
100
+ if mode == 'w':
101
+ ref_dict[self.name] = encoder_hidden_states
102
+ elif mode == 'r':
103
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
104
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
105
+ return res
106
+
107
+
108
+ # class RowWiseAttnProcessor2_0:
109
+ # def __call__(self, attn,
110
+ # hidden_states,
111
+ # encoder_hidden_states=None,
112
+ # attention_mask=None,
113
+ # temb=None,
114
+ # num_views=6,
115
+ # *args,
116
+ # **kwargs):
117
+ # residual = hidden_states
118
+ # if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb)
119
+
120
+ # input_ndim = hidden_states.ndim
121
+ # if input_ndim == 4:
122
+ # batch_size, channel, height, width = hidden_states.shape
123
+ # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ # if encoder_hidden_states is None:
126
+ # batch_size, sequence_length, _ = hidden_states.shape
127
+ # else:
128
+ # batch_size, sequence_length, _ = encoder_hidden_states.shape
129
+
130
+ # if attention_mask is not None:
131
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
132
+ # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
133
+ # if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
134
+
135
+ # query = attn.to_q(hidden_states)
136
+ # if encoder_hidden_states is None: encoder_hidden_states = hidden_states
137
+ # elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
138
+
139
+ # # encoder_hidden_states [B, 6hw+hw, C] if ref att
140
+ # key = attn.to_k(encoder_hidden_states) # [B, Vhw+hw, C]
141
+ # value = attn.to_v(encoder_hidden_states) # [B, Vhw+hw, C]
142
+
143
+ # mv_flag = hidden_states.shape[1] < encoder_hidden_states.shape[1] and encoder_hidden_states.shape[1] != 77
144
+ # if mv_flag:
145
+ # target_size = int(math.sqrt(hidden_states.shape[1] // num_views))
146
+ # assert target_size ** 2 * num_views == hidden_states.shape[1]
147
+
148
+ # gen_key = key[:, :num_views*target_size*target_size, :]
149
+ # ref_key = key[:, num_views*target_size*target_size:, :]
150
+ # gen_value = value[:, :num_views*target_size*target_size, :]
151
+ # ref_value = value[:, num_views*target_size*target_size:, :]
152
+
153
+ # # rowwise attention
154
+ # query, gen_key, gen_value = \
155
+ # rearrange( query, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
156
+ # v1=num_views//2, v2=2, h=target_size, w=target_size), \
157
+ # rearrange( gen_key, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
158
+ # v1=num_views//2, v2=2, h=target_size, w=target_size), \
159
+ # rearrange(gen_value, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
160
+ # v1=num_views//2, v2=2, h=target_size, w=target_size)
161
+
162
+ # inner_dim = key.shape[-1]
163
+ # ref_size = int(math.sqrt(ref_key.shape[1]))
164
+ # ref_key_expanded = ref_key.view(batch_size, 1, ref_size * ref_size, inner_dim)
165
+ # ref_key_expanded = ref_key_expanded.expand(-1, target_size, -1, -1).contiguous()
166
+ # ref_key_expanded = ref_key_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
167
+ # key = torch.cat([ gen_key, ref_key_expanded], dim=1)
168
+
169
+ # ref_value_expanded = ref_value.view(batch_size, 1, ref_size * ref_size, inner_dim)
170
+ # ref_value_expanded = ref_value_expanded.expand(-1, target_size, -1, -1).contiguous()
171
+ # ref_value_expanded = ref_value_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
172
+ # value = torch.cat([gen_value, ref_value_expanded], dim=1)
173
+ # h = target_size
174
+ # else:
175
+ # target_size = int(math.sqrt(hidden_states.shape[1]))
176
+ # h = 1
177
+ # num_views = 1
178
+
179
+ # inner_dim = key.shape[-1]
180
+ # head_dim = inner_dim // attn.heads
181
+
182
+ # query = query.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
183
+ # key = key.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
184
+ # value = value.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
185
+
186
+ # hidden_states = F.scaled_dot_product_attention(query, key, value,
187
+ # attn_mask=attention_mask,
188
+ # dropout_p=0.0,
189
+ # is_causal=False)
190
+ # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * h,
191
+ # -1,
192
+ # attn.heads * head_dim).to(query.dtype)
193
+ # hidden_states = attn.to_out[1](attn.to_out[0](hidden_states))
194
+
195
+ # if mv_flag: hidden_states = rearrange(hidden_states, "(b h) (v1 v2 w) c -> b (v1 h v2 w) c",
196
+ # b=batch_size, v1=num_views//2,
197
+ # v2=2, h=target_size, w=target_size)
198
+
199
+ # if input_ndim == 4:
200
+ # hidden_states = hidden_states.transpose(-1, -2)
201
+ # hidden_states = hidden_states.reshape(batch_size,
202
+ # channel,
203
+ # target_size,
204
+ # target_size)
205
+ # if attn.residual_connection: hidden_states = hidden_states + residual
206
+ # hidden_states = hidden_states / attn.rescale_output_factor
207
+ # return hidden_states
208
+
209
+
210
+ class RefOnlyNoisedUNet(torch.nn.Module):
211
+ def __init__(self, unet, train_sched, val_sched):
212
+ super().__init__()
213
+ self.unet = unet
214
+ self.train_sched = train_sched
215
+ self.val_sched = val_sched
216
+
217
+ unet_lora_attn_procs = dict()
218
+ for name, _ in unet.attn_processors.items():
219
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(),
220
+ enabled=name.endswith("attn1.processor"),
221
+ name=name)
222
+ unet.set_attn_processor(unet_lora_attn_procs)
223
+
224
+ def __getattr__(self, name: str):
225
+ try:
226
+ return super().__getattr__(name)
227
+ except AttributeError:
228
+ return getattr(self.unet, name)
229
+
230
+ def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs):
231
+ cond_lat = cross_attention_kwargs['cond_lat']
232
+ noise = torch.randn_like(cond_lat)
233
+ if self.training:
234
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
235
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
236
+ else:
237
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
238
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
239
+
240
+ ref_dict = {}
241
+ self.unet(noisy_cond_lat,
242
+ timestep,
243
+ encoder_hidden_states,
244
+ *args,
245
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
246
+ **kwargs)
247
+ return self.unet(sample,
248
+ timestep,
249
+ encoder_hidden_states,
250
+ *args,
251
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict),
252
+ **kwargs)
253
+
254
+
255
+ class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
256
+ def __init__(
257
+ self,
258
+ vae: AutoencoderKL,
259
+ text_encoder: CLIPTextModel,
260
+ tokenizer: CLIPTokenizer,
261
+ unet: UNet2DConditionModel,
262
+ scheduler: KarrasDiffusionSchedulers,
263
+ vision_encoder: CLIPVisionModelWithProjection,
264
+ feature_extractor_clip: CLIPImageProcessor,
265
+ feature_extractor_vae: CLIPImageProcessor,
266
+ ramping_coefficients: Optional[list] = None,
267
+ safety_checker=None,
268
+ ):
269
+ DiffusionPipeline.__init__(self)
270
+ self.register_modules(
271
+ vae=vae,
272
+ unet=unet,
273
+ tokenizer=tokenizer,
274
+ scheduler=scheduler,
275
+ text_encoder=text_encoder,
276
+ vision_encoder=vision_encoder,
277
+ feature_extractor_vae=feature_extractor_vae,
278
+ feature_extractor_clip=feature_extractor_clip)
279
+ '''
280
+ rewrite the stable diffusion pipeline
281
+ vae: vae
282
+ unet: unet
283
+ tokenizer: tokenizer
284
+ scheduler: scheduler
285
+ text_encoder: text_encoder
286
+ vision_encoder: vision_encoder
287
+ feature_extractor_vae: feature_extractor_vae
288
+ feature_extractor_clip: feature_extractor_clip
289
+ '''
290
+ self.register_to_config(ramping_coefficients=ramping_coefficients)
291
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
292
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
293
+
294
+ def prepare_extra_step_kwargs(self, generator, eta):
295
+ extra_step_kwargs = {}
296
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
297
+ if accepts_eta: extra_step_kwargs["eta"] = eta
298
+
299
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
300
+ if accepts_generator: extra_step_kwargs["generator"] = generator
301
+ return extra_step_kwargs
302
+
303
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
304
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
305
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
306
+ latents = latents * self.scheduler.init_noise_sigma
307
+ return latents
308
+
309
+ @torch.no_grad()
310
+ def _encode_prompt(
311
+ self,
312
+ prompt,
313
+ device,
314
+ num_images_per_prompt,
315
+ do_classifier_free_guidance,
316
+ negative_prompt=None,
317
+ prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
319
+ lora_scale: Optional[float] = None,
320
+ ):
321
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
322
+ self._lora_scale = lora_scale
323
+
324
+ if prompt is not None and isinstance(prompt, str):
325
+ batch_size = 1
326
+ elif prompt is not None and isinstance(prompt, list):
327
+ batch_size = len(prompt)
328
+ else:
329
+ batch_size = prompt_embeds.shape[0]
330
+
331
+ if prompt_embeds is None:
332
+ if isinstance(self, TextualInversionLoaderMixin):
333
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
334
+
335
+ text_inputs = self.tokenizer(
336
+ prompt,
337
+ padding="max_length",
338
+ max_length=self.tokenizer.model_max_length,
339
+ truncation=True,
340
+ return_tensors="pt",
341
+ )
342
+ text_input_ids = text_inputs.input_ids
343
+
344
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
345
+ attention_mask = text_inputs.attention_mask.to(device)
346
+ else:
347
+ attention_mask = None
348
+
349
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0]
350
+
351
+ if self.text_encoder is not None:
352
+ prompt_embeds_dtype = self.text_encoder.dtype
353
+ elif self.unet is not None:
354
+ prompt_embeds_dtype = self.unet.dtype
355
+ else:
356
+ prompt_embeds_dtype = prompt_embeds.dtype
357
+
358
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
359
+ bs_embed, seq_len, _ = prompt_embeds.shape
360
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
362
+
363
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
364
+ uncond_tokens: List[str]
365
+ if negative_prompt is None: uncond_tokens = [""] * batch_size
366
+ elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError()
367
+ elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt]
368
+ elif batch_size != len(negative_prompt): raise ValueError()
369
+ else: uncond_tokens = negative_prompt
370
+ if isinstance(self, TextualInversionLoaderMixin):
371
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
372
+
373
+ max_length = prompt_embeds.shape[1]
374
+ uncond_input = self.tokenizer(uncond_tokens,
375
+ padding="max_length",
376
+ max_length=max_length,
377
+ truncation=True,
378
+ return_tensors="pt")
379
+
380
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
381
+ attention_mask = uncond_input.attention_mask.to(device)
382
+ else:
383
+ attention_mask = None
384
+
385
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask)
386
+ negative_prompt_embeds = negative_prompt_embeds[0]
387
+
388
+ if do_classifier_free_guidance:
389
+ seq_len = negative_prompt_embeds.shape[1]
390
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
391
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
392
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
393
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
394
+
395
+ return prompt_embeds
396
+
397
+ @torch.no_grad()
398
+ def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample()
399
+
400
+ @torch.no_grad()
401
+ def __call__(self, image=None,
402
+ width=640,
403
+ height=960,
404
+ num_inference_steps=75,
405
+ return_dict=True,
406
+ generator=None,
407
+ **kwargs):
408
+ batch_size = 1
409
+ num_images_per_prompt = 1
410
+ output_type = 'pil'
411
+ do_classifier_free_guidance = True
412
+ guidance_rescale = 0.
413
+ if isinstance(self.unet, UNet2DConditionModel):
414
+ self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval()
415
+
416
+ cond_image = recenter_img(image)
417
+ cond_image = to_rgb_image(image)
418
+ image = cond_image
419
+ image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
420
+ image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
421
+ image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
422
+ image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
423
+
424
+ cond_lat = self.encode_condition_image(image_1)
425
+ negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
426
+ cond_lat = torch.cat([negative_lat, cond_lat])
427
+ cross_attention_kwargs = dict(cond_lat=cond_lat)
428
+
429
+ global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2)
430
+ encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False)
431
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
432
+ prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp])
433
+
434
+ device = self._execution_device
435
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
436
+ timesteps = self.scheduler.timesteps
437
+ num_channels_latents = self.unet.config.in_channels
438
+ latents = self.prepare_latents(batch_size * num_images_per_prompt,
439
+ num_channels_latents,
440
+ height,
441
+ width,
442
+ prompt_embeds.dtype,
443
+ device,
444
+ generator,
445
+ None)
446
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
447
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
448
+
449
+ # set adaptive cfg
450
+ # the image order is:
451
+ # [0, 60,
452
+ # 120, 180,
453
+ # 240, 300]
454
+ # the cfg is set as 3, 2.5, 2, 1.5
455
+
456
+ tmp_guidance_scale = torch.ones_like(latents)
457
+ tmp_guidance_scale[:, :, :40, :40] = 3
458
+ tmp_guidance_scale[:, :, :40, 40:] = 2.5
459
+ tmp_guidance_scale[:, :, 40:80, :40] = 2
460
+ tmp_guidance_scale[:, :, 40:80, 40:] = 1.5
461
+ tmp_guidance_scale[:, :, 80:120, :40] = 2
462
+ tmp_guidance_scale[:, :, 80:120, 40:] = 2.5
463
+
464
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
465
+ for i, t in enumerate(timesteps):
466
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
467
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
468
+
469
+ noise_pred = self.unet(latent_model_input, t,
470
+ encoder_hidden_states=prompt_embeds,
471
+ cross_attention_kwargs=cross_attention_kwargs,
472
+ return_dict=False)[0]
473
+
474
+ adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3
475
+ if do_classifier_free_guidance:
476
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
477
+ noise_pred = noise_pred_uncond + \
478
+ tmp_guidance_scale * adaptive_guidance_scale * \
479
+ (noise_pred_text - noise_pred_uncond)
480
+
481
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
482
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
483
+
484
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
485
+ if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0):
486
+ progress_bar.update()
487
+
488
+ latents = unscale_latents(latents)
489
+ image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
490
+ image = self.image_processor.postprocess(image, output_type='pil')[0]
491
+ image = [image, cond_image]
492
+ return ImagePipelineOutput(images=image) if return_dict else (image,)
493
+
mvd/hunyuan3d_mvd_std_pipeline.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import inspect
24
+ from typing import Any, Dict, Optional
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ import os
28
+ import torch
29
+ import numpy as np
30
+ from PIL import Image
31
+
32
+ import diffusers
33
+ from diffusers.image_processor import VaeImageProcessor
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.utils.import_utils import is_xformers_available
38
+ from diffusers.models.attention_processor import (
39
+ Attention,
40
+ AttnProcessor,
41
+ XFormersAttnProcessor,
42
+ AttnProcessor2_0
43
+ )
44
+ from diffusers import (
45
+ AutoencoderKL,
46
+ DDPMScheduler,
47
+ DiffusionPipeline,
48
+ EulerAncestralDiscreteScheduler,
49
+ UNet2DConditionModel,
50
+ ImagePipelineOutput
51
+ )
52
+ import transformers
53
+ from transformers import (
54
+ CLIPImageProcessor,
55
+ CLIPTextModel,
56
+ CLIPTokenizer,
57
+ CLIPVisionModelWithProjection,
58
+ CLIPTextModelWithProjection
59
+ )
60
+
61
+ from .utils import to_rgb_image, white_out_background, recenter_img
62
+
63
+ EXAMPLE_DOC_STRING = """
64
+ Examples:
65
+ ```py
66
+ >>> import torch
67
+ >>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
68
+
69
+ >>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
70
+ ... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+
74
+ >>> img = Image.open("demo.png")
75
+ >>> res_img = pipe(img).images[0]
76
+ ```
77
+ """
78
+
79
+
80
+
81
+ def scale_latents(latents): return (latents - 0.22) * 0.75
82
+ def unscale_latents(latents): return (latents / 0.75) + 0.22
83
+ def scale_image(image): return (image - 0.5) / 0.5
84
+ def scale_image_2(image): return (image * 0.5) / 0.8
85
+ def unscale_image(image): return (image * 0.5) + 0.5
86
+ def unscale_image_2(image): return (image * 0.8) / 0.5
87
+
88
+
89
+
90
+
91
+ class ReferenceOnlyAttnProc(torch.nn.Module):
92
+ def __init__(self, chained_proc, enabled=False, name=None):
93
+ super().__init__()
94
+ self.enabled = enabled
95
+ self.chained_proc = chained_proc
96
+ self.name = name
97
+
98
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
99
+ encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
100
+ if self.enabled:
101
+ if mode == 'w': ref_dict[self.name] = encoder_hidden_states
102
+ elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
103
+ else: raise Exception(f"mode should not be {mode}")
104
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
105
+
106
+
107
+ class RefOnlyNoisedUNet(torch.nn.Module):
108
+ def __init__(self, unet, scheduler) -> None:
109
+ super().__init__()
110
+ self.unet = unet
111
+ self.scheduler = scheduler
112
+
113
+ unet_attn_procs = dict()
114
+ for name, _ in unet.attn_processors.items():
115
+ if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
116
+ elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
117
+ else: default_attn_proc = AttnProcessor()
118
+ unet_attn_procs[name] = ReferenceOnlyAttnProc(
119
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
120
+ )
121
+ unet.set_attn_processor(unet_attn_procs)
122
+
123
+ def __getattr__(self, name: str):
124
+ try:
125
+ return super().__getattr__(name)
126
+ except AttributeError:
127
+ return getattr(self.unet, name)
128
+
129
+ def forward(
130
+ self,
131
+ sample: torch.FloatTensor,
132
+ timestep: Union[torch.Tensor, float, int],
133
+ encoder_hidden_states: torch.Tensor,
134
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
135
+ class_labels: Optional[torch.Tensor] = None,
136
+ down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
137
+ mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
138
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
139
+ return_dict: bool = True,
140
+ **kwargs
141
+ ):
142
+
143
+ dtype = self.unet.dtype
144
+
145
+ # cond_lat add same level noise
146
+ cond_lat = cross_attention_kwargs['cond_lat']
147
+ noise = torch.randn_like(cond_lat)
148
+
149
+ noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
150
+ noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
151
+
152
+ ref_dict = {}
153
+
154
+ _ = self.unet(
155
+ noisy_cond_lat,
156
+ timestep,
157
+ encoder_hidden_states = encoder_hidden_states,
158
+ class_labels = class_labels,
159
+ cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
160
+ added_cond_kwargs = added_cond_kwargs,
161
+ return_dict = return_dict,
162
+ **kwargs
163
+ )
164
+
165
+ res = self.unet(
166
+ sample,
167
+ timestep,
168
+ encoder_hidden_states,
169
+ class_labels=class_labels,
170
+ cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
171
+ down_block_additional_residuals = [
172
+ sample.to(dtype=dtype) for sample in down_block_res_samples
173
+ ] if down_block_res_samples is not None else None,
174
+ mid_block_additional_residual = (
175
+ mid_block_res_sample.to(dtype=dtype)
176
+ if mid_block_res_sample is not None else None),
177
+ added_cond_kwargs = added_cond_kwargs,
178
+ return_dict = return_dict,
179
+ **kwargs
180
+ )
181
+ return res
182
+
183
+
184
+
185
+ class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
186
+ def __init__(
187
+ self,
188
+ vae: AutoencoderKL,
189
+ unet: UNet2DConditionModel,
190
+ scheduler: KarrasDiffusionSchedulers,
191
+ feature_extractor_vae: CLIPImageProcessor,
192
+ vision_processor: CLIPImageProcessor,
193
+ vision_encoder: CLIPVisionModelWithProjection,
194
+ vision_encoder_2: CLIPVisionModelWithProjection,
195
+ ramping_coefficients: Optional[list] = None,
196
+ add_watermarker: Optional[bool] = None,
197
+ safety_checker = None,
198
+ ):
199
+ DiffusionPipeline.__init__(self)
200
+
201
+ self.register_modules(
202
+ vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
203
+ vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
204
+ )
205
+ self.register_to_config( ramping_coefficients = ramping_coefficients)
206
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
207
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
208
+ self.default_sample_size = self.unet.config.sample_size
209
+ self.watermark = None
210
+ self.prepare_init = False
211
+
212
+ def prepare(self):
213
+ assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
214
+ self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
215
+ self.prepare_init = True
216
+
217
+ def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
218
+ latent = self.vae.encode(image).latent_dist.sample()
219
+ return (latent * self.vae.config.scaling_factor) if scale_factor else latent
220
+
221
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
222
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
223
+ shape = (
224
+ batch_size,
225
+ num_channels_latents,
226
+ int(height) // self.vae_scale_factor,
227
+ int(width) // self.vae_scale_factor,
228
+ )
229
+ if isinstance(generator, list) and len(generator) != batch_size:
230
+ raise ValueError(
231
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
232
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
233
+ )
234
+
235
+ if latents is None:
236
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
237
+ else:
238
+ latents = latents.to(device)
239
+
240
+ # scale the initial noise by the standard deviation required by the scheduler
241
+ latents = latents * self.scheduler.init_noise_sigma
242
+ return latents
243
+
244
+ def _get_add_time_ids(
245
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
246
+ ):
247
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
248
+
249
+ passed_add_embed_dim = (
250
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
251
+ )
252
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
253
+
254
+ if expected_add_embed_dim != passed_add_embed_dim:
255
+ raise ValueError(
256
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
257
+ f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
258
+ f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
259
+ )
260
+
261
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
262
+ return add_time_ids
263
+
264
+ def prepare_extra_step_kwargs(self, generator, eta):
265
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
266
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
267
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
268
+ # and should be between [0, 1]
269
+
270
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
271
+ extra_step_kwargs = {}
272
+ if accepts_eta: extra_step_kwargs["eta"] = eta
273
+
274
+ # check if the scheduler accepts generator
275
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
276
+ if accepts_generator: extra_step_kwargs["generator"] = generator
277
+ return extra_step_kwargs
278
+
279
+ @property
280
+ def guidance_scale(self):
281
+ return self._guidance_scale
282
+
283
+ @property
284
+ def interrupt(self):
285
+ return self._interrupt
286
+
287
+ @property
288
+ def do_classifier_free_guidance(self):
289
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
290
+
291
+ @torch.no_grad()
292
+ def __call__(
293
+ self,
294
+ image: Image.Image = None,
295
+ guidance_scale = 2.0,
296
+ output_type: Optional[str] = "pil",
297
+ num_inference_steps: int = 50,
298
+ return_dict: bool = True,
299
+ eta: float = 0.0,
300
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
301
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
302
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
303
+ latent: torch.Tensor = None,
304
+ guidance_curve = None,
305
+ **kwargs
306
+ ):
307
+ if not self.prepare_init:
308
+ self.prepare()
309
+
310
+ here = dict(device=self.vae.device, dtype=self.vae.dtype)
311
+
312
+ batch_size = 1
313
+ num_images_per_prompt = 1
314
+ width, height = 512 * 2, 512 * 3
315
+ target_size = original_size = (height, width)
316
+
317
+ self._guidance_scale = guidance_scale
318
+ self._cross_attention_kwargs = cross_attention_kwargs
319
+ self._interrupt = False
320
+
321
+ device = self._execution_device
322
+
323
+ # Prepare timesteps
324
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
325
+ timesteps = self.scheduler.timesteps
326
+
327
+ # Prepare latent variables
328
+ num_channels_latents = self.unet.config.in_channels
329
+ latents = self.prepare_latents(
330
+ batch_size * num_images_per_prompt,
331
+ num_channels_latents,
332
+ height,
333
+ width,
334
+ self.vae.dtype,
335
+ device,
336
+ generator,
337
+ latents=latent,
338
+ )
339
+
340
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
341
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
342
+
343
+
344
+ # Prepare added time ids & embeddings
345
+ text_encoder_projection_dim = 1280
346
+ add_time_ids = self._get_add_time_ids(
347
+ original_size,
348
+ crops_coords_top_left,
349
+ target_size,
350
+ dtype=self.vae.dtype,
351
+ text_encoder_projection_dim=text_encoder_projection_dim,
352
+ )
353
+ negative_add_time_ids = add_time_ids
354
+
355
+ # hw: preprocess
356
+ cond_image = recenter_img(image)
357
+ cond_image = to_rgb_image(image)
358
+ image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
359
+ image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
360
+
361
+ # hw: get cond_lat from cond_img using vae
362
+ cond_lat = self.encode_image(image_vae, scale_factor=False)
363
+ negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
364
+ cond_lat = torch.cat([negative_lat, cond_lat])
365
+
366
+ # hw: get visual global embedding using clip
367
+ global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
368
+ global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
369
+ global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
370
+
371
+ ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
372
+ prompt_embeds = self.uc_text_emb.to(**here)
373
+ pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
374
+
375
+ prompt_embeds = prompt_embeds + global_embeds * ramp
376
+ add_text_embeds = pooled_prompt_embeds
377
+
378
+ if self.do_classifier_free_guidance:
379
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
380
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
381
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
382
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
383
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
384
+
385
+ prompt_embeds = prompt_embeds.to(device)
386
+ add_text_embeds = add_text_embeds.to(device)
387
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
388
+
389
+ # Denoising loop
390
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
391
+ timestep_cond = None
392
+ self._num_timesteps = len(timesteps)
393
+
394
+ if guidance_curve is None:
395
+ guidance_curve = lambda t: guidance_scale
396
+
397
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
398
+ for i, t in enumerate(timesteps):
399
+ if self.interrupt:
400
+ continue
401
+
402
+ # expand the latents if we are doing classifier free guidance
403
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
404
+
405
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
406
+
407
+ # predict the noise residual
408
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
409
+
410
+ noise_pred = self.unet(
411
+ latent_model_input,
412
+ t,
413
+ encoder_hidden_states=prompt_embeds,
414
+ timestep_cond=timestep_cond,
415
+ cross_attention_kwargs=dict(cond_lat=cond_lat),
416
+ added_cond_kwargs=added_cond_kwargs,
417
+ return_dict=False,
418
+ )[0]
419
+
420
+ # perform guidance
421
+
422
+ # cur_guidance_scale = self.guidance_scale
423
+ cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
424
+
425
+ if self.do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
428
+
429
+ # cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
430
+ # noise_pred_top_left = noise_pred_uncond +
431
+ # cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
432
+ # _, _, h, w = noise_pred.shape
433
+ # noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
434
+
435
+ # compute the previous noisy sample x_t -> x_t-1
436
+ latents_dtype = latents.dtype
437
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
438
+
439
+ # call the callback, if provided
440
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
441
+ progress_bar.update()
442
+
443
+ latents = unscale_latents(latents)
444
+
445
+ if output_type=="latent":
446
+ image = latents
447
+ else:
448
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
449
+ image = unscale_image(unscale_image_2(image)).clamp(0, 1)
450
+ image = [
451
+ Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
452
+ # self.image_processor.postprocess(image, output_type=output_type)[0],
453
+ cond_image.resize((512, 512))
454
+ ]
455
+
456
+ if not return_dict: return (image,)
457
+ return ImagePipelineOutput(images=image)
458
+
459
+ def save_pretrained(self, save_directory):
460
+ # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
461
+ super().save_pretrained(save_directory)
462
+ torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
463
+ torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
464
+
465
+ @classmethod
466
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
467
+ # uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
468
+ pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
469
+ pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
470
+ pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
471
+ return pipeline
mvd/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import numpy as np
24
+ from PIL import Image
25
+
26
+ def to_rgb_image(maybe_rgba: Image.Image):
27
+ '''
28
+ convert a PIL.Image to rgb mode with white background
29
+ maybe_rgba: PIL.Image
30
+ return: PIL.Image
31
+ '''
32
+ if maybe_rgba.mode == 'RGB':
33
+ return maybe_rgba
34
+ elif maybe_rgba.mode == 'RGBA':
35
+ rgba = maybe_rgba
36
+ img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
37
+ img = Image.fromarray(img, 'RGB')
38
+ img.paste(rgba, mask=rgba.getchannel('A'))
39
+ return img
40
+ else:
41
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
42
+
43
+ def white_out_background(pil_img, is_gray_fg=True):
44
+ data = pil_img.getdata()
45
+ new_data = []
46
+ # convert fore-ground white to gray
47
+ for r, g, b, a in data:
48
+ if a < 16:
49
+ new_data.append((255, 255, 255, 0)) # back-ground to be black
50
+ else:
51
+ is_white = is_gray_fg and (r>235) and (g>235) and (b>235)
52
+ new_r = 235 if is_white else r
53
+ new_g = 235 if is_white else g
54
+ new_b = 235 if is_white else b
55
+ new_data.append((new_r, new_g, new_b, a))
56
+ pil_img.putdata(new_data)
57
+ return pil_img
58
+
59
+ def recenter_img(img, size=512, color=(255,255,255)):
60
+ img = white_out_background(img)
61
+ mask = np.array(img)[..., 3]
62
+ image = np.array(img)[..., :3]
63
+
64
+ H, W, C = image.shape
65
+ coords = np.nonzero(mask)
66
+ x_min, x_max = coords[0].min(), coords[0].max()
67
+ y_min, y_max = coords[1].min(), coords[1].max()
68
+ h = x_max - x_min
69
+ w = y_max - y_min
70
+ if h == 0 or w == 0: raise ValueError
71
+ roi = image[x_min:x_max, y_min:y_max]
72
+
73
+ border_ratio = 0.15 # 0.2
74
+ pad_h = int(h * border_ratio)
75
+ pad_w = int(w * border_ratio)
76
+
77
+ result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8)
78
+ result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi
79
+
80
+ cur_h, cur_w = result_tmp.shape[:2]
81
+ side = max(cur_h, cur_w)
82
+ result = np.full((side, side, C), color, dtype=np.uint8)
83
+ result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp
84
+ result = Image.fromarray(result)
85
+ return result.resize((size, size), Image.LANCZOS) if size else result
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/cu118
2
+ torch==2.2.0
3
+ torchvision==0.17.0
4
+ diffusers
5
+ transformers
6
+ rembg
7
+ tqdm
8
+ omegaconf
9
+ matplotlib
10
+ opencv-python
11
+ imageio
12
+ jaxtyping
13
+ einops
14
+ SentencePiece
15
+ accelerate
16
+ trimesh
17
+ PyMCubes
18
+ xatlas
19
+ libigl
20
+ git+https://github.com/facebookresearch/pytorch3d
21
+ git+https://github.com/NVlabs/nvdiffrast
22
+ open3d
scripts/image_to_3d.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # image to 3d
2
+
3
+ python main.py \
4
+ --image_prompt ./demos/example_000.png \
5
+ --save_folder ./outputs/test/ \
6
+ --max_faces_num 90000 \
7
+ --do_texture \
8
+ --do_render
scripts/image_to_3d_demo.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # image to 3d
2
+
3
+ python main.py \
4
+ --image_prompt ./demos/example_000.png \
5
+ --save_folder ./outputs/test/ \
6
+ --max_faces_num 90000 \
7
+ --do_texture_mapping \
8
+ --do_render
scripts/image_to_3d_fast.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # image to 3d fast
2
+ python main.py \
3
+ --image_prompt ./demos/example_000.png \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 10000 \
6
+ --use_lite
scripts/image_to_3d_fast_demo.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # image to 3d fast
2
+ python main.py \
3
+ --image_prompt ./demos/example_000.png \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 10000 \
6
+ --use_lite
scripts/text_to_3d.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # text to 3d fast
2
+ python main.py \
3
+ --text_prompt "a lovely cat" \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 90000 \
6
+ --do_texture \
7
+ --do_render
scripts/text_to_3d_demo.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # text to 3d fast
2
+ python main.py \
3
+ --text_prompt "a lovely rabbit" \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 90000 \
6
+ --do_texture_mapping \
7
+ --do_render
scripts/text_to_3d_fast.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # text to 3d fast
2
+ python main.py \
3
+ --text_prompt "一个广式茶杯" \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 10000 \
6
+ --use_lite
scripts/text_to_3d_fast_demo.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # text to 3d fast
2
+ python main.py \
3
+ --text_prompt "一个广式茶杯" \
4
+ --save_folder ./outputs/test/ \
5
+ --max_faces_num 10000 \
6
+ --use_lite
svrm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
svrm/configs/2024-10-24T22-36-18-project.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 3.0e-05
3
+ target: svrm.ldm.models.svrm.SVRMModel
4
+ params:
5
+
6
+ img_encoder_config:
7
+ target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
8
+ params:
9
+ version: dinov2_vitb14
10
+
11
+ img_to_triplane_config:
12
+ target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
13
+ params:
14
+ pos_emb_size: 64
15
+ pos_emb_dim: 1024
16
+ cam_cond_dim: 20
17
+ n_heads: 16
18
+ d_head: 64
19
+ depth: 16
20
+ context_dim: 768
21
+ triplane_dim: 120
22
+ use_fp16: true
23
+ use_bf16: false
24
+ upsample_time: 2
25
+
26
+ render_config:
27
+ target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
28
+ params:
29
+ triplane_dim: 120
30
+ samples_per_ray: 128
31
+
32
+
svrm/configs/svrm.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 3.0e-05
3
+ target: svrm.ldm.models.svrm.SVRMModel
4
+ params:
5
+
6
+ img_encoder_config:
7
+ target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
8
+ params:
9
+ version: dinov2_vitb14
10
+
11
+ img_to_triplane_config:
12
+ target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
13
+ params:
14
+ pos_emb_size: 64
15
+ pos_emb_dim: 1024
16
+ cam_cond_dim: 20
17
+ n_heads: 16
18
+ d_head: 64
19
+ depth: 16
20
+ context_dim: 768
21
+ triplane_dim: 120
22
+ use_fp16: true
23
+ use_bf16: false
24
+ upsample_time: 2
25
+
26
+ render_config:
27
+ target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
28
+ params:
29
+ triplane_dim: 120
30
+ samples_per_ray: 128
31
+
32
+
svrm/ldm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
svrm/ldm/models/svrm.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
2
+ # The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
3
+
4
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
5
+ # The below software and/or models in this distribution may have been
6
+ # modified by THL A29 Limited ("Tencent Modifications").
7
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
8
+
9
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
10
+ # except for the third-party components listed below.
11
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
12
+ # in the repsective licenses of these third-party components.
13
+ # Users must comply with all terms and conditions of original licenses of these third-party
14
+ # components and must ensure that the usage of the third party components adheres to
15
+ # all relevant laws and regulations.
16
+
17
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
18
+ # their software and algorithms, including trained model weights, parameters (including
19
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
20
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
21
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
22
+
23
+ import os
24
+ import time
25
+ import math
26
+ import cv2
27
+ import numpy as np
28
+ import itertools
29
+ import shutil
30
+ from tqdm import tqdm
31
+ import torch
32
+ import torch.nn.functional as F
33
+ from einops import rearrange
34
+ try:
35
+ import trimesh
36
+ import mcubes
37
+ import xatlas
38
+ import open3d as o3d
39
+ except:
40
+ raise "failed to import 3d libraries "
41
+
42
+ from ..modules.rendering_neus.mesh import Mesh
43
+ from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
44
+
45
+ from ..utils.ops import scale_tensor
46
+ from ..util import count_params, instantiate_from_config
47
+ from ..vis_util import render
48
+
49
+
50
+ def unwrap_uv(v_pos, t_pos_idx):
51
+ print("Using xatlas to perform UV unwrapping, may take a while ...")
52
+ atlas = xatlas.Atlas()
53
+ atlas.add_mesh(v_pos, t_pos_idx)
54
+ atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions())
55
+ _, indices, uvs = atlas.get_mesh(0)
56
+ indices = indices.astype(np.int64, casting="same_kind")
57
+ return uvs, indices
58
+
59
+
60
+ def uv_padding(image, hole_mask, uv_padding_size = 2):
61
+ return cv2.inpaint(
62
+ (image.detach().cpu().numpy() * 255).astype(np.uint8),
63
+ (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
64
+ uv_padding_size,
65
+ cv2.INPAINT_TELEA
66
+ )
67
+
68
+ def refine_mesh(vtx_refine, faces_refine):
69
+ mesh = o3d.geometry.TriangleMesh(
70
+ vertices=o3d.utility.Vector3dVector(vtx_refine),
71
+ triangles=o3d.utility.Vector3iVector(faces_refine))
72
+
73
+ mesh = mesh.remove_unreferenced_vertices()
74
+ mesh = mesh.remove_duplicated_triangles()
75
+ mesh = mesh.remove_duplicated_vertices()
76
+
77
+ voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound())
78
+
79
+ mesh = mesh.simplify_vertex_clustering(
80
+ voxel_size=0.007, # 0.005
81
+ contraction=o3d.geometry.SimplificationContraction.Average)
82
+
83
+ mesh = mesh.filter_smooth_simple(number_of_iterations=2)
84
+
85
+ vtx_refine = np.asarray(mesh.vertices).astype(np.float32)
86
+ faces_refine = np.asarray(mesh.triangles)
87
+ return vtx_refine, faces_refine, mesh
88
+
89
+
90
+ class SVRMModel(torch.nn.Module):
91
+ def __init__(
92
+ self,
93
+ img_encoder_config,
94
+ img_to_triplane_config,
95
+ render_config,
96
+ device = "cuda:0",
97
+ **kwargs
98
+ ):
99
+ super().__init__()
100
+
101
+ self.img_encoder = instantiate_from_config(img_encoder_config).half()
102
+ self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half()
103
+ self.render = instantiate_from_config(render_config).half()
104
+ self.device = device
105
+ count_params(self, verbose=True)
106
+
107
+ @torch.no_grad()
108
+ def export_mesh_with_uv(
109
+ self,
110
+ data,
111
+ mesh_size: int = 384,
112
+ ctx = None,
113
+ context_type = 'cuda',
114
+ texture_res = 1024,
115
+ target_face_count = 10000,
116
+ do_texture_mapping = True,
117
+ out_dir = 'outputs/test'
118
+ ):
119
+ """
120
+ color_type: 0 for ray texture, 1 for vertices texture
121
+ """
122
+ st = time.time()
123
+ here = {'device': self.device, 'dtype': torch.float16}
124
+ input_view_image = data["input_view"].to(**here) # [b, m, c, h, w]
125
+ input_view_cam = data["input_view_cam"].to(**here) # [b, m, 20]
126
+
127
+ batch_size, input_view_num, *_ = input_view_image.shape
128
+ assert batch_size == 1, "batch size should be 1"
129
+
130
+ input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w')
131
+ input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d')
132
+ input_view_feat = self.img_encoder(input_view_image, input_view_cam)
133
+ input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num)
134
+
135
+ # -- decoder
136
+ torch.cuda.empty_cache()
137
+ triplane_gen = self.img_to_triplane_decoder(input_view_feat) # [b, 3, tri_dim, h, w]
138
+ del input_view_feat
139
+ torch.cuda.empty_cache()
140
+
141
+ # --- triplane nerf render
142
+
143
+ cur_triplane = triplane_gen[0:1]
144
+
145
+ aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here)
146
+ grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb)
147
+
148
+ print(f"=====> LRM forward time: {time.time() - st}")
149
+ st = time.time()
150
+
151
+ vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0)
152
+
153
+ bbox = aabb[0].cpu().numpy()
154
+ vtx = vtx / (mesh_size - 1)
155
+ vtx = vtx * (bbox[1] - bbox[0]) + bbox[0]
156
+
157
+ # refine mesh
158
+ vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces)
159
+
160
+ # reduce faces
161
+ if faces_refine.shape[0] > target_face_count:
162
+ print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}")
163
+ mesh = o3d.geometry.TriangleMesh(
164
+ vertices = o3d.utility.Vector3dVector(vtx_refine),
165
+ triangles = o3d.utility.Vector3iVector(faces_refine)
166
+ )
167
+
168
+ # Function to simplify mesh using Quadric Error Metric Decimation by Garland and Heckbert
169
+ mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0)
170
+
171
+ mesh = Mesh(
172
+ v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device),
173
+ t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device),
174
+ v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device)
175
+ )
176
+ vtx_refine = mesh.v_pos.cpu().numpy()
177
+ faces_refine = mesh.t_pos_idx.cpu().numpy()
178
+
179
+ vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here))
180
+ vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy()
181
+
182
+ color_ratio = 0.8 # increase brightness
183
+ with open(f'{out_dir}/mesh_with_colors.obj', 'w') as fid:
184
+ verts = vtx_refine[:, [1,2,0]]
185
+ for pidx, pp in enumerate(verts):
186
+ color = vtx_colors[pidx]
187
+ color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio]
188
+ fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
189
+ for i, f in enumerate(faces_refine):
190
+ f1 = f + 1
191
+ fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2]))
192
+
193
+ mesh = trimesh.load_mesh(f'{out_dir}/mesh_with_colors.obj')
194
+ print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
195
+ st = time.time()
196
+
197
+ if not do_texture_mapping:
198
+ shutil.copy(f'{out_dir}/mesh_with_colors.obj', f'{out_dir}/mesh.obj')
199
+ mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
200
+ return None
201
+
202
+ ########## export texture ########
203
+ st = time.time()
204
+
205
+ # uv unwrap
206
+ vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine)
207
+ vtx_refine = torch.from_numpy(vtx_refine).to(self.device)
208
+ faces_refine = torch.from_numpy(faces_refine).to(self.device)
209
+ t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device)
210
+ uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device)
211
+
212
+ # rasterize
213
+ ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx
214
+ rast = ctx.rasterize_one(
215
+ torch.cat([
216
+ uv_clip,
217
+ torch.zeros_like(uv_clip[..., 0:1]),
218
+ torch.ones_like(uv_clip[..., 0:1])
219
+ ], dim=-1),
220
+ t_tex_idx,
221
+ (texture_res, texture_res)
222
+ )[0]
223
+ hole_mask = ~(rast[:, :, 3] > 0)
224
+
225
+ # Interpolate world space position
226
+ gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
227
+ with torch.no_grad():
228
+ gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
229
+ tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
230
+ tex_map = tex_map.float().squeeze(0) # (0, 1)
231
+ tex_map = tex_map.view((texture_res, texture_res, 3))
232
+ img = uv_padding(tex_map, hole_mask)
233
+ img = ((img/255.0) ** color_ratio) * 255 # increase brightness
234
+ img = img.clip(0, 255).astype(np.uint8)
235
+
236
+ verts = vtx_refine.cpu().numpy()[:, [1,2,0]]
237
+ faces = faces_refine.cpu().numpy()
238
+
239
+ with open(f'{out_dir}/texture.mtl', 'w') as fid:
240
+ fid.write('newmtl material_0\n')
241
+ fid.write("Ka 1.000 1.000 1.000\n")
242
+ fid.write("Kd 1.000 1.000 1.000\n")
243
+ fid.write("Ks 0.000 0.000 0.000\n")
244
+ fid.write("d 1.0\n")
245
+ fid.write("illum 2\n")
246
+ fid.write(f'map_Kd texture.png\n')
247
+
248
+ with open(f'{out_dir}/mesh.obj', 'w') as fid:
249
+ fid.write(f'mtllib texture.mtl\n')
250
+ for pidx, pp in enumerate(verts):
251
+ fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
252
+ for pidx, pp in enumerate(vtx_tex):
253
+ fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
254
+ fid.write('usemtl material_0\n')
255
+ for i, f in enumerate(faces):
256
+ f1 = f + 1
257
+ f2 = t_tex_idx[i] + 1
258
+ fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],))
259
+
260
+ cv2.imwrite(f'{out_dir}/texture.png', img[..., [2, 1, 0]])
261
+ mesh = trimesh.load_mesh(f'{out_dir}/mesh.obj')
262
+ mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
263
+
svrm/ldm/modules/attention.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ import numpy as np
8
+
9
+ FLASH_IS_AVAILABLE = XFORMERS_IS_AVAILBLE = False
10
+ try:
11
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
12
+ FLASH_IS_AVAILABLE = True
13
+ except:
14
+ try:
15
+ import xformers
16
+ import xformers.ops
17
+ XFORMERS_IS_AVAILBLE = True
18
+ except:
19
+ pass
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+
25
+ def uniq(arr):
26
+ return{el: True for el in arr}.keys()
27
+
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if isfunction(d) else d
33
+
34
+
35
+ def max_neg_value(t):
36
+ return -torch.finfo(t.dtype).max
37
+
38
+
39
+ def init_(tensor):
40
+ dim = tensor.shape[-1]
41
+ std = 1 / math.sqrt(dim)
42
+ tensor.uniform_(-std, std)
43
+ return tensor
44
+
45
+ def checkpoint(func, inputs, params, flag):
46
+ """
47
+ Evaluate a function without caching intermediate activations, allowing for
48
+ reduced memory at the expense of extra compute in the backward pass.
49
+ :param func: the function to evaluate.
50
+ :param inputs: the argument sequence to pass to `func`.
51
+ :param params: a sequence of parameters `func` depends on but does not
52
+ explicitly take as arguments.
53
+ :param flag: if False, disable gradient checkpointing.
54
+ """
55
+ if flag:
56
+ args = tuple(inputs) + tuple(params)
57
+ return CheckpointFunction.apply(func, len(inputs), *args)
58
+ else:
59
+ return func(*inputs)
60
+
61
+
62
+ class CheckpointFunction(torch.autograd.Function):
63
+ @staticmethod
64
+ def forward(ctx, run_function, length, *args):
65
+ ctx.run_function = run_function
66
+ ctx.input_tensors = list(args[:length])
67
+ ctx.input_params = list(args[length:])
68
+
69
+ with torch.no_grad():
70
+ output_tensors = ctx.run_function(*ctx.input_tensors)
71
+ return output_tensors
72
+
73
+ @staticmethod
74
+ def backward(ctx, *output_grads):
75
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
76
+ with torch.enable_grad():
77
+ # Fixes a bug where the first op in run_function modifies the
78
+ # Tensor storage in place, which is not allowed for detach()'d
79
+ # Tensors.
80
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
81
+ output_tensors = ctx.run_function(*shallow_copies)
82
+ input_grads = torch.autograd.grad(
83
+ output_tensors,
84
+ ctx.input_tensors + ctx.input_params,
85
+ output_grads,
86
+ allow_unused=True,
87
+ )
88
+ del ctx.input_tensors
89
+ del ctx.input_params
90
+ del output_tensors
91
+ return (None, None) + input_grads
92
+
93
+
94
+ # feedforward
95
+ class GEGLU(nn.Module):
96
+ def __init__(self, dim_in, dim_out):
97
+ super().__init__()
98
+ self.proj = nn.Linear(dim_in, dim_out * 2)
99
+
100
+ def forward(self, x):
101
+ x, gate = self.proj(x).chunk(2, dim=-1)
102
+ return x * F.gelu(gate)
103
+
104
+
105
+ class FeedForward(nn.Module):
106
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
107
+ super().__init__()
108
+ inner_dim = int(dim * mult)
109
+ dim_out = default(dim_out, dim)
110
+ project_in = nn.Sequential(
111
+ nn.Linear(dim, inner_dim),
112
+ nn.GELU()
113
+ ) if not glu else GEGLU(dim, inner_dim)
114
+
115
+ self.net = nn.Sequential(
116
+ project_in,
117
+ nn.Dropout(dropout),
118
+ nn.Linear(inner_dim, dim_out)
119
+ )
120
+
121
+ def forward(self, x):
122
+ return self.net(x)
123
+
124
+
125
+ def zero_module(module):
126
+ """
127
+ Zero out the parameters of a module and return it.
128
+ """
129
+ for p in module.parameters():
130
+ p.detach().zero_()
131
+ return module
132
+
133
+
134
+ def Normalize(in_channels):
135
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
136
+
137
+
138
+ class LinearAttention(nn.Module):
139
+ def __init__(self, dim, heads=4, dim_head=32):
140
+ super().__init__()
141
+ self.heads = heads
142
+ hidden_dim = dim_head * heads
143
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
144
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
+
146
+ def forward(self, x):
147
+ b, c, h, w = x.shape
148
+ qkv = self.to_qkv(x)
149
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
150
+ k = k.softmax(dim=-1)
151
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
152
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
153
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
154
+ return self.to_out(out)
155
+
156
+
157
+ class SpatialSelfAttention(nn.Module):
158
+ def __init__(self, in_channels):
159
+ super().__init__()
160
+ self.in_channels = in_channels
161
+
162
+ self.norm = Normalize(in_channels)
163
+ self.q = torch.nn.Conv2d(in_channels,
164
+ in_channels,
165
+ kernel_size=1,
166
+ stride=1,
167
+ padding=0)
168
+ self.k = torch.nn.Conv2d(in_channels,
169
+ in_channels,
170
+ kernel_size=1,
171
+ stride=1,
172
+ padding=0)
173
+ self.v = torch.nn.Conv2d(in_channels,
174
+ in_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0)
178
+ self.proj_out = torch.nn.Conv2d(in_channels,
179
+ in_channels,
180
+ kernel_size=1,
181
+ stride=1,
182
+ padding=0)
183
+
184
+ def forward(self, x):
185
+ h_ = x
186
+ h_ = self.norm(h_)
187
+ q = self.q(h_)
188
+ k = self.k(h_)
189
+ v = self.v(h_)
190
+
191
+ # compute attention
192
+ b,c,h,w = q.shape
193
+ q = rearrange(q, 'b c h w -> b (h w) c')
194
+ k = rearrange(k, 'b c h w -> b c (h w)')
195
+ w_ = torch.einsum('bij,bjk->bik', q, k)
196
+
197
+ w_ = w_ * (int(c)**(-0.5))
198
+ w_ = torch.nn.functional.softmax(w_, dim=2)
199
+
200
+ # attend to values
201
+ v = rearrange(v, 'b c h w -> b c (h w)')
202
+ w_ = rearrange(w_, 'b i j -> b j i')
203
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
204
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
205
+ h_ = self.proj_out(h_)
206
+
207
+ return x+h_
208
+
209
+
210
+ class CrossAttention(nn.Module):
211
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
212
+ super().__init__()
213
+ inner_dim = dim_head * heads
214
+ context_dim = default(context_dim, query_dim)
215
+ self.scale = dim_head ** -0.5
216
+ self.heads = heads
217
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
218
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
219
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
220
+
221
+ self.to_out = nn.Sequential(
222
+ nn.Linear(inner_dim, query_dim),
223
+ nn.Dropout(dropout)
224
+ )
225
+
226
+ def forward(self, x, context=None, mask=None):
227
+ h = self.heads
228
+ q = self.to_q(x)
229
+ context = default(context, x)
230
+ k = self.to_k(context)
231
+ v = self.to_v(context)
232
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
233
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
234
+ if exists(mask):
235
+ mask = rearrange(mask, 'b ... -> b (...)')
236
+ max_neg_value = -torch.finfo(sim.dtype).max
237
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
238
+ sim.masked_fill_(~mask, max_neg_value)
239
+ # attention, what we cannot get enough of
240
+ attn = sim.softmax(dim=-1)
241
+ out = einsum('b i j, b j d -> b i d', attn, v) # [b*h, n, d]
242
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
243
+ return self.to_out(out)
244
+
245
+
246
+ class FlashAttention(nn.Module):
247
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
248
+ super().__init__()
249
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
250
+ f"{heads} heads.")
251
+ inner_dim = dim_head * heads
252
+ context_dim = default(context_dim, query_dim)
253
+ self.scale = dim_head ** -0.5
254
+ self.heads = heads
255
+ self.dropout = dropout
256
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
257
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
258
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
259
+ self.to_out = nn.Sequential(
260
+ nn.Linear(inner_dim, query_dim),
261
+ nn.Dropout(dropout)
262
+ )
263
+
264
+ def forward(self, x, context=None, mask=None):
265
+ context = default(context, x)
266
+ h = self.heads
267
+ dtype = torch.bfloat16 # torch.half
268
+ q = self.to_q(x).to(dtype)
269
+ k = self.to_k(context).to(dtype)
270
+ v = self.to_v(context).to(dtype)
271
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
272
+ out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q
273
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
274
+ return self.to_out(out.float())
275
+
276
+ class MemoryEfficientCrossAttention(nn.Module):
277
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
278
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
279
+ super().__init__()
280
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
281
+ f"{heads} heads.")
282
+ inner_dim = dim_head * heads
283
+ context_dim = default(context_dim, query_dim)
284
+
285
+ self.heads = heads
286
+ self.dim_head = dim_head
287
+
288
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
289
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
290
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
291
+
292
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
293
+ self.attention_op: Optional[Any] = None
294
+
295
+ def forward(self, x, context=None, mask=None):
296
+ q = self.to_q(x)
297
+ context = default(context, x)
298
+ k = self.to_k(context)
299
+ v = self.to_v(context)
300
+
301
+ b, _, _ = q.shape
302
+ q, k, v = map(
303
+ lambda t: t.unsqueeze(3)
304
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
305
+ .permute(0, 2, 1, 3)
306
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
307
+ .contiguous(),
308
+ (q, k, v),
309
+ )
310
+
311
+ # actually compute the attention, what we cannot get enough of
312
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
313
+
314
+ if exists(mask):
315
+ raise NotImplementedError
316
+ out = (
317
+ out.unsqueeze(0)
318
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
319
+ .permute(0, 2, 1, 3)
320
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
321
+ )
322
+ return self.to_out(out)
323
+
324
+ class BasicTransformerBlock(nn.Module):
325
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
326
+ disable_self_attn=False):
327
+ super().__init__()
328
+ self.disable_self_attn = disable_self_attn
329
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
330
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
331
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
332
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
333
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
334
+ self.norm1 = Fp32LayerNorm(dim)
335
+ self.norm2 = Fp32LayerNorm(dim)
336
+ self.norm3 = Fp32LayerNorm(dim)
337
+ self.checkpoint = checkpoint
338
+
339
+ def forward(self, x, context=None):
340
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
341
+
342
+ def _forward(self, x, context=None):
343
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
344
+ x = self.attn2(self.norm2(x), context=context) + x
345
+ x = self.ff(self.norm3(x)) + x
346
+ return x
347
+
348
+ ATTENTION_MODES = {
349
+ "softmax": CrossAttention, # vanilla attention
350
+ "softmax-xformers": MemoryEfficientCrossAttention,
351
+ "softmax-flash": FlashAttention
352
+ }
353
+
354
+ def modulate(x, shift, scale):
355
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
356
+
357
+
358
+ class Fp32LayerNorm(nn.LayerNorm):
359
+ def __init__(self, *args, **kwargs):
360
+ super().__init__(*args, **kwargs)
361
+ def forward(self, x):
362
+ return super().forward(x.float()).type(x.dtype)
363
+
364
+
365
+ class AdaNorm(nn.Module):
366
+ def __init__(self, dim):
367
+ super().__init__()
368
+ self.adaLN_modulation = nn.Sequential(
369
+ nn.SiLU(),
370
+ nn.Linear(dim, 2 * dim, bias=True)
371
+ )
372
+ self.norm = Fp32LayerNorm(dim, elementwise_affine=False, eps=1e-6)
373
+
374
+ def forward(self, x, c): # x is fp32, c is fp16
375
+ shift, scale = self.adaLN_modulation(c.float()).chunk(2, dim=1) # bf16
376
+ x = modulate(self.norm(x), shift, scale) # fp32
377
+ return x
378
+
379
+
380
+ class BasicTransformerBlockLRM(nn.Module):
381
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, \
382
+ checkpoint=True):
383
+ super().__init__()
384
+
385
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
386
+ attn_mode = "softmax-flash" if FLASH_IS_AVAILABLE else attn_mode
387
+ assert attn_mode in ATTENTION_MODES
388
+ attn_cls = ATTENTION_MODES[attn_mode]
389
+
390
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
391
+ context_dim=context_dim) # cross-attn
392
+ self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
393
+ context_dim=None) # self-attn
394
+
395
+ self.norm1 = Fp32LayerNorm(dim)
396
+ self.norm2 = Fp32LayerNorm(dim)
397
+ self.norm3 = Fp32LayerNorm(dim)
398
+
399
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
400
+ self.checkpoint = checkpoint
401
+
402
+ def forward(self, x, context=None, cam_emb=None): # (torch.float32, torch.float32, torch.bfloat16)
403
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
404
+
405
+
406
+ def _forward(self, x, context=None, cam_emb=None):
407
+
408
+ x = self.attn1(self.norm1(x), context=context) + x # cross-attn
409
+ x = self.attn2(self.norm2(x), context=None) + x # self-attn
410
+ x = self.ff(self.norm3(x)) + x
411
+
412
+ return x
413
+
414
+ class ImgToTriplaneTransformer(nn.Module):
415
+ """
416
+ Transformer block for image-like data.
417
+ First, project the input (aka embedding)
418
+ and reshape to b, t, d.
419
+ Then apply standard transformer action.
420
+ Finally, reshape to image
421
+ """
422
+ def __init__(self, query_dim, n_heads, d_head, depth=1, dropout=0., context_dim=None, triplane_size=64):
423
+ super().__init__()
424
+
425
+ self.transformer_blocks = nn.ModuleList([
426
+ BasicTransformerBlockLRM(query_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
427
+ for d in range(depth)])
428
+
429
+ self.norm = Fp32LayerNorm(query_dim, eps=1e-6)
430
+
431
+ self.initialize_weights()
432
+
433
+ def initialize_weights(self):
434
+ # Initialize transformer layers:
435
+ def _basic_init(module):
436
+ if isinstance(module, nn.Linear):
437
+ torch.nn.init.xavier_uniform_(module.weight)
438
+ if module.bias is not None:
439
+ nn.init.constant_(module.bias, 0)
440
+ elif isinstance(module, nn.LayerNorm):
441
+ if module.bias is not None:
442
+ nn.init.constant_(module.bias, 0)
443
+ if module.weight is not None:
444
+ nn.init.constant_(module.weight, 1.0)
445
+ self.apply(_basic_init)
446
+
447
+ def forward(self, x, context=None, cam_emb=None):
448
+ # note: if no context is given, cross-attention defaults to self-attention
449
+ for block in self.transformer_blocks:
450
+ x = block(x, context=context)
451
+ x = self.norm(x)
452
+ return x
453
+
454
+
455
+
456
+
457
+
svrm/ldm/modules/encoders/__init__.py ADDED
File without changes
svrm/ldm/modules/encoders/dinov2/__init__.py ADDED
File without changes
svrm/ldm/modules/encoders/dinov2/hub/__init__.py ADDED
File without changes
svrm/ldm/modules/encoders/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
svrm/ldm/modules/encoders/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
svrm/ldm/modules/encoders/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlockMod
11
+ from .attention import MemEffAttention
svrm/ldm/modules/encoders/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
svrm/ldm/modules/encoders/dinov2/layers/block.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import os
11
+ import logging
12
+ import warnings
13
+ from typing import Callable, List, Any, Tuple, Dict
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+ from ....attention import AdaNorm
24
+
25
+
26
+ logger = logging.getLogger("dinov2")
27
+
28
+
29
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
30
+ try:
31
+ if XFORMERS_ENABLED:
32
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
33
+
34
+ XFORMERS_AVAILABLE = True
35
+ warnings.warn("xFormers is available (Block)")
36
+ else:
37
+ warnings.warn("xFormers is disabled (Block)")
38
+ raise ImportError
39
+ except ImportError:
40
+ XFORMERS_AVAILABLE = False
41
+
42
+ warnings.warn("xFormers is not available (Block)")
43
+
44
+
45
+ class BlockMod(nn.Module):
46
+ '''
47
+ using Modified Block, see below
48
+ '''
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_heads: int,
53
+ mlp_ratio: float = 4.0,
54
+ qkv_bias: bool = False,
55
+ proj_bias: bool = True,
56
+ ffn_bias: bool = True,
57
+ drop: float = 0.0,
58
+ attn_drop: float = 0.0,
59
+ init_values=None,
60
+ drop_path: float = 0.0,
61
+ act_layer: Callable[..., nn.Module] = nn.GELU,
62
+ norm_layer: Callable[..., nn.Module] = AdaNorm,
63
+ attn_class: Callable[..., nn.Module] = Attention,
64
+ ffn_layer: Callable[..., nn.Module] = Mlp,
65
+ ) -> None:
66
+ super().__init__()
67
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
68
+ self.norm1 = norm_layer(dim)
69
+ self.attn = attn_class(
70
+ dim,
71
+ num_heads=num_heads,
72
+ qkv_bias=qkv_bias,
73
+ proj_bias=proj_bias,
74
+ attn_drop=attn_drop,
75
+ proj_drop=drop,
76
+ )
77
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.norm2 = norm_layer(dim)
81
+ mlp_hidden_dim = int(dim * mlp_ratio)
82
+ self.mlp = ffn_layer(
83
+ in_features=dim,
84
+ hidden_features=mlp_hidden_dim,
85
+ act_layer=act_layer,
86
+ drop=drop,
87
+ bias=ffn_bias,
88
+ )
89
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
90
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
91
+
92
+ self.sample_drop_ratio = drop_path
93
+
94
+ def forward(self, x: Tensor, cam_emb: Tensor) -> Tensor:
95
+ def attn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
96
+ return self.ls1(self.attn(self.norm1(x, cam_emb)))
97
+
98
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
99
+ return self.ls2(self.mlp(self.norm2(x, cam_emb)))
100
+
101
+ if self.training and self.sample_drop_ratio > 0.1:
102
+ # the overhead is compensated only for a drop path rate larger than 0.1
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=attn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ x = drop_add_residual_stochastic_depth(
109
+ x,
110
+ residual_func=ffn_residual_func,
111
+ sample_drop_ratio=self.sample_drop_ratio,
112
+ )
113
+ elif self.training and self.sample_drop_ratio > 0.0:
114
+ x = x + self.drop_path1(attn_residual_func(x, cam_emb))
115
+ x = x + self.drop_path1(ffn_residual_func(x, cam_emb)) # FIXME: drop_path2
116
+ else:
117
+ x = x + attn_residual_func(x, cam_emb)
118
+ x = x + ffn_residual_func(x, cam_emb)
119
+ return x
120
+
121
+
122
+ def drop_add_residual_stochastic_depth(
123
+ x: Tensor,
124
+ residual_func: Callable[[Tensor], Tensor],
125
+ sample_drop_ratio: float = 0.0,
126
+ ) -> Tensor:
127
+ # drop_add_residual_stochastic_depth_list
128
+
129
+ # 1) extract subset using permutation
130
+ b, n, d = x.shape
131
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
132
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
133
+ x_subset = x[brange]
134
+
135
+ # 2) apply residual_func to get residual
136
+ residual = residual_func(x_subset)
137
+
138
+ x_flat = x.flatten(1)
139
+ residual = residual.flatten(1)
140
+
141
+ residual_scale_factor = b / sample_subset_size
142
+
143
+ # 3) add the residual
144
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
145
+ return x_plus_residual.view_as(x)
146
+
147
+
148
+ def get_branges_scales(x, sample_drop_ratio=0.0):
149
+ # get_branges_scales
150
+ b, n, d = x.shape
151
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
152
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
153
+ residual_scale_factor = b / sample_subset_size
154
+ return brange, residual_scale_factor
155
+
156
+
157
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
158
+ # add residuals
159
+ if scaling_vector is None:
160
+ x_flat = x.flatten(1)
161
+ residual = residual.flatten(1)
162
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
163
+ else:
164
+ x_plus_residual = scaled_index_add(
165
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
166
+ )
167
+ return x_plus_residual
168
+
169
+
170
+ attn_bias_cache: Dict[Tuple, Any] = {}
171
+
172
+
173
+ def get_attn_bias_and_cat(x_list, branges=None):
174
+ """
175
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
176
+ """
177
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
178
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
179
+ if all_shapes not in attn_bias_cache.keys():
180
+ seqlens = []
181
+ for b, x in zip(batch_sizes, x_list):
182
+ for _ in range(b):
183
+ seqlens.append(x.shape[1])
184
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
185
+ attn_bias._batch_sizes = batch_sizes
186
+ attn_bias_cache[all_shapes] = attn_bias
187
+
188
+ if branges is not None:
189
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
190
+ else:
191
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
192
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
193
+
194
+ return attn_bias_cache[all_shapes], cat_tensors
195
+
196
+
197
+ def drop_add_residual_stochastic_list(
198
+ x_list: List[Tensor],
199
+ residual_func: Callable[[Tensor, Any], Tensor],
200
+ sample_drop_ratio: float = 0.0,
201
+ scaling_vector=None,
202
+ ) -> Tensor:
203
+ # 1) generate random set of indices for dropping samples in the batch
204
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
205
+ branges = [s[0] for s in branges_scales]
206
+ residual_scale_factors = [s[1] for s in branges_scales]
207
+
208
+ # 2) get attention bias and index+concat the tensors
209
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
210
+
211
+ # 3) apply residual_func to get residual, and split the result
212
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
213
+
214
+ outputs = []
215
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
216
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
217
+ return outputs
218
+
219
+
220
+ class NestedTensorBlockMod(BlockMod):
221
+ def forward_nested(self, x_list: List[Tensor], cam_emb_list: List[Tensor]) -> List[Tensor]:
222
+ """
223
+ x_list contains a list of tensors to nest together and run
224
+ """
225
+ assert isinstance(self.attn, MemEffAttention)
226
+
227
+ if self.training and self.sample_drop_ratio > 0.0:
228
+
229
+ def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
230
+ return self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias)
231
+
232
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
233
+ return self.mlp(self.norm2(x, cam_emb))
234
+
235
+ x_list = drop_add_residual_stochastic_list(
236
+ x_list,
237
+ residual_func=attn_residual_func,
238
+ sample_drop_ratio=self.sample_drop_ratio,
239
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
240
+ )
241
+ x_list = drop_add_residual_stochastic_list(
242
+ x_list,
243
+ residual_func=ffn_residual_func,
244
+ sample_drop_ratio=self.sample_drop_ratio,
245
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
246
+ )
247
+ return x_list
248
+ else:
249
+
250
+ def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
251
+ return self.ls1(self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias))
252
+
253
+ def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
254
+ return self.ls2(self.mlp(self.norm2(x, cam_emb)))
255
+
256
+ attn_bias, x = get_attn_bias_and_cat(x_list)
257
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
258
+ x = x + ffn_residual_func(x)
259
+ return attn_bias.split(x)
260
+
261
+ def forward(self, x_or_x_list, cam_emb_or_cam_emb_list):
262
+ if isinstance(x_or_x_list, Tensor) and isinstance(cam_emb_or_cam_emb_list, Tensor) :
263
+ return super().forward(x_or_x_list, cam_emb_or_cam_emb_list)
264
+ elif isinstance(x_or_x_list, list) and isinstance(cam_emb_or_cam_emb_list, list):
265
+ if not XFORMERS_AVAILABLE:
266
+ raise AssertionError("xFormers is required for using nested tensors")
267
+ return self.forward_nested(x_or_x_list, cam_emb_or_cam_emb_list)
268
+ else:
269
+ raise AssertionError
svrm/ldm/modules/encoders/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
svrm/ldm/modules/encoders/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma