Haoxin Chen commited on
Commit
af6180a
·
1 Parent(s): a34a4fc

update VideoCrafter1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. README.md +0 -13
  3. app.py +36 -46
  4. models/adapter_t2v_depth/model_config.yaml → configs/inference_i2v_512_v1.0.yaml +30 -36
  5. models/base_t2v/model_config.yaml → configs/inference_t2v_1024_v1.0.yaml +24 -19
  6. configs/inference_t2v_512_v1.0.yaml +77 -0
  7. demo_test.py +4 -4
  8. extralibs/midas/__init__.py +0 -0
  9. extralibs/midas/api.py +0 -171
  10. extralibs/midas/midas/__init__.py +0 -0
  11. extralibs/midas/midas/base_model.py +0 -16
  12. extralibs/midas/midas/blocks.py +0 -342
  13. extralibs/midas/midas/dpt_depth.py +0 -110
  14. extralibs/midas/midas/midas_net.py +0 -76
  15. extralibs/midas/midas/midas_net_custom.py +0 -128
  16. extralibs/midas/midas/transforms.py +0 -234
  17. extralibs/midas/midas/vit.py +0 -489
  18. extralibs/midas/utils.py +0 -189
  19. i2v_test.py +78 -0
  20. input/flamingo.mp4 +0 -0
  21. input/prompts.txt +0 -2
  22. lvdm/basics.py +100 -0
  23. lvdm/common.py +95 -0
  24. lvdm/data/webvid.py +0 -188
  25. lvdm/{models/modules/distributions.py → distributions.py} +19 -0
  26. lvdm/ema.py +76 -0
  27. lvdm/models/autoencoder.py +29 -12
  28. lvdm/models/ddpm3d.py +261 -982
  29. lvdm/models/modules/adapter.py +0 -105
  30. lvdm/models/modules/attention_temporal.py +0 -399
  31. lvdm/models/modules/condition_modules.py +0 -40
  32. lvdm/models/modules/lora.py +0 -1251
  33. lvdm/{samplers → models/samplers}/ddim.py +120 -51
  34. lvdm/models/{modules/util.py → utils_diffusion.py} +28 -272
  35. lvdm/modules/attention.py +475 -0
  36. lvdm/modules/encoders/condition.py +392 -0
  37. lvdm/modules/encoders/ip_resampler.py +136 -0
  38. lvdm/{models/modules/autoencoder_modules.py → modules/networks/ae_modules.py} +294 -45
  39. lvdm/{models/modules → modules/networks}/openaimodel3d.py +287 -380
  40. lvdm/modules/x_transformer.py +640 -0
  41. lvdm/utils/common_utils.py +0 -132
  42. lvdm/utils/dist_utils.py +0 -19
  43. lvdm/utils/saving_utils.py +0 -269
  44. prompts/i2v_prompts/horse.png +0 -0
  45. prompts/i2v_prompts/seashore.png +0 -0
  46. prompts/i2v_prompts/test_prompts.txt +2 -0
  47. prompts/test_prompts.txt +2 -0
  48. requirements.txt +4 -3
  49. sample_adapter.sh +0 -22
  50. sample_text2video.sh +0 -16
.gitignore CHANGED
@@ -8,3 +8,4 @@ results
8
  *.ckpt
9
  *.pt
10
  *.pth
 
 
8
  *.ckpt
9
  *.pt
10
  *.pth
11
+ checkpoints
README.md CHANGED
@@ -1,13 +0,0 @@
1
- ---
2
- title: VideoCrafter
3
- emoji: 🌍
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,30 +1,30 @@
1
  import os
2
  import sys
3
  import gradio as gr
4
- # from demo_test import Text2Video, VideoControl
5
- from videocrafter_test import Text2Video
6
- from videocontrol_test import VideoControl
7
  sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
8
 
9
  t2v_examples = [
10
- ['an elephant is walking under the sea, 4K, high definition',50,'origin',1,15,1,],
11
- ['an astronaut riding a horse in outer space',25,'origin',1,15,1,],
12
- ['a monkey is playing a piano',25,'vangogh',1,15,1,],
13
- ['A fire is burning on a candle',25,'frozen',1,15,1,],
14
- ['a horse is drinking in the river',25,'yourname',1,15,1,],
15
- ['Robot dancing in times square',25,'coco',1,15,1,],
16
  ]
17
 
18
- control_examples = [
19
- ['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1, 16, 256]
20
  ]
21
 
22
  def videocrafter_demo(result_dir='./tmp/'):
23
  text2video = Text2Video(result_dir)
24
- videocontrol = VideoControl(result_dir)
25
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
26
- gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
27
- <a style='font-size:18px;color: #000000' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
28
 
29
  gr.Markdown("<b> You may duplicate the space and upgrade to GPU in settings for better performance and faster inference without waiting in the queue. <a style='display:inline-block' href='https://huggingface.co/spaces/VideoCrafter/VideoCrafter?duplicate=true'> <img src='https://bit.ly/3gLdBN6' alt='Duplicate Space'></a> </b>")
30
  #######t2v#######
@@ -33,65 +33,55 @@ def videocrafter_demo(result_dir='./tmp/'):
33
  with gr.Row().style(equal_height=False):
34
  with gr.Column():
35
  input_text = gr.Text(label='Prompts')
36
- model_choices=['origin','vangogh','frozen','yourname', 'coco']
37
- with gr.Row():
38
- model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
39
  with gr.Row():
40
  steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id=f"steps", label="Sampling steps", value=50)
41
  eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="eta")
42
  with gr.Row():
43
- lora_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, label='Lora Scale', value=1.0, elem_id="lora_scale")
44
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="cfg_scale")
45
  send_btn = gr.Button("Send")
46
  with gr.Tab(label='result'):
47
- output_video_1 = gr.Video().style(width=384)
48
  gr.Examples(examples=t2v_examples,
49
- inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale],
50
  outputs=[output_video_1],
51
  fn=text2video.get_prompt,
52
  cache_examples=False)
53
  #cache_examples=os.getenv('SYSTEM') == 'spaces')
54
  send_btn.click(
55
  fn=text2video.get_prompt,
56
- inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale,],
57
  outputs=[output_video_1],
58
  )
59
- #######videocontrol######
60
- with gr.Tab(label='VideoControl'):
61
  with gr.Column():
62
  with gr.Row():
63
- # with gr.Tab(label='input'):
64
  with gr.Column():
65
  with gr.Row():
66
- vc_input_video = gr.Video(label="Input Video").style(width=256)
67
- vc_origin_video = gr.Video(label='Center-cropped Video').style(width=256)
68
- with gr.Row():
69
- vc_input_text = gr.Text(label='Prompts')
70
  with gr.Row():
71
- vc_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="vc_eta")
72
- vc_cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="vc_cfg_scale")
73
  with gr.Row():
74
- vc_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="vc_steps", label="Sampling steps", value=50)
75
- frame_stride = gr.Slider(minimum=0 , maximum=100, step=1, label='Frame Stride', value=0, elem_id="vc_frame_stride")
76
  with gr.Row():
77
- resolution = gr.Slider(minimum=128 , maximum=512, step=8, label='Long Side Resolution', value=256, elem_id="vc_resolution")
78
- video_frames = gr.Slider(minimum=8 , maximum=64, step=1, label='Video Frame Num', value=16, elem_id="vc_video_frames")
79
- vc_end_btn = gr.Button("Send")
80
  with gr.Tab(label='Result'):
81
- vc_output_info = gr.Text(label='Info')
82
  with gr.Row():
83
- vc_depth_video = gr.Video(label="Depth Video").style(width=256)
84
- vc_output_video = gr.Video(label="Generated Video").style(width=256)
85
 
86
- gr.Examples(examples=control_examples,
87
- inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
88
- outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
89
- fn = videocontrol.get_video,
90
  cache_examples=os.getenv('SYSTEM') == 'spaces',
91
  )
92
- vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta, video_frames, resolution],
93
- outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
94
- fn = videocontrol.get_video
95
  )
96
 
97
  return videocrafter_iface
 
1
  import os
2
  import sys
3
  import gradio as gr
4
+ # from demo_test import Text2Video, Image2Video
5
+ from t2v_test import Text2Video
6
+ from i2v_test import Image2Video
7
  sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
8
 
9
  t2v_examples = [
10
+ ['an elephant is walking under the sea, 4K, high definition',50, 12,1, 16],
11
+ ['an astronaut riding a horse in outer space',25,12,1,16],
12
+ ['a monkey is playing a piano',25,12,1,16],
13
+ ['A fire is burning on a candle',25,12,1,16],
14
+ ['a horse is drinking in the river',25,12,1,16],
15
+ ['Robot dancing in times square',25,12,1,16],
16
  ]
17
 
18
+ i2v_examples = [
19
+ ['prompts/i2v_prompts/horse.png', 'horses are walking on the grassland', 50, 12, 1, 16]
20
  ]
21
 
22
  def videocrafter_demo(result_dir='./tmp/'):
23
  text2video = Text2Video(result_dir)
24
+ image2video = Image2Video(result_dir)
25
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
26
+ gr.Markdown("<div align='center'> <h2> VideoCrafter1: Open Diffusion Models for High-Quality Video Generation </span> </h2> \
27
+ <a style='font-size:18px;color: #000000' href='https://github.com/AILab-CVC/VideoCrafter'> Github </div>")
28
 
29
  gr.Markdown("<b> You may duplicate the space and upgrade to GPU in settings for better performance and faster inference without waiting in the queue. <a style='display:inline-block' href='https://huggingface.co/spaces/VideoCrafter/VideoCrafter?duplicate=true'> <img src='https://bit.ly/3gLdBN6' alt='Duplicate Space'></a> </b>")
30
  #######t2v#######
 
33
  with gr.Row().style(equal_height=False):
34
  with gr.Column():
35
  input_text = gr.Text(label='Prompts')
 
 
 
36
  with gr.Row():
37
  steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id=f"steps", label="Sampling steps", value=50)
38
  eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="eta")
39
  with gr.Row():
40
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=12.0, elem_id="cfg_scale")
41
+ fps = gr.Slider(minimum=4, maximum=32, step=1, label='fps', value=16, elem_id="fps")
42
  send_btn = gr.Button("Send")
43
  with gr.Tab(label='result'):
44
+ output_video_1 = gr.Video().style(width=320)
45
  gr.Examples(examples=t2v_examples,
46
+ inputs=[input_text,steps,cfg_scale,eta],
47
  outputs=[output_video_1],
48
  fn=text2video.get_prompt,
49
  cache_examples=False)
50
  #cache_examples=os.getenv('SYSTEM') == 'spaces')
51
  send_btn.click(
52
  fn=text2video.get_prompt,
53
+ inputs=[input_text,steps,cfg_scale,eta,fps],
54
  outputs=[output_video_1],
55
  )
56
+ #######image2video######
57
+ with gr.Tab(label='Image2Video'):
58
  with gr.Column():
59
  with gr.Row():
 
60
  with gr.Column():
61
  with gr.Row():
62
+ i2v_input_image = gr.Image(label="Input Image").style(width=256)
 
 
 
63
  with gr.Row():
64
+ i2v_input_text = gr.Text(label='Prompts')
 
65
  with gr.Row():
66
+ i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
67
+ i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=12.0, elem_id="i2v_cfg_scale")
68
  with gr.Row():
69
+ i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
70
+ i2v_fps = gr.Slider(minimum=4, maximum=32, step=1, elem_id="i2v_fps", label="Generative fps", value=16)
71
+ i2v_end_btn = gr.Button("Send")
72
  with gr.Tab(label='Result'):
 
73
  with gr.Row():
74
+ i2v_output_video = gr.Video(label="Generated Video").style(width=320)
 
75
 
76
+ gr.Examples(examples=i2v_examples,
77
+ inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_fps],
78
+ outputs=[i2v_output_video],
79
+ fn = image2video.get_image,
80
  cache_examples=os.getenv('SYSTEM') == 'spaces',
81
  )
82
+ i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_fps],
83
+ outputs=[i2v_output_video],
84
+ fn = image2video.get_image
85
  )
86
 
87
  return videocrafter_iface
models/adapter_t2v_depth/model_config.yaml → configs/inference_i2v_512_v1.0.yaml RENAMED
@@ -1,27 +1,28 @@
1
  model:
2
- target: lvdm.models.ddpm3d.T2VAdapterDepth
3
  params:
4
  linear_start: 0.00085
5
  linear_end: 0.012
6
  num_timesteps_cond: 1
7
- log_every_t: 200
8
  timesteps: 1000
9
  first_stage_key: video
10
  cond_stage_key: caption
11
- image_size:
12
- - 32
13
- - 32
14
- video_length: 16
15
- channels: 4
16
  cond_stage_trainable: false
17
  conditioning_key: crossattn
 
 
 
 
18
  scale_by_std: false
19
  scale_factor: 0.18215
20
-
 
 
 
 
21
  unet_config:
22
- target: lvdm.models.modules.openaimodel3d.UNetModel
23
  params:
24
- image_size: 32
25
  in_channels: 4
26
  out_channels: 4
27
  model_channels: 320
@@ -35,16 +36,20 @@ model:
35
  - 2
36
  - 4
37
  - 4
38
- num_heads: 8
39
  transformer_depth: 1
40
- context_dim: 768
 
41
  use_checkpoint: true
42
- legacy: false
43
- kernel_size_t: 1
44
- padding_t: 0
 
 
 
45
  temporal_length: 16
46
- use_relative_position: true
47
-
48
  first_stage_config:
49
  target: lvdm.models.autoencoder.AutoencoderKL
50
  params:
@@ -53,7 +58,7 @@ model:
53
  ddconfig:
54
  double_z: true
55
  z_channels: 4
56
- resolution: 256
57
  in_channels: 3
58
  out_ch: 3
59
  ch: 128
@@ -67,23 +72,12 @@ model:
67
  dropout: 0.0
68
  lossconfig:
69
  target: torch.nn.Identity
70
-
71
  cond_stage_config:
72
- target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder
73
-
74
- depth_stage_config:
75
- target: extralibs.midas.api.MiDaSInference
 
 
76
  params:
77
- model_type: "dpt_hybrid"
78
- model_path: models/adapter_t2v_depth/dpt_hybrid-midas.pt
79
-
80
- adapter_config:
81
- target: lvdm.models.modules.adapter.Adapter
82
- cond_name: depth
83
- params:
84
- cin: 64
85
- channels: [320, 640, 1280, 1280]
86
- nums_rb: 2
87
- ksize: 1
88
- sk: True
89
- use_conv: False
 
1
  model:
2
+ target: lvdm.models.ddpm3d.LatentVisualDiffusion
3
  params:
4
  linear_start: 0.00085
5
  linear_end: 0.012
6
  num_timesteps_cond: 1
 
7
  timesteps: 1000
8
  first_stage_key: video
9
  cond_stage_key: caption
 
 
 
 
 
10
  cond_stage_trainable: false
11
  conditioning_key: crossattn
12
+ image_size:
13
+ - 40
14
+ - 64
15
+ channels: 4
16
  scale_by_std: false
17
  scale_factor: 0.18215
18
+ use_ema: false
19
+ uncond_type: empty_seq
20
+ use_scale: true
21
+ scale_b: 0.7
22
+ finegrained: true
23
  unet_config:
24
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
25
  params:
 
26
  in_channels: 4
27
  out_channels: 4
28
  model_channels: 320
 
36
  - 2
37
  - 4
38
  - 4
39
+ num_head_channels: 64
40
  transformer_depth: 1
41
+ context_dim: 1024
42
+ use_linear: true
43
  use_checkpoint: true
44
+ temporal_conv: true
45
+ temporal_attention: true
46
+ temporal_selfatt_only: true
47
+ use_relative_position: false
48
+ use_causal_attention: false
49
+ use_image_attention: true
50
  temporal_length: 16
51
+ addition_attention: true
52
+ fps_cond: true
53
  first_stage_config:
54
  target: lvdm.models.autoencoder.AutoencoderKL
55
  params:
 
58
  ddconfig:
59
  double_z: true
60
  z_channels: 4
61
+ resolution: 512
62
  in_channels: 3
63
  out_ch: 3
64
  ch: 128
 
72
  dropout: 0.0
73
  lossconfig:
74
  target: torch.nn.Identity
 
75
  cond_stage_config:
76
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
77
+ params:
78
+ freeze: true
79
+ layer: penultimate
80
+ cond_img_config:
81
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
82
  params:
83
+ freeze: true
 
 
 
 
 
 
 
 
 
 
 
 
models/base_t2v/model_config.yaml → configs/inference_t2v_1024_v1.0.yaml RENAMED
@@ -4,24 +4,24 @@ model:
4
  linear_start: 0.00085
5
  linear_end: 0.012
6
  num_timesteps_cond: 1
7
- log_every_t: 200
8
  timesteps: 1000
9
  first_stage_key: video
10
  cond_stage_key: caption
11
- image_size:
12
- - 32
13
- - 32
14
- video_length: 16
15
- channels: 4
16
  cond_stage_trainable: false
17
  conditioning_key: crossattn
 
 
 
 
18
  scale_by_std: false
19
  scale_factor: 0.18215
20
-
 
 
 
21
  unet_config:
22
- target: lvdm.models.modules.openaimodel3d.UNetModel
23
  params:
24
- image_size: 32
25
  in_channels: 4
26
  out_channels: 4
27
  model_channels: 320
@@ -35,16 +35,19 @@ model:
35
  - 2
36
  - 4
37
  - 4
38
- num_heads: 8
39
  transformer_depth: 1
40
- context_dim: 768
 
41
  use_checkpoint: true
42
- legacy: false
43
- kernel_size_t: 1
44
- padding_t: 0
45
- temporal_length: 16
46
  use_relative_position: true
47
-
 
 
 
48
  first_stage_config:
49
  target: lvdm.models.autoencoder.AutoencoderKL
50
  params:
@@ -53,7 +56,7 @@ model:
53
  ddconfig:
54
  double_z: true
55
  z_channels: 4
56
- resolution: 256
57
  in_channels: 3
58
  out_ch: 3
59
  ch: 128
@@ -67,6 +70,8 @@ model:
67
  dropout: 0.0
68
  lossconfig:
69
  target: torch.nn.Identity
70
-
71
  cond_stage_config:
72
- target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder
 
 
 
 
4
  linear_start: 0.00085
5
  linear_end: 0.012
6
  num_timesteps_cond: 1
 
7
  timesteps: 1000
8
  first_stage_key: video
9
  cond_stage_key: caption
 
 
 
 
 
10
  cond_stage_trainable: false
11
  conditioning_key: crossattn
12
+ image_size:
13
+ - 72
14
+ - 128
15
+ channels: 4
16
  scale_by_std: false
17
  scale_factor: 0.18215
18
+ use_ema: false
19
+ uncond_type: empty_seq
20
+ use_scale: true
21
+ fix_scale_bug: true
22
  unet_config:
23
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
24
  params:
 
25
  in_channels: 4
26
  out_channels: 4
27
  model_channels: 320
 
35
  - 2
36
  - 4
37
  - 4
38
+ num_head_channels: 64
39
  transformer_depth: 1
40
+ context_dim: 1024
41
+ use_linear: true
42
  use_checkpoint: true
43
+ temporal_conv: false
44
+ temporal_attention: true
45
+ temporal_selfatt_only: true
 
46
  use_relative_position: true
47
+ use_causal_attention: false
48
+ temporal_length: 16
49
+ addition_attention: true
50
+ fps_cond: true
51
  first_stage_config:
52
  target: lvdm.models.autoencoder.AutoencoderKL
53
  params:
 
56
  ddconfig:
57
  double_z: true
58
  z_channels: 4
59
+ resolution: 512
60
  in_channels: 3
61
  out_ch: 3
62
  ch: 128
 
70
  dropout: 0.0
71
  lossconfig:
72
  target: torch.nn.Identity
 
73
  cond_stage_config:
74
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
75
+ params:
76
+ freeze: true
77
+ layer: penultimate
configs/inference_t2v_512_v1.0.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: lvdm.models.ddpm3d.LatentDiffusion
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.012
6
+ num_timesteps_cond: 1
7
+ timesteps: 1000
8
+ first_stage_key: video
9
+ cond_stage_key: caption
10
+ cond_stage_trainable: false
11
+ conditioning_key: crossattn
12
+ image_size:
13
+ - 40
14
+ - 64
15
+ channels: 4
16
+ scale_by_std: false
17
+ scale_factor: 0.18215
18
+ use_ema: false
19
+ uncond_type: empty_seq
20
+ use_scale: true
21
+ scale_b: 0.7
22
+ unet_config:
23
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
24
+ params:
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 320
28
+ attention_resolutions:
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ - 4
38
+ num_head_channels: 64
39
+ transformer_depth: 1
40
+ context_dim: 1024
41
+ use_linear: true
42
+ use_checkpoint: true
43
+ temporal_conv: true
44
+ temporal_attention: true
45
+ temporal_selfatt_only: true
46
+ use_relative_position: false
47
+ use_causal_attention: false
48
+ temporal_length: 16
49
+ addition_attention: true
50
+ fps_cond: true
51
+ first_stage_config:
52
+ target: lvdm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 512
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+ cond_stage_config:
74
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
75
+ params:
76
+ freeze: true
77
+ layer: penultimate
demo_test.py CHANGED
@@ -2,15 +2,15 @@ class Text2Video():
2
  def __init__(self, result_dir='./tmp/') -> None:
3
  pass
4
 
5
- def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
6
 
7
  return '01.mp4'
8
 
9
- class VideoControl:
10
  def __init__(self, result_dir='./tmp/') -> None:
11
  pass
12
 
13
- def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
14
 
15
- return 'su','01.mp4'
16
 
 
2
  def __init__(self, result_dir='./tmp/') -> None:
3
  pass
4
 
5
+ def get_prompt(self, input_text, steps=50, cfg_scale=15.0, eta=1.0, fps=16):
6
 
7
  return '01.mp4'
8
 
9
+ class Image2Video:
10
  def __init__(self, result_dir='./tmp/') -> None:
11
  pass
12
 
13
+ def get_image(self, input_image, input_prompt, i2v_steps=50, i2v_cfg_scale=15.0, i2v_eta=1.0, i2v_fps=16):
14
 
15
+ return '01.mp4'
16
 
extralibs/midas/__init__.py DELETED
File without changes
extralibs/midas/api.py DELETED
@@ -1,171 +0,0 @@
1
- # based on https://github.com/isl-org/MiDaS
2
-
3
- import cv2
4
- import torch
5
- import torch.nn as nn
6
- from torchvision.transforms import Compose
7
-
8
- from extralibs.midas.midas.dpt_depth import DPTDepthModel
9
- from extralibs.midas.midas.midas_net import MidasNet
10
- from extralibs.midas.midas.midas_net_custom import MidasNet_small
11
- from extralibs.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12
-
13
-
14
- ISL_PATHS = {
15
- "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16
- "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17
- "midas_v21": "",
18
- "midas_v21_small": "",
19
- }
20
-
21
-
22
- def disabled_train(self, mode=True):
23
- """Overwrite model.train with this function to make sure train/eval mode
24
- does not change anymore."""
25
- return self
26
-
27
-
28
- def load_midas_transform(model_type):
29
- # https://github.com/isl-org/MiDaS/blob/master/run.py
30
- # load transform only
31
- if model_type == "dpt_large": # DPT-Large
32
- net_w, net_h = 384, 384
33
- resize_mode = "minimal"
34
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35
-
36
- elif model_type == "dpt_hybrid": # DPT-Hybrid
37
- net_w, net_h = 384, 384
38
- resize_mode = "minimal"
39
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40
-
41
- elif model_type == "midas_v21":
42
- net_w, net_h = 384, 384
43
- resize_mode = "upper_bound"
44
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
-
46
- elif model_type == "midas_v21_small":
47
- net_w, net_h = 256, 256
48
- resize_mode = "upper_bound"
49
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
-
51
- else:
52
- assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53
-
54
- transform = Compose(
55
- [
56
- Resize(
57
- net_w,
58
- net_h,
59
- resize_target=None,
60
- keep_aspect_ratio=True,
61
- ensure_multiple_of=32,
62
- resize_method=resize_mode,
63
- image_interpolation_method=cv2.INTER_CUBIC,
64
- ),
65
- normalization,
66
- PrepareForNet(),
67
- ]
68
- )
69
-
70
- return transform
71
-
72
-
73
- def load_model(model_type, model_path=None):
74
- # https://github.com/isl-org/MiDaS/blob/master/run.py
75
- # load network
76
- if model_path is None:
77
- model_path = ISL_PATHS[model_type]
78
- if model_type == "dpt_large": # DPT-Large
79
- model = DPTDepthModel(
80
- path=model_path,
81
- backbone="vitl16_384",
82
- non_negative=True,
83
- )
84
- net_w, net_h = 384, 384
85
- resize_mode = "minimal"
86
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
87
-
88
- elif model_type == "dpt_hybrid": # DPT-Hybrid
89
- model = DPTDepthModel(
90
- path=model_path,
91
- backbone="vitb_rn50_384",
92
- non_negative=True,
93
- )
94
- net_w, net_h = 384, 384
95
- resize_mode = "minimal"
96
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
97
-
98
- elif model_type == "midas_v21":
99
- model = MidasNet(model_path, non_negative=True)
100
- net_w, net_h = 384, 384
101
- resize_mode = "upper_bound"
102
- normalization = NormalizeImage(
103
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
104
- )
105
-
106
- elif model_type == "midas_v21_small":
107
- model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
108
- non_negative=True, blocks={'expand': True})
109
- net_w, net_h = 256, 256
110
- resize_mode = "upper_bound"
111
- normalization = NormalizeImage(
112
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
113
- )
114
-
115
- else:
116
- print(f"model_type '{model_type}' not implemented, use: --model_type large")
117
- assert False
118
-
119
- transform = Compose(
120
- [
121
- Resize(
122
- net_w,
123
- net_h,
124
- resize_target=None,
125
- keep_aspect_ratio=True,
126
- ensure_multiple_of=32,
127
- resize_method=resize_mode,
128
- image_interpolation_method=cv2.INTER_CUBIC,
129
- ),
130
- normalization,
131
- PrepareForNet(),
132
- ]
133
- )
134
-
135
- return model.eval(), transform
136
-
137
-
138
- class MiDaSInference(nn.Module):
139
- MODEL_TYPES_TORCH_HUB = [
140
- "DPT_Large",
141
- "DPT_Hybrid",
142
- "MiDaS_small"
143
- ]
144
- MODEL_TYPES_ISL = [
145
- "dpt_large",
146
- "dpt_hybrid",
147
- "midas_v21",
148
- "midas_v21_small",
149
- ]
150
-
151
- def __init__(self, model_type, model_path):
152
- super().__init__()
153
- assert (model_type in self.MODEL_TYPES_ISL)
154
- model, _ = load_model(model_type, model_path)
155
- self.model = model
156
- self.model.train = disabled_train
157
-
158
- def forward(self, x):
159
- # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
160
- # NOTE: we expect that the correct transform has been called during dataloading.
161
- with torch.no_grad():
162
- prediction = self.model(x)
163
- prediction = torch.nn.functional.interpolate(
164
- prediction.unsqueeze(1),
165
- size=x.shape[2:],
166
- mode="bicubic",
167
- align_corners=False,
168
- )
169
- assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
170
- return prediction
171
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/__init__.py DELETED
File without changes
extralibs/midas/midas/base_model.py DELETED
@@ -1,16 +0,0 @@
1
- import torch
2
-
3
-
4
- class BaseModel(torch.nn.Module):
5
- def load(self, path):
6
- """Load model from file.
7
-
8
- Args:
9
- path (str): file path
10
- """
11
- parameters = torch.load(path, map_location=torch.device('cpu'))
12
-
13
- if "optimizer" in parameters:
14
- parameters = parameters["model"]
15
-
16
- self.load_state_dict(parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/blocks.py DELETED
@@ -1,342 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from .vit import (
5
- _make_pretrained_vitb_rn50_384,
6
- _make_pretrained_vitl16_384,
7
- _make_pretrained_vitb16_384,
8
- forward_vit,
9
- )
10
-
11
- def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
- if backbone == "vitl16_384":
13
- pretrained = _make_pretrained_vitl16_384(
14
- use_pretrained, hooks=hooks, use_readout=use_readout
15
- )
16
- scratch = _make_scratch(
17
- [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
- ) # ViT-L/16 - 85.0% Top1 (backbone)
19
- elif backbone == "vitb_rn50_384":
20
- pretrained = _make_pretrained_vitb_rn50_384(
21
- use_pretrained,
22
- hooks=hooks,
23
- use_vit_only=use_vit_only,
24
- use_readout=use_readout,
25
- )
26
- scratch = _make_scratch(
27
- [256, 512, 768, 768], features, groups=groups, expand=expand
28
- ) # ViT-H/16 - 85.0% Top1 (backbone)
29
- elif backbone == "vitb16_384":
30
- pretrained = _make_pretrained_vitb16_384(
31
- use_pretrained, hooks=hooks, use_readout=use_readout
32
- )
33
- scratch = _make_scratch(
34
- [96, 192, 384, 768], features, groups=groups, expand=expand
35
- ) # ViT-B/16 - 84.6% Top1 (backbone)
36
- elif backbone == "resnext101_wsl":
37
- pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
- scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
- elif backbone == "efficientnet_lite3":
40
- pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
- scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
- else:
43
- print(f"Backbone '{backbone}' not implemented")
44
- assert False
45
-
46
- return pretrained, scratch
47
-
48
-
49
- def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
- scratch = nn.Module()
51
-
52
- out_shape1 = out_shape
53
- out_shape2 = out_shape
54
- out_shape3 = out_shape
55
- out_shape4 = out_shape
56
- if expand==True:
57
- out_shape1 = out_shape
58
- out_shape2 = out_shape*2
59
- out_shape3 = out_shape*4
60
- out_shape4 = out_shape*8
61
-
62
- scratch.layer1_rn = nn.Conv2d(
63
- in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
- )
65
- scratch.layer2_rn = nn.Conv2d(
66
- in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
- )
68
- scratch.layer3_rn = nn.Conv2d(
69
- in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
- )
71
- scratch.layer4_rn = nn.Conv2d(
72
- in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
- )
74
-
75
- return scratch
76
-
77
-
78
- def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
- efficientnet = torch.hub.load(
80
- "rwightman/gen-efficientnet-pytorch",
81
- "tf_efficientnet_lite3",
82
- pretrained=use_pretrained,
83
- exportable=exportable
84
- )
85
- return _make_efficientnet_backbone(efficientnet)
86
-
87
-
88
- def _make_efficientnet_backbone(effnet):
89
- pretrained = nn.Module()
90
-
91
- pretrained.layer1 = nn.Sequential(
92
- effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
- )
94
- pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
- pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
- pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
-
98
- return pretrained
99
-
100
-
101
- def _make_resnet_backbone(resnet):
102
- pretrained = nn.Module()
103
- pretrained.layer1 = nn.Sequential(
104
- resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
- )
106
-
107
- pretrained.layer2 = resnet.layer2
108
- pretrained.layer3 = resnet.layer3
109
- pretrained.layer4 = resnet.layer4
110
-
111
- return pretrained
112
-
113
-
114
- def _make_pretrained_resnext101_wsl(use_pretrained):
115
- resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
- return _make_resnet_backbone(resnet)
117
-
118
-
119
-
120
- class Interpolate(nn.Module):
121
- """Interpolation module.
122
- """
123
-
124
- def __init__(self, scale_factor, mode, align_corners=False):
125
- """Init.
126
-
127
- Args:
128
- scale_factor (float): scaling
129
- mode (str): interpolation mode
130
- """
131
- super(Interpolate, self).__init__()
132
-
133
- self.interp = nn.functional.interpolate
134
- self.scale_factor = scale_factor
135
- self.mode = mode
136
- self.align_corners = align_corners
137
-
138
- def forward(self, x):
139
- """Forward pass.
140
-
141
- Args:
142
- x (tensor): input
143
-
144
- Returns:
145
- tensor: interpolated data
146
- """
147
-
148
- x = self.interp(
149
- x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
- )
151
-
152
- return x
153
-
154
-
155
- class ResidualConvUnit(nn.Module):
156
- """Residual convolution module.
157
- """
158
-
159
- def __init__(self, features):
160
- """Init.
161
-
162
- Args:
163
- features (int): number of features
164
- """
165
- super().__init__()
166
-
167
- self.conv1 = nn.Conv2d(
168
- features, features, kernel_size=3, stride=1, padding=1, bias=True
169
- )
170
-
171
- self.conv2 = nn.Conv2d(
172
- features, features, kernel_size=3, stride=1, padding=1, bias=True
173
- )
174
-
175
- self.relu = nn.ReLU(inplace=True)
176
-
177
- def forward(self, x):
178
- """Forward pass.
179
-
180
- Args:
181
- x (tensor): input
182
-
183
- Returns:
184
- tensor: output
185
- """
186
- out = self.relu(x)
187
- out = self.conv1(out)
188
- out = self.relu(out)
189
- out = self.conv2(out)
190
-
191
- return out + x
192
-
193
-
194
- class FeatureFusionBlock(nn.Module):
195
- """Feature fusion block.
196
- """
197
-
198
- def __init__(self, features):
199
- """Init.
200
-
201
- Args:
202
- features (int): number of features
203
- """
204
- super(FeatureFusionBlock, self).__init__()
205
-
206
- self.resConfUnit1 = ResidualConvUnit(features)
207
- self.resConfUnit2 = ResidualConvUnit(features)
208
-
209
- def forward(self, *xs):
210
- """Forward pass.
211
-
212
- Returns:
213
- tensor: output
214
- """
215
- output = xs[0]
216
-
217
- if len(xs) == 2:
218
- output += self.resConfUnit1(xs[1])
219
-
220
- output = self.resConfUnit2(output)
221
-
222
- output = nn.functional.interpolate(
223
- output, scale_factor=2, mode="bilinear", align_corners=True
224
- )
225
-
226
- return output
227
-
228
-
229
-
230
-
231
- class ResidualConvUnit_custom(nn.Module):
232
- """Residual convolution module.
233
- """
234
-
235
- def __init__(self, features, activation, bn):
236
- """Init.
237
-
238
- Args:
239
- features (int): number of features
240
- """
241
- super().__init__()
242
-
243
- self.bn = bn
244
-
245
- self.groups=1
246
-
247
- self.conv1 = nn.Conv2d(
248
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
- )
250
-
251
- self.conv2 = nn.Conv2d(
252
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
- )
254
-
255
- if self.bn==True:
256
- self.bn1 = nn.BatchNorm2d(features)
257
- self.bn2 = nn.BatchNorm2d(features)
258
-
259
- self.activation = activation
260
-
261
- self.skip_add = nn.quantized.FloatFunctional()
262
-
263
- def forward(self, x):
264
- """Forward pass.
265
-
266
- Args:
267
- x (tensor): input
268
-
269
- Returns:
270
- tensor: output
271
- """
272
-
273
- out = self.activation(x)
274
- out = self.conv1(out)
275
- if self.bn==True:
276
- out = self.bn1(out)
277
-
278
- out = self.activation(out)
279
- out = self.conv2(out)
280
- if self.bn==True:
281
- out = self.bn2(out)
282
-
283
- if self.groups > 1:
284
- out = self.conv_merge(out)
285
-
286
- return self.skip_add.add(out, x)
287
-
288
- # return out + x
289
-
290
-
291
- class FeatureFusionBlock_custom(nn.Module):
292
- """Feature fusion block.
293
- """
294
-
295
- def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
- """Init.
297
-
298
- Args:
299
- features (int): number of features
300
- """
301
- super(FeatureFusionBlock_custom, self).__init__()
302
-
303
- self.deconv = deconv
304
- self.align_corners = align_corners
305
-
306
- self.groups=1
307
-
308
- self.expand = expand
309
- out_features = features
310
- if self.expand==True:
311
- out_features = features//2
312
-
313
- self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
-
315
- self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
- self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
-
318
- self.skip_add = nn.quantized.FloatFunctional()
319
-
320
- def forward(self, *xs):
321
- """Forward pass.
322
-
323
- Returns:
324
- tensor: output
325
- """
326
- output = xs[0]
327
-
328
- if len(xs) == 2:
329
- res = self.resConfUnit1(xs[1])
330
- output = self.skip_add.add(output, res)
331
- # output += res
332
-
333
- output = self.resConfUnit2(output)
334
-
335
- output = nn.functional.interpolate(
336
- output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
- )
338
-
339
- output = self.out_conv(output)
340
-
341
- return output
342
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/dpt_depth.py DELETED
@@ -1,110 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .base_model import BaseModel
6
- from .blocks import (
7
- FeatureFusionBlock,
8
- FeatureFusionBlock_custom,
9
- Interpolate,
10
- _make_encoder,
11
- forward_vit,
12
- )
13
-
14
-
15
- def _make_fusion_block(features, use_bn):
16
- return FeatureFusionBlock_custom(
17
- features,
18
- nn.ReLU(False),
19
- deconv=False,
20
- bn=use_bn,
21
- expand=False,
22
- align_corners=True,
23
- )
24
-
25
-
26
- class DPT(BaseModel):
27
- def __init__(
28
- self,
29
- head,
30
- features=256,
31
- backbone="vitb_rn50_384",
32
- readout="project",
33
- channels_last=False,
34
- use_bn=False,
35
- ):
36
-
37
- super(DPT, self).__init__()
38
-
39
- self.channels_last = channels_last
40
-
41
- hooks = {
42
- "vitb_rn50_384": [0, 1, 8, 11],
43
- "vitb16_384": [2, 5, 8, 11],
44
- "vitl16_384": [5, 11, 17, 23],
45
- }
46
-
47
- # Instantiate backbone and reassemble blocks
48
- self.pretrained, self.scratch = _make_encoder(
49
- backbone,
50
- features,
51
- False, # Set to true of you want to train from scratch, uses ImageNet weights
52
- groups=1,
53
- expand=False,
54
- exportable=False,
55
- hooks=hooks[backbone],
56
- use_readout=readout,
57
- )
58
-
59
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
-
64
- self.scratch.output_conv = head
65
-
66
-
67
- def forward(self, x):
68
- if self.channels_last == True:
69
- x.contiguous(memory_format=torch.channels_last)
70
-
71
- layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
-
73
- layer_1_rn = self.scratch.layer1_rn(layer_1)
74
- layer_2_rn = self.scratch.layer2_rn(layer_2)
75
- layer_3_rn = self.scratch.layer3_rn(layer_3)
76
- layer_4_rn = self.scratch.layer4_rn(layer_4)
77
-
78
- path_4 = self.scratch.refinenet4(layer_4_rn)
79
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
-
83
- out = self.scratch.output_conv(path_1)
84
-
85
- return out
86
-
87
-
88
- class DPTDepthModel(DPT):
89
- def __init__(self, path=None, non_negative=True, **kwargs):
90
- features = kwargs["features"] if "features" in kwargs else 256
91
-
92
- head = nn.Sequential(
93
- nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
- nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
- nn.ReLU(True),
97
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
- nn.ReLU(True) if non_negative else nn.Identity(),
99
- nn.Identity(),
100
- )
101
-
102
- super().__init__(head, **kwargs)
103
-
104
- if path is not None:
105
- self.load(path)
106
- print("Midas depth estimation model loaded.")
107
-
108
- def forward(self, x):
109
- return super().forward(x).squeeze(dim=1)
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/midas_net.py DELETED
@@ -1,76 +0,0 @@
1
- """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
- This file contains code that is adapted from
3
- https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
- """
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .base_model import BaseModel
9
- from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
-
11
-
12
- class MidasNet(BaseModel):
13
- """Network for monocular depth estimation.
14
- """
15
-
16
- def __init__(self, path=None, features=256, non_negative=True):
17
- """Init.
18
-
19
- Args:
20
- path (str, optional): Path to saved model. Defaults to None.
21
- features (int, optional): Number of features. Defaults to 256.
22
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
- """
24
- print("Loading weights: ", path)
25
-
26
- super(MidasNet, self).__init__()
27
-
28
- use_pretrained = False if path is None else True
29
-
30
- self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
-
32
- self.scratch.refinenet4 = FeatureFusionBlock(features)
33
- self.scratch.refinenet3 = FeatureFusionBlock(features)
34
- self.scratch.refinenet2 = FeatureFusionBlock(features)
35
- self.scratch.refinenet1 = FeatureFusionBlock(features)
36
-
37
- self.scratch.output_conv = nn.Sequential(
38
- nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
- Interpolate(scale_factor=2, mode="bilinear"),
40
- nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
- nn.ReLU(True),
42
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
- nn.ReLU(True) if non_negative else nn.Identity(),
44
- )
45
-
46
- if path:
47
- self.load(path)
48
-
49
- def forward(self, x):
50
- """Forward pass.
51
-
52
- Args:
53
- x (tensor): input data (image)
54
-
55
- Returns:
56
- tensor: depth
57
- """
58
-
59
- layer_1 = self.pretrained.layer1(x)
60
- layer_2 = self.pretrained.layer2(layer_1)
61
- layer_3 = self.pretrained.layer3(layer_2)
62
- layer_4 = self.pretrained.layer4(layer_3)
63
-
64
- layer_1_rn = self.scratch.layer1_rn(layer_1)
65
- layer_2_rn = self.scratch.layer2_rn(layer_2)
66
- layer_3_rn = self.scratch.layer3_rn(layer_3)
67
- layer_4_rn = self.scratch.layer4_rn(layer_4)
68
-
69
- path_4 = self.scratch.refinenet4(layer_4_rn)
70
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
-
74
- out = self.scratch.output_conv(path_1)
75
-
76
- return torch.squeeze(out, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/midas_net_custom.py DELETED
@@ -1,128 +0,0 @@
1
- """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
- This file contains code that is adapted from
3
- https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
- """
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .base_model import BaseModel
9
- from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
-
11
-
12
- class MidasNet_small(BaseModel):
13
- """Network for monocular depth estimation.
14
- """
15
-
16
- def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
- blocks={'expand': True}):
18
- """Init.
19
-
20
- Args:
21
- path (str, optional): Path to saved model. Defaults to None.
22
- features (int, optional): Number of features. Defaults to 256.
23
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
- """
25
- print("Loading weights: ", path)
26
-
27
- super(MidasNet_small, self).__init__()
28
-
29
- use_pretrained = False if path else True
30
-
31
- self.channels_last = channels_last
32
- self.blocks = blocks
33
- self.backbone = backbone
34
-
35
- self.groups = 1
36
-
37
- features1=features
38
- features2=features
39
- features3=features
40
- features4=features
41
- self.expand = False
42
- if "expand" in self.blocks and self.blocks['expand'] == True:
43
- self.expand = True
44
- features1=features
45
- features2=features*2
46
- features3=features*4
47
- features4=features*8
48
-
49
- self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
-
51
- self.scratch.activation = nn.ReLU(False)
52
-
53
- self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
- self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
- self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
- self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
-
58
-
59
- self.scratch.output_conv = nn.Sequential(
60
- nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
- Interpolate(scale_factor=2, mode="bilinear"),
62
- nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
- self.scratch.activation,
64
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
- nn.ReLU(True) if non_negative else nn.Identity(),
66
- nn.Identity(),
67
- )
68
-
69
- if path:
70
- self.load(path)
71
-
72
-
73
- def forward(self, x):
74
- """Forward pass.
75
-
76
- Args:
77
- x (tensor): input data (image)
78
-
79
- Returns:
80
- tensor: depth
81
- """
82
- if self.channels_last==True:
83
- print("self.channels_last = ", self.channels_last)
84
- x.contiguous(memory_format=torch.channels_last)
85
-
86
-
87
- layer_1 = self.pretrained.layer1(x)
88
- layer_2 = self.pretrained.layer2(layer_1)
89
- layer_3 = self.pretrained.layer3(layer_2)
90
- layer_4 = self.pretrained.layer4(layer_3)
91
-
92
- layer_1_rn = self.scratch.layer1_rn(layer_1)
93
- layer_2_rn = self.scratch.layer2_rn(layer_2)
94
- layer_3_rn = self.scratch.layer3_rn(layer_3)
95
- layer_4_rn = self.scratch.layer4_rn(layer_4)
96
-
97
-
98
- path_4 = self.scratch.refinenet4(layer_4_rn)
99
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
-
103
- out = self.scratch.output_conv(path_1)
104
-
105
- return torch.squeeze(out, dim=1)
106
-
107
-
108
-
109
- def fuse_model(m):
110
- prev_previous_type = nn.Identity()
111
- prev_previous_name = ''
112
- previous_type = nn.Identity()
113
- previous_name = ''
114
- for name, module in m.named_modules():
115
- if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
- # print("FUSED ", prev_previous_name, previous_name, name)
117
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
- elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
- # print("FUSED ", prev_previous_name, previous_name)
120
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
- # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
- # print("FUSED ", previous_name, name)
123
- # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
-
125
- prev_previous_type = previous_type
126
- prev_previous_name = previous_name
127
- previous_type = type(module)
128
- previous_name = name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/transforms.py DELETED
@@ -1,234 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import math
4
-
5
-
6
- def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
- """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
-
9
- Args:
10
- sample (dict): sample
11
- size (tuple): image size
12
-
13
- Returns:
14
- tuple: new size
15
- """
16
- shape = list(sample["disparity"].shape)
17
-
18
- if shape[0] >= size[0] and shape[1] >= size[1]:
19
- return sample
20
-
21
- scale = [0, 0]
22
- scale[0] = size[0] / shape[0]
23
- scale[1] = size[1] / shape[1]
24
-
25
- scale = max(scale)
26
-
27
- shape[0] = math.ceil(scale * shape[0])
28
- shape[1] = math.ceil(scale * shape[1])
29
-
30
- # resize
31
- sample["image"] = cv2.resize(
32
- sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
- )
34
-
35
- sample["disparity"] = cv2.resize(
36
- sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
- )
38
- sample["mask"] = cv2.resize(
39
- sample["mask"].astype(np.float32),
40
- tuple(shape[::-1]),
41
- interpolation=cv2.INTER_NEAREST,
42
- )
43
- sample["mask"] = sample["mask"].astype(bool)
44
-
45
- return tuple(shape)
46
-
47
-
48
- class Resize(object):
49
- """Resize sample to given size (width, height).
50
- """
51
-
52
- def __init__(
53
- self,
54
- width,
55
- height,
56
- resize_target=True,
57
- keep_aspect_ratio=False,
58
- ensure_multiple_of=1,
59
- resize_method="lower_bound",
60
- image_interpolation_method=cv2.INTER_AREA,
61
- ):
62
- """Init.
63
-
64
- Args:
65
- width (int): desired output width
66
- height (int): desired output height
67
- resize_target (bool, optional):
68
- True: Resize the full sample (image, mask, target).
69
- False: Resize image only.
70
- Defaults to True.
71
- keep_aspect_ratio (bool, optional):
72
- True: Keep the aspect ratio of the input sample.
73
- Output sample might not have the given width and height, and
74
- resize behaviour depends on the parameter 'resize_method'.
75
- Defaults to False.
76
- ensure_multiple_of (int, optional):
77
- Output width and height is constrained to be multiple of this parameter.
78
- Defaults to 1.
79
- resize_method (str, optional):
80
- "lower_bound": Output will be at least as large as the given size.
81
- "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
- "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
- Defaults to "lower_bound".
84
- """
85
- self.__width = width
86
- self.__height = height
87
-
88
- self.__resize_target = resize_target
89
- self.__keep_aspect_ratio = keep_aspect_ratio
90
- self.__multiple_of = ensure_multiple_of
91
- self.__resize_method = resize_method
92
- self.__image_interpolation_method = image_interpolation_method
93
-
94
- def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
- y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
-
97
- if max_val is not None and y > max_val:
98
- y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
-
100
- if y < min_val:
101
- y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
-
103
- return y
104
-
105
- def get_size(self, width, height):
106
- # determine new height and width
107
- scale_height = self.__height / height
108
- scale_width = self.__width / width
109
-
110
- if self.__keep_aspect_ratio:
111
- if self.__resize_method == "lower_bound":
112
- # scale such that output size is lower bound
113
- if scale_width > scale_height:
114
- # fit width
115
- scale_height = scale_width
116
- else:
117
- # fit height
118
- scale_width = scale_height
119
- elif self.__resize_method == "upper_bound":
120
- # scale such that output size is upper bound
121
- if scale_width < scale_height:
122
- # fit width
123
- scale_height = scale_width
124
- else:
125
- # fit height
126
- scale_width = scale_height
127
- elif self.__resize_method == "minimal":
128
- # scale as least as possbile
129
- if abs(1 - scale_width) < abs(1 - scale_height):
130
- # fit width
131
- scale_height = scale_width
132
- else:
133
- # fit height
134
- scale_width = scale_height
135
- else:
136
- raise ValueError(
137
- f"resize_method {self.__resize_method} not implemented"
138
- )
139
-
140
- if self.__resize_method == "lower_bound":
141
- new_height = self.constrain_to_multiple_of(
142
- scale_height * height, min_val=self.__height
143
- )
144
- new_width = self.constrain_to_multiple_of(
145
- scale_width * width, min_val=self.__width
146
- )
147
- elif self.__resize_method == "upper_bound":
148
- new_height = self.constrain_to_multiple_of(
149
- scale_height * height, max_val=self.__height
150
- )
151
- new_width = self.constrain_to_multiple_of(
152
- scale_width * width, max_val=self.__width
153
- )
154
- elif self.__resize_method == "minimal":
155
- new_height = self.constrain_to_multiple_of(scale_height * height)
156
- new_width = self.constrain_to_multiple_of(scale_width * width)
157
- else:
158
- raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
-
160
- return (new_width, new_height)
161
-
162
- def __call__(self, sample):
163
- width, height = self.get_size(
164
- sample["image"].shape[1], sample["image"].shape[0]
165
- )
166
-
167
- # resize sample
168
- sample["image"] = cv2.resize(
169
- sample["image"],
170
- (width, height),
171
- interpolation=self.__image_interpolation_method,
172
- )
173
-
174
- if self.__resize_target:
175
- if "disparity" in sample:
176
- sample["disparity"] = cv2.resize(
177
- sample["disparity"],
178
- (width, height),
179
- interpolation=cv2.INTER_NEAREST,
180
- )
181
-
182
- if "depth" in sample:
183
- sample["depth"] = cv2.resize(
184
- sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
- )
186
-
187
- sample["mask"] = cv2.resize(
188
- sample["mask"].astype(np.float32),
189
- (width, height),
190
- interpolation=cv2.INTER_NEAREST,
191
- )
192
- sample["mask"] = sample["mask"].astype(bool)
193
-
194
- return sample
195
-
196
-
197
- class NormalizeImage(object):
198
- """Normlize image by given mean and std.
199
- """
200
-
201
- def __init__(self, mean, std):
202
- self.__mean = mean
203
- self.__std = std
204
-
205
- def __call__(self, sample):
206
- sample["image"] = (sample["image"] - self.__mean) / self.__std
207
-
208
- return sample
209
-
210
-
211
- class PrepareForNet(object):
212
- """Prepare sample for usage as network input.
213
- """
214
-
215
- def __init__(self):
216
- pass
217
-
218
- def __call__(self, sample):
219
- image = np.transpose(sample["image"], (2, 0, 1))
220
- sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
-
222
- if "mask" in sample:
223
- sample["mask"] = sample["mask"].astype(np.float32)
224
- sample["mask"] = np.ascontiguousarray(sample["mask"])
225
-
226
- if "disparity" in sample:
227
- disparity = sample["disparity"].astype(np.float32)
228
- sample["disparity"] = np.ascontiguousarray(disparity)
229
-
230
- if "depth" in sample:
231
- depth = sample["depth"].astype(np.float32)
232
- sample["depth"] = np.ascontiguousarray(depth)
233
-
234
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/midas/vit.py DELETED
@@ -1,489 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import timm
4
- import types
5
- import math
6
- import torch.nn.functional as F
7
-
8
-
9
- class Slice(nn.Module):
10
- def __init__(self, start_index=1):
11
- super(Slice, self).__init__()
12
- self.start_index = start_index
13
-
14
- def forward(self, x):
15
- return x[:, self.start_index :]
16
-
17
-
18
- class AddReadout(nn.Module):
19
- def __init__(self, start_index=1):
20
- super(AddReadout, self).__init__()
21
- self.start_index = start_index
22
-
23
- def forward(self, x):
24
- if self.start_index == 2:
25
- readout = (x[:, 0] + x[:, 1]) / 2
26
- else:
27
- readout = x[:, 0]
28
- return x[:, self.start_index :] + readout.unsqueeze(1)
29
-
30
-
31
- class ProjectReadout(nn.Module):
32
- def __init__(self, in_features, start_index=1):
33
- super(ProjectReadout, self).__init__()
34
- self.start_index = start_index
35
-
36
- self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
-
38
- def forward(self, x):
39
- readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
- features = torch.cat((x[:, self.start_index :], readout), -1)
41
-
42
- return self.project(features)
43
-
44
-
45
- class Transpose(nn.Module):
46
- def __init__(self, dim0, dim1):
47
- super(Transpose, self).__init__()
48
- self.dim0 = dim0
49
- self.dim1 = dim1
50
-
51
- def forward(self, x):
52
- x = x.transpose(self.dim0, self.dim1)
53
- return x
54
-
55
-
56
- activations = {}
57
- def forward_vit(pretrained, x):
58
- b, c, h, w = x.shape
59
-
60
- glob = pretrained.model.forward_flex(x)
61
- pretrained.activations = activations
62
-
63
- layer_1 = pretrained.activations["1"]
64
- layer_2 = pretrained.activations["2"]
65
- layer_3 = pretrained.activations["3"]
66
- layer_4 = pretrained.activations["4"]
67
-
68
- layer_1 = pretrained.act_postprocess1[0:2](layer_1)
69
- layer_2 = pretrained.act_postprocess2[0:2](layer_2)
70
- layer_3 = pretrained.act_postprocess3[0:2](layer_3)
71
- layer_4 = pretrained.act_postprocess4[0:2](layer_4)
72
-
73
- unflatten = nn.Sequential(
74
- nn.Unflatten(
75
- 2,
76
- torch.Size(
77
- [
78
- h // pretrained.model.patch_size[1],
79
- w // pretrained.model.patch_size[0],
80
- ]
81
- ),
82
- )
83
- )
84
-
85
- if layer_1.ndim == 3:
86
- layer_1 = unflatten(layer_1)
87
- if layer_2.ndim == 3:
88
- layer_2 = unflatten(layer_2)
89
- if layer_3.ndim == 3:
90
- layer_3 = unflatten(layer_3)
91
- if layer_4.ndim == 3:
92
- layer_4 = unflatten(layer_4)
93
-
94
- layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
95
- layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
96
- layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
97
- layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
98
-
99
- return layer_1, layer_2, layer_3, layer_4
100
-
101
-
102
- def _resize_pos_embed(self, posemb, gs_h, gs_w):
103
- posemb_tok, posemb_grid = (
104
- posemb[:, : self.start_index],
105
- posemb[0, self.start_index :],
106
- )
107
-
108
- gs_old = int(math.sqrt(len(posemb_grid)))
109
-
110
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
111
- posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
112
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
113
-
114
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
115
-
116
- return posemb
117
-
118
-
119
- def forward_flex(self, x):
120
- b, c, h, w = x.shape
121
-
122
- pos_embed = self._resize_pos_embed(
123
- self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
124
- )
125
-
126
- B = x.shape[0]
127
-
128
- if hasattr(self.patch_embed, "backbone"):
129
- x = self.patch_embed.backbone(x)
130
- if isinstance(x, (list, tuple)):
131
- x = x[-1] # last feature if backbone outputs list/tuple of features
132
-
133
- x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
134
-
135
- if getattr(self, "dist_token", None) is not None:
136
- cls_tokens = self.cls_token.expand(
137
- B, -1, -1
138
- ) # stole cls_tokens impl from Phil Wang, thanks
139
- dist_token = self.dist_token.expand(B, -1, -1)
140
- x = torch.cat((cls_tokens, dist_token, x), dim=1)
141
- else:
142
- cls_tokens = self.cls_token.expand(
143
- B, -1, -1
144
- ) # stole cls_tokens impl from Phil Wang, thanks
145
- x = torch.cat((cls_tokens, x), dim=1)
146
-
147
- x = x + pos_embed
148
- x = self.pos_drop(x)
149
-
150
- for blk in self.blocks:
151
- x = blk(x)
152
-
153
- x = self.norm(x)
154
-
155
- return x
156
-
157
-
158
- def get_activation(name):
159
- def hook(model, input, output):
160
- activations[name] = output
161
- return hook
162
-
163
-
164
- def get_readout_oper(vit_features, features, use_readout, start_index=1):
165
- if use_readout == "ignore":
166
- readout_oper = [Slice(start_index)] * len(features)
167
- elif use_readout == "add":
168
- readout_oper = [AddReadout(start_index)] * len(features)
169
- elif use_readout == "project":
170
- readout_oper = [
171
- ProjectReadout(vit_features, start_index) for out_feat in features
172
- ]
173
- else:
174
- assert (
175
- False
176
- ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
177
-
178
- return readout_oper
179
-
180
-
181
- def _make_vit_b16_backbone(
182
- model,
183
- features=[96, 192, 384, 768],
184
- size=[384, 384],
185
- hooks=[2, 5, 8, 11],
186
- vit_features=768,
187
- use_readout="ignore",
188
- start_index=1,
189
- ):
190
- pretrained = nn.Module()
191
-
192
- pretrained.model = model
193
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
194
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
195
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
196
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
197
-
198
- pretrained.activations = activations
199
-
200
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
201
-
202
- # 32, 48, 136, 384
203
- pretrained.act_postprocess1 = nn.Sequential(
204
- readout_oper[0],
205
- Transpose(1, 2),
206
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
207
- nn.Conv2d(
208
- in_channels=vit_features,
209
- out_channels=features[0],
210
- kernel_size=1,
211
- stride=1,
212
- padding=0,
213
- ),
214
- nn.ConvTranspose2d(
215
- in_channels=features[0],
216
- out_channels=features[0],
217
- kernel_size=4,
218
- stride=4,
219
- padding=0,
220
- bias=True,
221
- dilation=1,
222
- groups=1,
223
- ),
224
- )
225
-
226
- pretrained.act_postprocess2 = nn.Sequential(
227
- readout_oper[1],
228
- Transpose(1, 2),
229
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
230
- nn.Conv2d(
231
- in_channels=vit_features,
232
- out_channels=features[1],
233
- kernel_size=1,
234
- stride=1,
235
- padding=0,
236
- ),
237
- nn.ConvTranspose2d(
238
- in_channels=features[1],
239
- out_channels=features[1],
240
- kernel_size=2,
241
- stride=2,
242
- padding=0,
243
- bias=True,
244
- dilation=1,
245
- groups=1,
246
- ),
247
- )
248
-
249
- pretrained.act_postprocess3 = nn.Sequential(
250
- readout_oper[2],
251
- Transpose(1, 2),
252
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
253
- nn.Conv2d(
254
- in_channels=vit_features,
255
- out_channels=features[2],
256
- kernel_size=1,
257
- stride=1,
258
- padding=0,
259
- ),
260
- )
261
-
262
- pretrained.act_postprocess4 = nn.Sequential(
263
- readout_oper[3],
264
- Transpose(1, 2),
265
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
266
- nn.Conv2d(
267
- in_channels=vit_features,
268
- out_channels=features[3],
269
- kernel_size=1,
270
- stride=1,
271
- padding=0,
272
- ),
273
- nn.Conv2d(
274
- in_channels=features[3],
275
- out_channels=features[3],
276
- kernel_size=3,
277
- stride=2,
278
- padding=1,
279
- ),
280
- )
281
-
282
- pretrained.model.start_index = start_index
283
- pretrained.model.patch_size = [16, 16]
284
-
285
- # We inject this function into the VisionTransformer instances so that
286
- # we can use it with interpolated position embeddings without modifying the library source.
287
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
288
- pretrained.model._resize_pos_embed = types.MethodType(
289
- _resize_pos_embed, pretrained.model
290
- )
291
-
292
- return pretrained
293
-
294
-
295
- def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
296
- model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
297
-
298
- hooks = [5, 11, 17, 23] if hooks == None else hooks
299
- return _make_vit_b16_backbone(
300
- model,
301
- features=[256, 512, 1024, 1024],
302
- hooks=hooks,
303
- vit_features=1024,
304
- use_readout=use_readout,
305
- )
306
-
307
-
308
- def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
309
- model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
310
-
311
- hooks = [2, 5, 8, 11] if hooks == None else hooks
312
- return _make_vit_b16_backbone(
313
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
314
- )
315
-
316
-
317
- def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
318
- model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
319
-
320
- hooks = [2, 5, 8, 11] if hooks == None else hooks
321
- return _make_vit_b16_backbone(
322
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
323
- )
324
-
325
-
326
- def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
327
- model = timm.create_model(
328
- "vit_deit_base_distilled_patch16_384", pretrained=pretrained
329
- )
330
-
331
- hooks = [2, 5, 8, 11] if hooks == None else hooks
332
- return _make_vit_b16_backbone(
333
- model,
334
- features=[96, 192, 384, 768],
335
- hooks=hooks,
336
- use_readout=use_readout,
337
- start_index=2,
338
- )
339
-
340
-
341
- def _make_vit_b_rn50_backbone(
342
- model,
343
- features=[256, 512, 768, 768],
344
- size=[384, 384],
345
- hooks=[0, 1, 8, 11],
346
- vit_features=768,
347
- use_vit_only=False,
348
- use_readout="ignore",
349
- start_index=1,
350
- ):
351
- pretrained = nn.Module()
352
-
353
- pretrained.model = model
354
-
355
- if use_vit_only == True:
356
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
357
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
358
- else:
359
- pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
360
- get_activation("1")
361
- )
362
- pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
363
- get_activation("2")
364
- )
365
-
366
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
367
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
368
-
369
- pretrained.activations = activations
370
-
371
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
372
-
373
- if use_vit_only == True:
374
- pretrained.act_postprocess1 = nn.Sequential(
375
- readout_oper[0],
376
- Transpose(1, 2),
377
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
378
- nn.Conv2d(
379
- in_channels=vit_features,
380
- out_channels=features[0],
381
- kernel_size=1,
382
- stride=1,
383
- padding=0,
384
- ),
385
- nn.ConvTranspose2d(
386
- in_channels=features[0],
387
- out_channels=features[0],
388
- kernel_size=4,
389
- stride=4,
390
- padding=0,
391
- bias=True,
392
- dilation=1,
393
- groups=1,
394
- ),
395
- )
396
-
397
- pretrained.act_postprocess2 = nn.Sequential(
398
- readout_oper[1],
399
- Transpose(1, 2),
400
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
401
- nn.Conv2d(
402
- in_channels=vit_features,
403
- out_channels=features[1],
404
- kernel_size=1,
405
- stride=1,
406
- padding=0,
407
- ),
408
- nn.ConvTranspose2d(
409
- in_channels=features[1],
410
- out_channels=features[1],
411
- kernel_size=2,
412
- stride=2,
413
- padding=0,
414
- bias=True,
415
- dilation=1,
416
- groups=1,
417
- ),
418
- )
419
- else:
420
- pretrained.act_postprocess1 = nn.Sequential(
421
- nn.Identity(), nn.Identity(), nn.Identity()
422
- )
423
- pretrained.act_postprocess2 = nn.Sequential(
424
- nn.Identity(), nn.Identity(), nn.Identity()
425
- )
426
-
427
- pretrained.act_postprocess3 = nn.Sequential(
428
- readout_oper[2],
429
- Transpose(1, 2),
430
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
431
- nn.Conv2d(
432
- in_channels=vit_features,
433
- out_channels=features[2],
434
- kernel_size=1,
435
- stride=1,
436
- padding=0,
437
- ),
438
- )
439
-
440
- pretrained.act_postprocess4 = nn.Sequential(
441
- readout_oper[3],
442
- Transpose(1, 2),
443
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
444
- nn.Conv2d(
445
- in_channels=vit_features,
446
- out_channels=features[3],
447
- kernel_size=1,
448
- stride=1,
449
- padding=0,
450
- ),
451
- nn.Conv2d(
452
- in_channels=features[3],
453
- out_channels=features[3],
454
- kernel_size=3,
455
- stride=2,
456
- padding=1,
457
- ),
458
- )
459
-
460
- pretrained.model.start_index = start_index
461
- pretrained.model.patch_size = [16, 16]
462
-
463
- # We inject this function into the VisionTransformer instances so that
464
- # we can use it with interpolated position embeddings without modifying the library source.
465
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
466
-
467
- # We inject this function into the VisionTransformer instances so that
468
- # we can use it with interpolated position embeddings without modifying the library source.
469
- pretrained.model._resize_pos_embed = types.MethodType(
470
- _resize_pos_embed, pretrained.model
471
- )
472
-
473
- return pretrained
474
-
475
-
476
- def _make_pretrained_vitb_rn50_384(
477
- pretrained, use_readout="ignore", hooks=None, use_vit_only=False
478
- ):
479
- model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
480
-
481
- hooks = [0, 1, 8, 11] if hooks == None else hooks
482
- return _make_vit_b_rn50_backbone(
483
- model,
484
- features=[256, 512, 768, 768],
485
- size=[384, 384],
486
- hooks=hooks,
487
- use_vit_only=use_vit_only,
488
- use_readout=use_readout,
489
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extralibs/midas/utils.py DELETED
@@ -1,189 +0,0 @@
1
- """Utils for monoDepth."""
2
- import sys
3
- import re
4
- import numpy as np
5
- import cv2
6
- import torch
7
-
8
-
9
- def read_pfm(path):
10
- """Read pfm file.
11
-
12
- Args:
13
- path (str): path to file
14
-
15
- Returns:
16
- tuple: (data, scale)
17
- """
18
- with open(path, "rb") as file:
19
-
20
- color = None
21
- width = None
22
- height = None
23
- scale = None
24
- endian = None
25
-
26
- header = file.readline().rstrip()
27
- if header.decode("ascii") == "PF":
28
- color = True
29
- elif header.decode("ascii") == "Pf":
30
- color = False
31
- else:
32
- raise Exception("Not a PFM file: " + path)
33
-
34
- dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
- if dim_match:
36
- width, height = list(map(int, dim_match.groups()))
37
- else:
38
- raise Exception("Malformed PFM header.")
39
-
40
- scale = float(file.readline().decode("ascii").rstrip())
41
- if scale < 0:
42
- # little-endian
43
- endian = "<"
44
- scale = -scale
45
- else:
46
- # big-endian
47
- endian = ">"
48
-
49
- data = np.fromfile(file, endian + "f")
50
- shape = (height, width, 3) if color else (height, width)
51
-
52
- data = np.reshape(data, shape)
53
- data = np.flipud(data)
54
-
55
- return data, scale
56
-
57
-
58
- def write_pfm(path, image, scale=1):
59
- """Write pfm file.
60
-
61
- Args:
62
- path (str): pathto file
63
- image (array): data
64
- scale (int, optional): Scale. Defaults to 1.
65
- """
66
-
67
- with open(path, "wb") as file:
68
- color = None
69
-
70
- if image.dtype.name != "float32":
71
- raise Exception("Image dtype must be float32.")
72
-
73
- image = np.flipud(image)
74
-
75
- if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
- color = True
77
- elif (
78
- len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
- ): # greyscale
80
- color = False
81
- else:
82
- raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
-
84
- file.write("PF\n" if color else "Pf\n".encode())
85
- file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
-
87
- endian = image.dtype.byteorder
88
-
89
- if endian == "<" or endian == "=" and sys.byteorder == "little":
90
- scale = -scale
91
-
92
- file.write("%f\n".encode() % scale)
93
-
94
- image.tofile(file)
95
-
96
-
97
- def read_image(path):
98
- """Read image and output RGB image (0-1).
99
-
100
- Args:
101
- path (str): path to file
102
-
103
- Returns:
104
- array: RGB image (0-1)
105
- """
106
- img = cv2.imread(path)
107
-
108
- if img.ndim == 2:
109
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
-
111
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
-
113
- return img
114
-
115
-
116
- def resize_image(img):
117
- """Resize image and make it fit for network.
118
-
119
- Args:
120
- img (array): image
121
-
122
- Returns:
123
- tensor: data ready for network
124
- """
125
- height_orig = img.shape[0]
126
- width_orig = img.shape[1]
127
-
128
- if width_orig > height_orig:
129
- scale = width_orig / 384
130
- else:
131
- scale = height_orig / 384
132
-
133
- height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
- width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
-
136
- img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
-
138
- img_resized = (
139
- torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
- )
141
- img_resized = img_resized.unsqueeze(0)
142
-
143
- return img_resized
144
-
145
-
146
- def resize_depth(depth, width, height):
147
- """Resize depth map and bring to CPU (numpy).
148
-
149
- Args:
150
- depth (tensor): depth
151
- width (int): image width
152
- height (int): image height
153
-
154
- Returns:
155
- array: processed depth
156
- """
157
- depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
-
159
- depth_resized = cv2.resize(
160
- depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
- )
162
-
163
- return depth_resized
164
-
165
- def write_depth(path, depth, bits=1):
166
- """Write depth map to pfm and png file.
167
-
168
- Args:
169
- path (str): filepath without extension
170
- depth (array): depth
171
- """
172
- write_pfm(path + ".pfm", depth.astype(np.float32))
173
-
174
- depth_min = depth.min()
175
- depth_max = depth.max()
176
-
177
- max_val = (2**(8*bits))-1
178
-
179
- if depth_max - depth_min > np.finfo("float").eps:
180
- out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
- else:
182
- out = np.zeros(depth.shape, dtype=depth.type)
183
-
184
- if bits == 1:
185
- cv2.imwrite(path + ".png", out.astype("uint8"))
186
- elif bits == 2:
187
- cv2.imwrite(path + ".png", out.astype("uint16"))
188
-
189
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
i2v_test.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+
4
+ import torch
5
+
6
+ from scripts.evaluation.funcs import load_model_checkpoint, load_image_batch, save_videos, batch_ddim_sampling
7
+ from utils.utils import instantiate_from_config
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ class Image2Video():
11
+ def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
12
+ self.download_model()
13
+ self.result_dir = result_dir
14
+ if not os.path.exists(self.result_dir):
15
+ os.mkdir(self.result_dir)
16
+ ckpt_path='checkpoints/i2v_512_v1/model.ckpt'
17
+ config_file='configs/inference_i2v_512_v1.0.yaml'
18
+ config = OmegaConf.load(config_file)
19
+ model_config = config.pop("model", OmegaConf.create())
20
+ model_config['params']['unet_config']['params']['use_checkpoint']=False
21
+ model_list = []
22
+ for gpu_id in range(gpu_num):
23
+ model = instantiate_from_config(model_config)
24
+ model = model.cuda(gpu_id)
25
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
26
+ model = load_model_checkpoint(model, ckpt_path)
27
+ model.eval()
28
+ model_list.append(model)
29
+ self.model_list = model_list
30
+ self.save_fps = 8
31
+
32
+ def get_image(self, image, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16):
33
+ gpu_id=0
34
+ if steps > 60:
35
+ steps = 60
36
+ model = self.model_list[gpu_id]
37
+ batch_size=1
38
+ channels = model.model.diffusion_model.in_channels
39
+ frames = model.temporal_length
40
+ h, w = 320 // 8, 512 // 8
41
+ noise_shape = [batch_size, channels, frames, h, w]
42
+
43
+ #prompts = batch_size * [""]
44
+ text_emb = model.get_learned_conditioning([prompt])
45
+
46
+ # cond_images = load_image_batch([image_path])
47
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float()
48
+ img_tensor = (img_tensor / 255. - 0.5) * 2
49
+ img_tensor = img_tensor.unsqueeze(0)
50
+ cond_images = img_tensor.to(model.device)
51
+ img_emb = model.get_image_embeds(cond_images)
52
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
53
+ cond = {"c_crossattn": [imtext_cond], "fps": fps}
54
+
55
+ ## inference
56
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
57
+ ## b,samples,c,t,h,w
58
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
59
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
60
+ prompt_str=prompt_str[:30]
61
+
62
+ save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
63
+ return os.path.join(self.result_dir, f"{prompt_str}.mp4")
64
+
65
+ def download_model(self):
66
+ REPO_ID = 'VideoCrafter/Image2Video-512-v1.0'
67
+ filename_list = ['model.ckpt']
68
+ if not os.path.exists('./checkpoints/i2v_512_v1/'):
69
+ os.makedirs('./checkpoints/i2v_512_v1/')
70
+ for filename in filename_list:
71
+ local_file = os.path.join('./checkpoints/i2v_512_v1/', filename)
72
+ if not os.path.exists(local_file):
73
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/i2v_512_v1/', local_dir_use_symlinks=False)
74
+
75
+ if __name__ == '__main__':
76
+ i2v = Image2Video()
77
+ video_path = i2v.get_image('prompts/i2v_prompts/horse.png','horses are walking on the grassland')
78
+ print('done', video_path)
input/flamingo.mp4 DELETED
Binary file (897 kB)
 
input/prompts.txt DELETED
@@ -1,2 +0,0 @@
1
- astronaut riding a horse
2
- Flying through an intense battle between pirate ships in a stormy ocean
 
 
 
lvdm/basics.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+ import torch.nn as nn
11
+ from utils.utils import instantiate_from_config
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+ def zero_module(module):
20
+ """
21
+ Zero out the parameters of a module and return it.
22
+ """
23
+ for p in module.parameters():
24
+ p.detach().zero_()
25
+ return module
26
+
27
+ def scale_module(module, scale):
28
+ """
29
+ Scale the parameters of a module and return it.
30
+ """
31
+ for p in module.parameters():
32
+ p.detach().mul_(scale)
33
+ return module
34
+
35
+
36
+ def conv_nd(dims, *args, **kwargs):
37
+ """
38
+ Create a 1D, 2D, or 3D convolution module.
39
+ """
40
+ if dims == 1:
41
+ return nn.Conv1d(*args, **kwargs)
42
+ elif dims == 2:
43
+ return nn.Conv2d(*args, **kwargs)
44
+ elif dims == 3:
45
+ return nn.Conv3d(*args, **kwargs)
46
+ raise ValueError(f"unsupported dimensions: {dims}")
47
+
48
+
49
+ def linear(*args, **kwargs):
50
+ """
51
+ Create a linear module.
52
+ """
53
+ return nn.Linear(*args, **kwargs)
54
+
55
+
56
+ def avg_pool_nd(dims, *args, **kwargs):
57
+ """
58
+ Create a 1D, 2D, or 3D average pooling module.
59
+ """
60
+ if dims == 1:
61
+ return nn.AvgPool1d(*args, **kwargs)
62
+ elif dims == 2:
63
+ return nn.AvgPool2d(*args, **kwargs)
64
+ elif dims == 3:
65
+ return nn.AvgPool3d(*args, **kwargs)
66
+ raise ValueError(f"unsupported dimensions: {dims}")
67
+
68
+
69
+ def nonlinearity(type='silu'):
70
+ if type == 'silu':
71
+ return nn.SiLU()
72
+ elif type == 'leaky_relu':
73
+ return nn.LeakyReLU()
74
+
75
+
76
+ class GroupNormSpecific(nn.GroupNorm):
77
+ def forward(self, x):
78
+ return super().forward(x.float()).type(x.dtype)
79
+
80
+
81
+ def normalization(channels, num_groups=32):
82
+ """
83
+ Make a standard normalization layer.
84
+ :param channels: number of input channels.
85
+ :return: an nn.Module for normalization.
86
+ """
87
+ return GroupNormSpecific(num_groups, channels)
88
+
89
+
90
+ class HybridConditioner(nn.Module):
91
+
92
+ def __init__(self, c_concat_config, c_crossattn_config):
93
+ super().__init__()
94
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
95
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
96
+
97
+ def forward(self, c_concat, c_crossattn):
98
+ c_concat = self.concat_conditioner(c_concat)
99
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
100
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
lvdm/common.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ import torch
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+
7
+
8
+ def gather_data(data, return_np=True):
9
+ ''' gather data from multiple processes to one list '''
10
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
11
+ dist.all_gather(data_list, data) # gather not supported with NCCL
12
+ if return_np:
13
+ data_list = [data.cpu().numpy() for data in data_list]
14
+ return data_list
15
+
16
+ def autocast(f):
17
+ def do_autocast(*args, **kwargs):
18
+ with torch.cuda.amp.autocast(enabled=True,
19
+ dtype=torch.get_autocast_gpu_dtype(),
20
+ cache_enabled=torch.is_autocast_cache_enabled()):
21
+ return f(*args, **kwargs)
22
+ return do_autocast
23
+
24
+
25
+ def extract_into_tensor(a, t, x_shape):
26
+ b, *_ = t.shape
27
+ out = a.gather(-1, t)
28
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
29
+
30
+
31
+ def noise_like(shape, device, repeat=False):
32
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
33
+ noise = lambda: torch.randn(shape, device=device)
34
+ return repeat_noise() if repeat else noise()
35
+
36
+
37
+ def default(val, d):
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val):
43
+ return val is not None
44
+
45
+ def identity(*args, **kwargs):
46
+ return nn.Identity()
47
+
48
+ def uniq(arr):
49
+ return{el: True for el in arr}.keys()
50
+
51
+ def mean_flat(tensor):
52
+ """
53
+ Take the mean over all non-batch dimensions.
54
+ """
55
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
56
+
57
+ def ismap(x):
58
+ if not isinstance(x, torch.Tensor):
59
+ return False
60
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
61
+
62
+ def isimage(x):
63
+ if not isinstance(x,torch.Tensor):
64
+ return False
65
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
66
+
67
+ def max_neg_value(t):
68
+ return -torch.finfo(t.dtype).max
69
+
70
+ def shape_to_str(x):
71
+ shape_str = "x".join([str(x) for x in x.shape])
72
+ return shape_str
73
+
74
+ def init_(tensor):
75
+ dim = tensor.shape[-1]
76
+ std = 1 / math.sqrt(dim)
77
+ tensor.uniform_(-std, std)
78
+ return tensor
79
+
80
+ ckpt = torch.utils.checkpoint.checkpoint
81
+ def checkpoint(func, inputs, params, flag):
82
+ """
83
+ Evaluate a function without caching intermediate activations, allowing for
84
+ reduced memory at the expense of extra compute in the backward pass.
85
+ :param func: the function to evaluate.
86
+ :param inputs: the argument sequence to pass to `func`.
87
+ :param params: a sequence of parameters `func` depends on but does not
88
+ explicitly take as arguments.
89
+ :param flag: if False, disable gradient checkpointing.
90
+ """
91
+ if flag:
92
+ return ckpt(func, *inputs)
93
+ else:
94
+ return func(*inputs)
95
+
lvdm/data/webvid.py DELETED
@@ -1,188 +0,0 @@
1
- import os
2
- import random
3
- import bisect
4
-
5
- import pandas as pd
6
-
7
- import omegaconf
8
- import torch
9
- from torch.utils.data import Dataset
10
- from torchvision import transforms
11
- from decord import VideoReader, cpu
12
- import torchvision.transforms._transforms_video as transforms_video
13
-
14
- class WebVid(Dataset):
15
- """
16
- WebVid Dataset.
17
- Assumes webvid data is structured as follows.
18
- Webvid/
19
- videos/
20
- 000001_000050/ ($page_dir)
21
- 1.mp4 (videoid.mp4)
22
- ...
23
- 5000.mp4
24
- ...
25
- """
26
- def __init__(self,
27
- meta_path,
28
- data_dir,
29
- subsample=None,
30
- video_length=16,
31
- resolution=[256, 512],
32
- frame_stride=1,
33
- spatial_transform=None,
34
- crop_resolution=None,
35
- fps_max=None,
36
- load_raw_resolution=False,
37
- fps_schedule=None,
38
- fs_probs=None,
39
- bs_per_gpu=None,
40
- trigger_word='',
41
- dataname='',
42
- ):
43
- self.meta_path = meta_path
44
- self.data_dir = data_dir
45
- self.subsample = subsample
46
- self.video_length = video_length
47
- self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution
48
- self.frame_stride = frame_stride
49
- self.fps_max = fps_max
50
- self.load_raw_resolution = load_raw_resolution
51
- self.fs_probs = fs_probs
52
- self.trigger_word = trigger_word
53
- self.dataname = dataname
54
-
55
- self._load_metadata()
56
- if spatial_transform is not None:
57
- if spatial_transform == "random_crop":
58
- self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution)
59
- elif spatial_transform == "resize_center_crop":
60
- assert(self.resolution[0] == self.resolution[1])
61
- self.spatial_transform = transforms.Compose([
62
- transforms.Resize(resolution),
63
- transforms_video.CenterCropVideo(resolution),
64
- ])
65
- else:
66
- raise NotImplementedError
67
- else:
68
- self.spatial_transform = None
69
-
70
- self.fps_schedule = fps_schedule
71
- self.bs_per_gpu = bs_per_gpu
72
- if self.fps_schedule is not None:
73
- assert(self.bs_per_gpu is not None)
74
- self.counter = 0
75
- self.stage_idx = 0
76
-
77
- def _load_metadata(self):
78
- metadata = pd.read_csv(self.meta_path)
79
- if self.subsample is not None:
80
- metadata = metadata.sample(self.subsample, random_state=0)
81
- metadata['caption'] = metadata['name']
82
- del metadata['name']
83
- self.metadata = metadata
84
- self.metadata.dropna(inplace=True)
85
- # self.metadata['caption'] = self.metadata['caption'].str[:350]
86
-
87
- def _get_video_path(self, sample):
88
- if self.dataname == "loradata":
89
- rel_video_fp = str(sample['videoid']) + '.mp4'
90
- full_video_fp = os.path.join(self.data_dir, rel_video_fp)
91
- else:
92
- rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
93
- full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
94
- return full_video_fp, rel_video_fp
95
-
96
- def get_fs_based_on_schedule(self, frame_strides, schedule):
97
- assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1
98
- global_step = self.counter // self.bs_per_gpu # TODO: support resume.
99
- stage_idx = bisect.bisect(schedule, global_step)
100
- frame_stride = frame_strides[stage_idx]
101
- # log stage change
102
- if stage_idx != self.stage_idx:
103
- print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}')
104
- self.stage_idx = stage_idx
105
- return frame_stride
106
-
107
- def get_fs_based_on_probs(self, frame_strides, probs):
108
- assert(len(frame_strides) == len(probs))
109
- return random.choices(frame_strides, weights=probs)[0]
110
-
111
- def get_fs_randomly(self, frame_strides):
112
- return random.choice(frame_strides)
113
-
114
- def __getitem__(self, index):
115
-
116
- if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig):
117
- if self.fps_schedule is not None:
118
- frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule)
119
- elif self.fs_probs is not None:
120
- frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs)
121
- else:
122
- frame_stride = self.get_fs_randomly(self.frame_stride)
123
- else:
124
- frame_stride = self.frame_stride
125
- assert(isinstance(frame_stride, int)), type(frame_stride)
126
-
127
- while True:
128
- index = index % len(self.metadata)
129
- sample = self.metadata.iloc[index]
130
- video_path, rel_fp = self._get_video_path(sample)
131
- caption = sample['caption']+self.trigger_word
132
-
133
- # make reader
134
- try:
135
- if self.load_raw_resolution:
136
- video_reader = VideoReader(video_path, ctx=cpu(0))
137
- else:
138
- video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
139
- if len(video_reader) < self.video_length:
140
- print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})")
141
- index += 1
142
- continue
143
- else:
144
- pass
145
- except:
146
- index += 1
147
- print(f"Load video failed! path = {video_path}")
148
- continue
149
-
150
- # sample strided frames
151
- all_frames = list(range(0, len(video_reader), frame_stride))
152
- if len(all_frames) < self.video_length: # recal a max fs
153
- frame_stride = len(video_reader) // self.video_length
154
- assert(frame_stride != 0)
155
- all_frames = list(range(0, len(video_reader), frame_stride))
156
-
157
- # select a random clip
158
- rand_idx = random.randint(0, len(all_frames) - self.video_length)
159
- frame_indices = all_frames[rand_idx:rand_idx+self.video_length]
160
- try:
161
- frames = video_reader.get_batch(frame_indices)
162
- break
163
- except:
164
- print(f"Get frames failed! path = {video_path}")
165
- index += 1
166
- continue
167
-
168
- assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
169
- frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
170
- if self.spatial_transform is not None:
171
- frames = self.spatial_transform(frames)
172
- if self.resolution is not None:
173
- assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
174
- frames = (frames / 255 - 0.5) * 2
175
-
176
- fps_ori = video_reader.get_avg_fps()
177
- fps_clip = fps_ori // frame_stride
178
- if self.fps_max is not None and fps_clip > self.fps_max:
179
- fps_clip = self.fps_max
180
-
181
- data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride}
182
-
183
- if self.fps_schedule is not None:
184
- self.counter += 1
185
- return data
186
-
187
- def __len__(self):
188
- return len(self.metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/{models/modules/distributions.py → distributions.py} RENAMED
@@ -2,6 +2,25 @@ import torch
2
  import numpy as np
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class DiagonalGaussianDistribution(object):
6
  def __init__(self, parameters, deterministic=False):
7
  self.parameters = parameters
 
2
  import numpy as np
3
 
4
 
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
  class DiagonalGaussianDistribution(object):
25
  def __init__(self, parameters, deterministic=False):
26
  self.parameters = parameters
lvdm/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
lvdm/models/autoencoder.py CHANGED
@@ -1,12 +1,14 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import torch.nn.functional as F
4
  import os
 
 
 
5
  from einops import rearrange
 
 
 
 
 
6
 
7
- from lvdm.models.modules.autoencoder_modules import Encoder, Decoder
8
- from lvdm.models.modules.distributions import DiagonalGaussianDistribution
9
- from lvdm.utils.common_utils import instantiate_from_config
10
 
11
  class AutoencoderKL(pl.LightningModule):
12
  def __init__(self,
@@ -69,12 +71,12 @@ class AutoencoderKL(pl.LightningModule):
69
  if self.test_args.save_input:
70
  os.makedirs(self.root_inputs, exist_ok=True)
71
  assert(self.test_args is not None)
72
- self.test_maximum = getattr(self.test_args, 'test_maximum', None) #1500 # 12000/8
73
  self.count = 0
74
  self.eval_metrics = {}
75
  self.decodes = []
76
  self.save_decode_samples = 2048
77
-
78
  def init_from_ckpt(self, path, ignore_keys=list()):
79
  sd = torch.load(path, map_location="cpu")
80
  try:
@@ -115,10 +117,6 @@ class AutoencoderKL(pl.LightningModule):
115
 
116
  def get_input(self, batch, k):
117
  x = batch[k]
118
- # if len(x.shape) == 3:
119
- # x = x[..., None]
120
- # if x.dim() == 4:
121
- # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
122
  if x.dim() == 5 and self.input_dim == 4:
123
  b,c,t,h,w = x.shape
124
  self.b = b
@@ -200,3 +198,22 @@ class AutoencoderKL(pl.LightningModule):
200
  x = F.conv2d(x, weight=self.colorize)
201
  x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
202
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from contextlib import contextmanager
3
+ import torch
4
+ import numpy as np
5
  from einops import rearrange
6
+ import torch.nn.functional as F
7
+ import pytorch_lightning as pl
8
+ from lvdm.modules.networks.ae_modules import Encoder, Decoder
9
+ from lvdm.distributions import DiagonalGaussianDistribution
10
+ from utils.utils import instantiate_from_config
11
 
 
 
 
12
 
13
  class AutoencoderKL(pl.LightningModule):
14
  def __init__(self,
 
71
  if self.test_args.save_input:
72
  os.makedirs(self.root_inputs, exist_ok=True)
73
  assert(self.test_args is not None)
74
+ self.test_maximum = getattr(self.test_args, 'test_maximum', None)
75
  self.count = 0
76
  self.eval_metrics = {}
77
  self.decodes = []
78
  self.save_decode_samples = 2048
79
+
80
  def init_from_ckpt(self, path, ignore_keys=list()):
81
  sd = torch.load(path, map_location="cpu")
82
  try:
 
117
 
118
  def get_input(self, batch, k):
119
  x = batch[k]
 
 
 
 
120
  if x.dim() == 5 and self.input_dim == 4:
121
  b,c,t,h,w = x.shape
122
  self.b = b
 
198
  x = F.conv2d(x, weight=self.colorize)
199
  x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
200
  return x
201
+
202
+ class IdentityFirstStage(torch.nn.Module):
203
+ def __init__(self, *args, vq_interface=False, **kwargs):
204
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
205
+ super().__init__()
206
+
207
+ def encode(self, x, *args, **kwargs):
208
+ return x
209
+
210
+ def decode(self, x, *args, **kwargs):
211
+ return x
212
+
213
+ def quantize(self, x, *args, **kwargs):
214
+ if self.vq_interface:
215
+ return x, None, [None, None, None]
216
+ return x
217
+
218
+ def forward(self, x, *args, **kwargs):
219
+ return x
lvdm/models/ddpm3d.py CHANGED
@@ -1,54 +1,42 @@
1
- import os
2
- import time
3
- import random
4
- import itertools
 
 
 
 
5
  from functools import partial
6
  from contextlib import contextmanager
7
-
8
  import numpy as np
9
  from tqdm import tqdm
10
  from einops import rearrange, repeat
11
-
 
12
  import torch
13
  import torch.nn as nn
14
- import pytorch_lightning as pl
15
  from torchvision.utils import make_grid
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from pytorch_lightning.utilities import rank_zero_only
18
- from lvdm.models.modules.distributions import normal_kl, DiagonalGaussianDistribution
19
- from lvdm.models.modules.util import make_beta_schedule, extract_into_tensor, noise_like
20
- from lvdm.models.modules.lora import inject_trainable_lora
21
- from lvdm.samplers.ddim import DDIMSampler
22
- from lvdm.utils.common_utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config, check_istarget
23
-
24
-
25
- def disabled_train(self, mode=True):
26
- """Overwrite model.train with this function to make sure train/eval mode
27
- does not change anymore."""
28
- return self
29
-
30
-
31
- def uniform_on_device(r1, r2, shape, device):
32
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
33
-
34
-
35
- def split_video_to_clips(video, clip_length, drop_left=True):
36
- video_length = video.shape[2]
37
- shape = video.shape
38
- if video_length % clip_length != 0 and drop_left:
39
- video = video[:, :, :video_length // clip_length * clip_length, :, :]
40
- print(f'[split_video_to_clips] Drop frames from {shape} to {video.shape}')
41
- nclips = video_length // clip_length
42
- clips = rearrange(video, 'b c (nc cl) h w -> (b nc) c cl h w', cl=clip_length, nc=nclips)
43
- return clips
44
-
45
- def merge_clips_to_videos(clips, bs):
46
- nclips = clips.shape[0] // bs
47
- video = rearrange(clips, '(b nc) c t h w -> b c (nc t) h w', nc=nclips)
48
- return video
49
 
50
  class DDPM(pl.LightningModule):
51
- # classic DDPM with Gaussian diffusion, in pixel space
52
  def __init__(self,
53
  unet_config,
54
  timesteps=1000,
@@ -57,11 +45,10 @@ class DDPM(pl.LightningModule):
57
  ckpt_path=None,
58
  ignore_keys=[],
59
  load_only_unet=False,
60
- monitor="val/loss",
61
  use_ema=True,
62
  first_stage_key="image",
63
  image_size=256,
64
- video_length=None,
65
  channels=3,
66
  log_every_t=100,
67
  clip_denoised=True,
@@ -70,35 +57,35 @@ class DDPM(pl.LightningModule):
70
  cosine_s=8e-3,
71
  given_betas=None,
72
  original_elbo_weight=0.,
73
- v_posterior=0.,
74
  l_simple_weight=1.,
75
  conditioning_key=None,
76
- parameterization="eps",
77
  scheduler_config=None,
 
78
  learn_logvar=False,
79
- logvar_init=0.,
80
- *args, **kwargs
81
  ):
82
  super().__init__()
83
  assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
84
  self.parameterization = parameterization
85
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
86
  self.cond_stage_model = None
87
  self.clip_denoised = clip_denoised
88
  self.log_every_t = log_every_t
89
  self.first_stage_key = first_stage_key
90
- self.image_size = image_size # try conv?
91
-
 
92
  if isinstance(self.image_size, int):
93
  self.image_size = [self.image_size, self.image_size]
94
- self.channels = channels
95
  self.model = DiffusionWrapper(unet_config, conditioning_key)
96
- self.conditioning_key = conditioning_key # also register conditioning_key in diffusion
97
-
98
- self.temporal_length = video_length if video_length is not None else unet_config.params.temporal_length
99
- count_params(self.model, verbose=True)
100
  self.use_ema = use_ema
101
-
 
 
 
102
  self.use_scheduler = scheduler_config is not None
103
  if self.use_scheduler:
104
  self.scheduler_config = scheduler_config
@@ -122,6 +109,7 @@ class DDPM(pl.LightningModule):
122
  if self.learn_logvar:
123
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
124
 
 
125
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
126
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
127
  if exists(given_betas):
@@ -182,14 +170,14 @@ class DDPM(pl.LightningModule):
182
  self.model_ema.store(self.model.parameters())
183
  self.model_ema.copy_to(self.model)
184
  if context is not None:
185
- print(f"{context}: Switched to EMA weights")
186
  try:
187
  yield None
188
  finally:
189
  if self.use_ema:
190
  self.model_ema.restore(self.model.parameters())
191
  if context is not None:
192
- print(f"{context}: Restored training weights")
193
 
194
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
195
  sd = torch.load(path, map_location="cpu")
@@ -198,16 +186,16 @@ class DDPM(pl.LightningModule):
198
  keys = list(sd.keys())
199
  for k in keys:
200
  for ik in ignore_keys:
201
- if k.startswith(ik) or (ik.startswith('**') and ik.split('**')[-1] in k):
202
- print("Deleting key {} from state_dict.".format(k))
203
  del sd[k]
204
  missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
205
  sd, strict=False)
206
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
207
  if len(missing) > 0:
208
- print(f"Missing Keys: {missing}")
209
  if len(unexpected) > 0:
210
- print(f"Unexpected Keys: {unexpected}")
211
 
212
  def q_mean_variance(self, x_start, t):
213
  """
@@ -274,196 +262,51 @@ class DDPM(pl.LightningModule):
274
 
275
  @torch.no_grad()
276
  def sample(self, batch_size=16, return_intermediates=False):
 
277
  channels = self.channels
278
- video_length = self.total_length
279
- size = (batch_size, channels, video_length, *self.image_size)
280
- return self.p_sample_loop(size,
281
  return_intermediates=return_intermediates)
282
 
283
  def q_sample(self, x_start, t, noise=None):
284
  noise = default(noise, lambda: torch.randn_like(x_start))
285
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
 
286
  extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
287
 
288
- def get_loss(self, pred, target, mean=True, mask=None):
289
- if self.loss_type == 'l1':
290
- loss = (target - pred).abs()
291
- if mean:
292
- loss = loss.mean()
293
- elif self.loss_type == 'l2':
294
- if mean:
295
- loss = torch.nn.functional.mse_loss(target, pred)
296
- else:
297
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
298
- else:
299
- raise NotImplementedError("unknown loss type '{loss_type}'")
300
- if mask is not None:
301
- assert(mean is False)
302
- assert(loss.shape[2:] == mask.shape[2:]) #thw need be the same
303
- loss = loss * mask
304
- return loss
305
-
306
- def p_losses(self, x_start, t, noise=None):
307
- noise = default(noise, lambda: torch.randn_like(x_start))
308
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
309
- model_out = self.model(x_noisy, t)
310
-
311
- loss_dict = {}
312
- if self.parameterization == "eps":
313
- target = noise
314
- elif self.parameterization == "x0":
315
- target = x_start
316
- else:
317
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
318
-
319
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3, 4])
320
-
321
- log_prefix = 'train' if self.training else 'val'
322
-
323
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
324
- loss_simple = loss.mean() * self.l_simple_weight
325
-
326
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
327
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
328
-
329
- loss = loss_simple + self.original_elbo_weight * loss_vlb
330
-
331
- loss_dict.update({f'{log_prefix}/loss': loss})
332
-
333
- return loss, loss_dict
334
-
335
- def forward(self, x, *args, **kwargs):
336
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
337
- return self.p_losses(x, t, *args, **kwargs)
338
-
339
  def get_input(self, batch, k):
340
  x = batch[k]
341
  x = x.to(memory_format=torch.contiguous_format).float()
342
  return x
343
 
344
- def shared_step(self, batch):
345
- x = self.get_input(batch, self.first_stage_key)
346
- loss, loss_dict = self(x)
347
- return loss, loss_dict
348
-
349
- def training_step(self, batch, batch_idx):
350
- loss, loss_dict = self.shared_step(batch)
351
-
352
- self.log_dict(loss_dict, prog_bar=True,
353
- logger=True, on_step=True, on_epoch=True)
354
-
355
- self.log("global_step", self.global_step,
356
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
357
-
358
- if self.use_scheduler:
359
- lr = self.optimizers().param_groups[0]['lr']
360
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
361
-
362
- if self.log_time:
363
- total_train_time = (time.time() - self.start_time) / (3600*24)
364
- avg_step_time = (time.time() - self.start_time) / (self.global_step + 1)
365
- left_time_2w_step = (20000-self.global_step -1) * avg_step_time / (3600*24)
366
- left_time_5w_step = (50000-self.global_step -1) * avg_step_time / (3600*24)
367
- with open(self.logger_path, 'w') as f:
368
- print(f'total_train_time = {total_train_time:.1f} days \n\
369
- total_train_step = {self.global_step + 1} steps \n\
370
- left_time_2w_step = {left_time_2w_step:.1f} days \n\
371
- left_time_5w_step = {left_time_5w_step:.1f} days', file=f)
372
- return loss
373
-
374
- @torch.no_grad()
375
- def validation_step(self, batch, batch_idx):
376
- # _, loss_dict_no_ema = self.shared_step_validate(batch)
377
- # with self.ema_scope():
378
- # _, loss_dict_ema = self.shared_step_validate(batch)
379
- # loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
380
- # self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
381
- # self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
382
- if (self.global_step) % self.val_fvd_interval == 0 and self.global_step != 0:
383
- print(f'sample for fvd...')
384
- self.log_images_kwargs = {
385
- 'inpaint': False,
386
- 'plot_diffusion_rows': False,
387
- 'plot_progressive_rows': False,
388
- 'ddim_steps': 50,
389
- 'unconditional_guidance_scale': 15.0,
390
- }
391
- torch.cuda.empty_cache()
392
- logs = self.log_images(batch, **self.log_images_kwargs)
393
- self.log("batch_idx", batch_idx,
394
- prog_bar=True, on_step=True, on_epoch=False)
395
- return {'real': logs['inputs'], 'fake': logs['samples'], 'conditioning_txt_img': logs['conditioning_txt_img']}
396
-
397
- def get_condition_validate(self, prompt):
398
- """ text embd
399
- """
400
- if isinstance(prompt, str):
401
- prompt = [prompt]
402
- c = self.get_learned_conditioning(prompt)
403
- bs = c.shape[0]
404
-
405
- return c
406
-
407
- def on_train_batch_end(self, *args, **kwargs):
408
- if self.use_ema:
409
- self.model_ema(self.model)
410
-
411
- def training_epoch_end(self, outputs):
412
-
413
- if (self.current_epoch == 0) or self.resume_new_epoch == 0:
414
- self.epoch_start_time = time.time()
415
- self.current_epoch_time = 0
416
- self.total_time = 0
417
- self.epoch_time_avg = 0
418
- else:
419
- self.current_epoch_time = time.time() - self.epoch_start_time
420
- self.epoch_start_time = time.time()
421
- self.total_time += self.current_epoch_time
422
- self.epoch_time_avg = self.total_time / self.current_epoch
423
- self.resume_new_epoch += 1
424
- epoch_avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
425
-
426
- self.log('train/epoch/loss', epoch_avg_loss, logger=True, on_epoch=True)
427
- self.log('train/epoch/idx', self.current_epoch, logger=True, on_epoch=True)
428
- self.log('train/epoch/time', self.current_epoch_time, logger=True, on_epoch=True)
429
- self.log('train/epoch/time_avg', self.epoch_time_avg, logger=True, on_epoch=True)
430
- self.log('train/epoch/time_avg_min', self.epoch_time_avg / 60, logger=True, on_epoch=True)
431
-
432
  def _get_rows_from_list(self, samples):
433
  n_imgs_per_row = len(samples)
434
- denoise_grid = rearrange(samples, 'n b c t h w -> b n c t h w')
435
- denoise_grid = rearrange(denoise_grid, 'b n c t h w -> (b n) c t h w')
436
- denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w')
437
  denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
438
  return denoise_grid
439
 
440
  @torch.no_grad()
441
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None,
442
- plot_diffusion_rows=True, plot_denoise_rows=True, **kwargs):
443
- """ log images for DDPM """
444
  log = dict()
445
  x = self.get_input(batch, self.first_stage_key)
446
  N = min(x.shape[0], N)
447
  n_row = min(x.shape[0], n_row)
448
  x = x.to(self.device)[:N]
449
  log["inputs"] = x
450
- if 'fps' in batch:
451
- log['fps'] = batch['fps']
452
 
453
- if plot_diffusion_rows:
454
- # get diffusion row
455
- diffusion_row = list()
456
- x_start = x[:n_row]
457
 
458
- for t in range(self.num_timesteps):
459
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
460
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
461
- t = t.to(self.device).long()
462
- noise = torch.randn_like(x_start)
463
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
464
- diffusion_row.append(x_noisy)
465
 
466
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
467
 
468
  if sample:
469
  # get denoise row
@@ -471,8 +314,7 @@ class DDPM(pl.LightningModule):
471
  samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
472
 
473
  log["samples"] = samples
474
- if plot_denoise_rows:
475
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
476
 
477
  if return_keys:
478
  if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
@@ -481,14 +323,6 @@ class DDPM(pl.LightningModule):
481
  return {key: log[key] for key in return_keys}
482
  return log
483
 
484
- def configure_optimizers(self):
485
- lr = self.learning_rate
486
- params = list(self.model.parameters())
487
- if self.learn_logvar:
488
- params = params + [self.logvar]
489
- opt = torch.optim.AdamW(params, lr=lr)
490
- return opt
491
-
492
 
493
  class LatentDiffusion(DDPM):
494
  """main class"""
@@ -496,36 +330,51 @@ class LatentDiffusion(DDPM):
496
  first_stage_config,
497
  cond_stage_config,
498
  num_timesteps_cond=None,
499
- cond_stage_key="image",
500
  cond_stage_trainable=False,
501
- concat_mode=True,
502
  cond_stage_forward=None,
503
  conditioning_key=None,
 
 
504
  scale_factor=1.0,
505
  scale_by_std=False,
506
  encoder_type="2d",
507
- shift_factor=0.0,
508
- split_clips=True,
509
- downfactor_t=None,
510
- clip_length=None,
511
  only_model=False,
512
- lora_args={},
 
 
 
 
513
  *args, **kwargs):
514
  self.num_timesteps_cond = default(num_timesteps_cond, 1)
515
  self.scale_by_std = scale_by_std
516
  assert self.num_timesteps_cond <= kwargs['timesteps']
517
  # for backwards compatibility after implementation of DiffusionWrapper
518
-
519
- if conditioning_key is None:
520
- conditioning_key = 'concat' if concat_mode else 'crossattn'
521
- if cond_stage_config == '__is_unconditional__':
522
- conditioning_key = None
523
  ckpt_path = kwargs.pop("ckpt_path", None)
524
  ignore_keys = kwargs.pop("ignore_keys", [])
 
525
  super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
526
- self.concat_mode = concat_mode
527
  self.cond_stage_trainable = cond_stage_trainable
528
  self.cond_stage_key = cond_stage_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  try:
530
  self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
531
  except:
@@ -536,71 +385,44 @@ class LatentDiffusion(DDPM):
536
  self.register_buffer('scale_factor', torch.tensor(scale_factor))
537
  self.instantiate_first_stage(first_stage_config)
538
  self.instantiate_cond_stage(cond_stage_config)
539
- self.cond_stage_forward = cond_stage_forward
540
- self.clip_denoised = False
541
- self.bbox_tokenizer = None
542
- self.cond_stage_config = cond_stage_config
543
  self.first_stage_config = first_stage_config
 
 
 
 
544
  self.encoder_type = encoder_type
545
  assert(encoder_type in ["2d", "3d"])
 
 
 
 
 
 
546
  self.restarted_from_ckpt = False
547
- self.shift_factor = shift_factor
548
  if ckpt_path is not None:
549
  self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
550
  self.restarted_from_ckpt = True
551
- self.split_clips = split_clips
552
- self.downfactor_t = downfactor_t
553
- self.clip_length = clip_length
554
- # lora related args
555
- self.inject_unet = getattr(lora_args, "inject_unet", False)
556
- self.inject_clip = getattr(lora_args, "inject_clip", False)
557
- self.inject_unet_key_word = getattr(lora_args, "inject_unet_key_word", None)
558
- self.inject_clip_key_word = getattr(lora_args, "inject_clip_key_word", None)
559
- self.lora_rank = getattr(lora_args, "lora_rank", 4)
560
 
561
  def make_cond_schedule(self, ):
562
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
563
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
564
  self.cond_ids[:self.num_timesteps_cond] = ids
565
 
566
- def inject_lora(self, lora_scale=1.0):
567
- if self.inject_unet:
568
- self.lora_require_grad_params, self.lora_names = inject_trainable_lora(self.model, self.inject_unet_key_word,
569
- r=self.lora_rank,
570
- scale=lora_scale
571
- )
572
- if self.inject_clip:
573
- self.lora_require_grad_params_clip, self.lora_names_clip = inject_trainable_lora(self.cond_stage_model, self.inject_clip_key_word,
574
- r=self.lora_rank,
575
- scale=lora_scale
576
- )
577
-
578
- @rank_zero_only
579
- @torch.no_grad()
580
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
581
- # only for very first batch, reset the self.scale_factor
582
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
583
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
584
- # set rescale weight to 1./std of encodings
585
- print("### USING STD-RESCALING ###")
586
- x = super().get_input(batch, self.first_stage_key)
587
- x = x.to(self.device)
588
- encoder_posterior = self.encode_first_stage(x)
589
- z = self.get_first_stage_encoding(encoder_posterior).detach()
590
- del self.scale_factor
591
- self.register_buffer('scale_factor', 1. / z.flatten().std())
592
- print(f"setting self.scale_factor to {self.scale_factor}")
593
- print("### USING STD-RESCALING ###")
594
- print(f"std={z.flatten().std()}")
595
-
596
- def register_schedule(self,
597
- given_betas=None, beta_schedule="linear", timesteps=1000,
598
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
599
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
600
 
601
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
602
- if self.shorten_cond_schedule:
603
- self.make_cond_schedule()
 
604
 
605
  def instantiate_first_stage(self, config):
606
  model = instantiate_from_config(config)
@@ -610,40 +432,16 @@ class LatentDiffusion(DDPM):
610
  param.requires_grad = False
611
 
612
  def instantiate_cond_stage(self, config):
613
- if config is None:
614
- self.cond_stage_model = None
615
- return
616
  if not self.cond_stage_trainable:
617
- if config == "__is_first_stage__":
618
- print("Using first stage also as cond stage.")
619
- self.cond_stage_model = self.first_stage_model
620
- elif config == "__is_unconditional__":
621
- print(f"Training {self.__class__.__name__} as an unconditional model.")
622
- self.cond_stage_model = None
623
- else:
624
- model = instantiate_from_config(config)
625
- self.cond_stage_model = model.eval()
626
- self.cond_stage_model.train = disabled_train
627
- for param in self.cond_stage_model.parameters():
628
- param.requires_grad = False
629
  else:
630
- assert config != '__is_first_stage__'
631
- assert config != '__is_unconditional__'
632
  model = instantiate_from_config(config)
633
  self.cond_stage_model = model
634
-
635
-
636
- def get_first_stage_encoding(self, encoder_posterior, noise=None):
637
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
638
- z = encoder_posterior.sample(noise=noise)
639
- elif isinstance(encoder_posterior, torch.Tensor):
640
- z = encoder_posterior
641
- else:
642
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
643
- z = self.scale_factor * (z + self.shift_factor)
644
- return z
645
-
646
-
647
  def get_learned_conditioning(self, c):
648
  if self.cond_stage_forward is None:
649
  if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
@@ -657,197 +455,61 @@ class LatentDiffusion(DDPM):
657
  c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
658
  return c
659
 
660
-
661
- @torch.no_grad()
662
- def get_condition(self, batch, x, bs, force_c_encode, k, cond_key, is_imgs=False):
663
- is_conditional = self.model.conditioning_key is not None # crossattn
664
- if is_conditional:
665
- if cond_key is None:
666
- cond_key = self.cond_stage_key
667
-
668
- # get condition batch of different condition type
669
- if cond_key != self.first_stage_key:
670
- assert(cond_key in ["caption", "txt"])
671
- xc = batch[cond_key]
672
- else:
673
- xc = x
674
-
675
- # if static video
676
- if self.static_video:
677
- xc_ = [c + ' (static)' for c in xc]
678
- xc = xc_
679
-
680
- # get learned condition.
681
- # can directly skip it: c = xc
682
- if self.cond_stage_config is not None and (not self.cond_stage_trainable or force_c_encode):
683
- if isinstance(xc, torch.Tensor):
684
- xc = xc.to(self.device)
685
- c = self.get_learned_conditioning(xc)
686
- else:
687
- c = xc
688
-
689
- if self.classfier_free_guidance:
690
- if cond_key in ['caption', "txt"] and self.uncond_type == 'empty_seq':
691
- for i, ci in enumerate(c):
692
- if random.random() < self.prob:
693
- c[i] = ""
694
- elif cond_key == 'class_label' and self.uncond_type == 'zero_embed':
695
- pass
696
- elif cond_key == 'class_label' and self.uncond_type == 'learned_embed':
697
- import pdb;pdb.set_trace()
698
- for i, ci in enumerate(c):
699
- if random.random() < self.prob:
700
- c[i]['class_label'] = self.n_classes
701
-
702
- else:
703
- raise NotImplementedError
704
-
705
- if self.zero_cond_embed:
706
- import pdb;pdb.set_trace()
707
- c = torch.zeros_like(c)
708
-
709
- # process c
710
- if bs is not None:
711
- if (is_imgs and not self.static_video):
712
- c = c[:bs*self.temporal_length] # each random img (in T axis) has a corresponding prompt
713
- else:
714
- c = c[:bs]
715
-
716
  else:
717
- c = None
718
- xc = None
719
-
720
- return c, xc
721
-
722
  @torch.no_grad()
723
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
724
- cond_key=None, return_original_cond=False, bs=None, mask_temporal=False):
725
- """ Get input in LDM
726
- """
727
- # get input imgaes
728
- x = super().get_input(batch, k) # k = first_stage_key=image
729
- is_imgs = True if k == 'jpg' else False
730
- if is_imgs:
731
- if self.static_video:
732
- # repeat single img to a static video
733
- x = x.unsqueeze(2) # bchw -> bc1hw
734
- x = x.repeat(1,1,self.temporal_length,1,1) # bc1hw -> bcthw
735
- else:
736
- # rearrange to videos with T random img
737
- bs_load = x.shape[0] // self.temporal_length
738
- x = x[:bs_load*self.temporal_length, ...]
739
- x = rearrange(x, '(b t) c h w -> b c t h w', t=self.temporal_length, b=bs_load)
740
-
741
- if bs is not None:
742
- x = x[:bs]
743
-
744
- x = x.to(self.device)
745
- x_ori = x
746
-
747
- b, _, t, h, w = x.shape
748
-
749
- # encode video frames x to z via a 2D encoder
750
- x = rearrange(x, 'b c t h w -> (b t) c h w')
751
- encoder_posterior = self.encode_first_stage(x, mask_temporal)
752
- z = self.get_first_stage_encoding(encoder_posterior).detach()
753
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
754
-
755
 
756
- c, xc = self.get_condition(batch, x, bs, force_c_encode, k, cond_key, is_imgs)
757
- out = [z, c]
758
 
759
- if return_first_stage_outputs:
760
- xrec = self.decode_first_stage(z, mask_temporal=mask_temporal)
761
- out.extend([x_ori, xrec])
762
- if return_original_cond:
763
- if isinstance(xc, torch.Tensor) and xc.dim() == 4:
764
- xc = rearrange(xc, '(b t) c h w -> b c t h w', b=b, t=t)
765
- out.append(xc)
766
 
767
- return out
768
-
769
- @torch.no_grad()
770
- def decode(self, z, **kwargs,):
771
- z = 1. / self.scale_factor * z - self.shift_factor
772
- results = self.first_stage_model.decode(z,**kwargs)
773
  return results
774
 
775
  @torch.no_grad()
776
- def decode_first_stage_2DAE(self, z, decode_bs=16, return_cpu=True, **kwargs):
777
- b, _, t, _, _ = z.shape
778
- z = rearrange(z, 'b c t h w -> (b t) c h w')
779
- if decode_bs is None:
780
- results = self.decode(z, **kwargs)
781
- else:
782
- z = torch.split(z, decode_bs, dim=0)
783
- if return_cpu:
784
- results = torch.cat([self.decode(z_, **kwargs).cpu() for z_ in z], dim=0)
785
- else:
786
- results = torch.cat([self.decode(z_, **kwargs) for z_ in z], dim=0)
787
- results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t).contiguous()
788
- return results
789
-
790
- @torch.no_grad()
791
- def decode_first_stage(self, z, decode_bs=16, return_cpu=True, **kwargs):
792
- assert(self.encoder_type == "2d" and z.dim() == 5)
793
- return self.decode_first_stage_2DAE(z, decode_bs=decode_bs, return_cpu=return_cpu, **kwargs)
794
 
795
- @torch.no_grad()
796
- def encode_first_stage_2DAE(self, x, encode_bs=16):
797
  b, _, t, _, _ = x.shape
798
- x = rearrange(x, 'b c t h w -> (b t) c h w')
799
- if encode_bs is None:
800
- results = self.first_stage_model.encode(x)
801
- else:
802
- x = torch.split(x, encode_bs, dim=0)
803
- zs = []
804
- for x_ in x:
805
- encoder_posterior = self.first_stage_model.encode(x_)
806
- z = self.get_first_stage_encoding(encoder_posterior).detach()
807
- zs.append(z)
808
- results = torch.cat(zs, dim=0)
809
- results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
810
  return results
811
 
812
- @torch.no_grad()
813
- def encode_first_stage(self, x):
814
- assert(self.encoder_type == "2d" and x.dim() == 5)
815
- b, _, t, _, _ = x.shape
816
- x = rearrange(x, 'b c t h w -> (b t) c h w')
817
- results = self.first_stage_model.encode(x)
818
- results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
819
- return results
 
820
 
821
- def shared_step(self, batch, **kwargs):
822
- """ shared step of LDM.
823
- If learned condition, c is raw condition (e.g. text)
824
- Encoding condition is performed in below forward function.
825
- """
826
- x, c = self.get_input(batch, self.first_stage_key)
827
- loss = self(x, c)
828
- return loss
829
-
830
- def forward(self, x, c, *args, **kwargs):
831
- start_t = getattr(self, "start_t", 0)
832
- end_t = getattr(self, "end_t", self.num_timesteps)
833
- t = torch.randint(start_t, end_t, (x.shape[0],), device=self.device).long()
834
-
835
- if self.model.conditioning_key is not None:
836
- assert c is not None
837
- if self.cond_stage_trainable:
838
- c = self.get_learned_conditioning(c)
839
- if self.classfier_free_guidance and self.uncond_type == 'zero_embed':
840
- for i, ci in enumerate(c):
841
- if random.random() < self.prob:
842
- c[i] = torch.zeros_like(c[i])
843
- if self.shorten_cond_schedule: # TODO: drop this option
844
- tc = self.cond_ids[t].to(self.device)
845
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
846
-
847
- return self.p_losses(x, c, t, *args, **kwargs)
848
 
849
- def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
 
 
850
 
 
851
  if isinstance(cond, dict):
852
  # hybrid case, cond is exptected to be a dict
853
  pass
@@ -859,104 +521,55 @@ class LatentDiffusion(DDPM):
859
 
860
  x_recon = self.model(x_noisy, t, **cond, **kwargs)
861
 
862
- if isinstance(x_recon, tuple) and not return_ids:
863
  return x_recon[0]
864
  else:
865
  return x_recon
866
 
867
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
868
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
869
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
 
 
870
 
871
- def _prior_bpd(self, x_start):
872
- """
873
- Get the prior KL term for the variational lower-bound, measured in
874
- bits-per-dim.
875
- This term can't be optimized, as it only depends on the encoder.
876
- :param x_start: the [N x C x ...] tensor of inputs.
877
- :return: a batch of [N] KL values (in bits), one per batch element.
878
- """
879
- batch_size = x_start.shape[0]
880
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
881
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
882
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
883
- return mean_flat(kl_prior) / np.log(2.0)
884
-
885
- def p_losses(self, x_start, cond, t, noise=None, skip_qsample=False, x_noisy=None, cond_mask=None, **kwargs,):
886
- if not skip_qsample:
887
- noise = default(noise, lambda: torch.randn_like(x_start))
888
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
889
  else:
890
- assert(x_noisy is not None)
891
- assert(noise is not None)
892
- model_output = self.apply_model(x_noisy, t, cond, **kwargs)
893
 
894
- loss_dict = {}
895
- prefix = 'train' if self.training else 'val'
896
 
897
- if self.parameterization == "x0":
898
- target = x_start
899
- elif self.parameterization == "eps":
900
- target = noise
901
- else:
902
- raise NotImplementedError()
903
-
904
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3, 4])
905
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
906
- if self.logvar.device != self.device:
907
- self.logvar = self.logvar.to(self.device)
908
- logvar_t = self.logvar[t]
909
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
910
- if self.learn_logvar:
911
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
912
- loss_dict.update({'logvar': self.logvar.data.mean()})
913
 
914
- loss = self.l_simple_weight * loss.mean()
 
 
915
 
916
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3, 4))
917
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
918
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
919
- loss += (self.original_elbo_weight * loss_vlb)
920
- loss_dict.update({f'{prefix}/loss': loss})
921
 
922
- return loss, loss_dict
923
 
924
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
925
- return_x0=False, score_corrector=None, corrector_kwargs=None,
926
- unconditional_guidance_scale=1., unconditional_conditioning=None,
927
- uc_type=None,):
928
  t_in = t
929
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
930
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
931
- else:
932
- # with unconditional condition
933
- if isinstance(c, torch.Tensor):
934
- x_in = torch.cat([x] * 2)
935
- t_in = torch.cat([t] * 2)
936
- c_in = torch.cat([unconditional_conditioning, c])
937
- model_out_uncond, model_out = self.apply_model(x_in, t_in, c_in, return_ids=return_codebook_ids).chunk(2)
938
- elif isinstance(c, dict):
939
- model_out = self.apply_model(x, t, c, return_ids=return_codebook_ids)
940
- model_out_uncond = self.apply_model(x, t, unconditional_conditioning, return_ids=return_codebook_ids)
941
- else:
942
- raise NotImplementedError
943
- if uc_type is None:
944
- model_out = model_out_uncond + unconditional_guidance_scale * (model_out - model_out_uncond)
945
- else:
946
- if uc_type == 'cfg_original':
947
- model_out = model_out + unconditional_guidance_scale * (model_out - model_out_uncond)
948
- elif uc_type == 'cfg_ours':
949
- model_out = model_out + unconditional_guidance_scale * (model_out_uncond - model_out)
950
- else:
951
- raise NotImplementedError
952
 
953
  if score_corrector is not None:
954
  assert self.parameterization == "eps"
955
  model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
956
 
957
- if return_codebook_ids:
958
- model_out, logits = model_out
959
-
960
  if self.parameterization == "eps":
961
  x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
962
  elif self.parameterization == "x0":
@@ -966,34 +579,21 @@ class LatentDiffusion(DDPM):
966
 
967
  if clip_denoised:
968
  x_recon.clamp_(-1., 1.)
969
- if quantize_denoised:
970
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
971
  model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
972
- if return_codebook_ids:
973
- return model_mean, posterior_variance, posterior_log_variance, logits
974
- elif return_x0:
975
  return model_mean, posterior_variance, posterior_log_variance, x_recon
976
  else:
977
  return model_mean, posterior_variance, posterior_log_variance
978
 
979
  @torch.no_grad()
980
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
981
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
982
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
983
- unconditional_guidance_scale=1., unconditional_conditioning=None,
984
- uc_type=None,):
985
  b, *_, device = *x.shape, x.device
986
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
987
- return_codebook_ids=return_codebook_ids,
988
- quantize_denoised=quantize_denoised,
989
- return_x0=return_x0,
990
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs,
991
- unconditional_guidance_scale=unconditional_guidance_scale,
992
- unconditional_conditioning=unconditional_conditioning,
993
- uc_type=uc_type,)
994
- if return_codebook_ids:
995
- raise DeprecationWarning("Support dropped.")
996
- elif return_x0:
997
  model_mean, _, model_log_variance, x0 = outputs
998
  else:
999
  model_mean, _, model_log_variance = outputs
@@ -1001,99 +601,35 @@ class LatentDiffusion(DDPM):
1001
  noise = noise_like(x.shape, device, repeat_noise) * temperature
1002
  if noise_dropout > 0.:
1003
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1004
-
1005
  nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1006
 
1007
- if return_codebook_ids:
1008
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1009
  if return_x0:
1010
  return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1011
  else:
1012
  return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1013
 
1014
  @torch.no_grad()
1015
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1016
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1017
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1018
- log_every_t=None):
1019
- if not log_every_t:
1020
- log_every_t = self.log_every_t
1021
- timesteps = self.num_timesteps
1022
- if batch_size is not None:
1023
- b = batch_size if batch_size is not None else shape[0]
1024
- shape = [batch_size] + list(shape)
1025
- else:
1026
- b = batch_size = shape[0]
1027
- if x_T is None:
1028
- img = torch.randn(shape, device=self.device)
1029
- else:
1030
- img = x_T
1031
- intermediates = []
1032
- if cond is not None:
1033
- if isinstance(cond, dict):
1034
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1035
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1036
- else:
1037
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1038
-
1039
- if start_T is not None:
1040
- timesteps = min(timesteps, start_T)
1041
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1042
- total=timesteps) if verbose else reversed(
1043
- range(0, timesteps))
1044
- if type(temperature) == float:
1045
- temperature = [temperature] * timesteps
1046
-
1047
- for i in iterator:
1048
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1049
- if self.shorten_cond_schedule:
1050
- assert self.model.conditioning_key != 'hybrid'
1051
- tc = self.cond_ids[ts].to(cond.device)
1052
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1053
-
1054
- img, x0_partial = self.p_sample(img, cond, ts,
1055
- clip_denoised=self.clip_denoised,
1056
- quantize_denoised=quantize_denoised, return_x0=True,
1057
- temperature=temperature[i], noise_dropout=noise_dropout,
1058
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1059
- if mask is not None:
1060
- assert x0 is not None
1061
- img_orig = self.q_sample(x0, ts)
1062
- img = img_orig * mask + (1. - mask) * img
1063
-
1064
- if i % log_every_t == 0 or i == timesteps - 1:
1065
- intermediates.append(x0_partial)
1066
- if callback: callback(i)
1067
- if img_callback: img_callback(img, i)
1068
- return img, intermediates
1069
-
1070
- @torch.no_grad()
1071
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1072
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1073
- mask=None, x0=None, img_callback=None, start_T=None,
1074
- log_every_t=None,
1075
- unconditional_guidance_scale=1., unconditional_conditioning=None,
1076
- uc_type=None,):
1077
 
1078
  if not log_every_t:
1079
  log_every_t = self.log_every_t
1080
  device = self.betas.device
1081
- b = shape[0]
1082
-
1083
  # sample an initial noise
1084
  if x_T is None:
1085
  img = torch.randn(shape, device=device)
1086
  else:
1087
  img = x_T
1088
-
1089
  intermediates = [img]
1090
  if timesteps is None:
1091
  timesteps = self.num_timesteps
1092
-
1093
  if start_T is not None:
1094
  timesteps = min(timesteps, start_T)
1095
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1096
- range(0, timesteps))
1097
 
1098
  if mask is not None:
1099
  assert x0 is not None
@@ -1106,12 +642,7 @@ class LatentDiffusion(DDPM):
1106
  tc = self.cond_ids[ts].to(cond.device)
1107
  cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1108
 
1109
- img = self.p_sample(img, cond, ts,
1110
- clip_denoised=self.clip_denoised,
1111
- quantize_denoised=quantize_denoised,
1112
- unconditional_guidance_scale=unconditional_guidance_scale,
1113
- unconditional_conditioning=unconditional_conditioning,
1114
- uc_type=uc_type)
1115
  if mask is not None:
1116
  img_orig = self.q_sample(x0, ts)
1117
  img = img_orig * mask + (1. - mask) * img
@@ -1125,253 +656,54 @@ class LatentDiffusion(DDPM):
1125
  return img, intermediates
1126
  return img
1127
 
1128
- @torch.no_grad()
1129
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1130
- verbose=True, timesteps=None, quantize_denoised=False,
1131
- mask=None, x0=None, shape=None, **kwargs):
1132
- if shape is None:
1133
- shape = (batch_size, self.channels, self.total_length, *self.image_size)
1134
- if cond is not None:
1135
- if isinstance(cond, dict):
1136
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1137
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1138
- else:
1139
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1140
- return self.p_sample_loop(cond,
1141
- shape,
1142
- return_intermediates=return_intermediates, x_T=x_T,
1143
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1144
- mask=mask, x0=x0,)
1145
-
1146
- @torch.no_grad()
1147
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1148
-
1149
- if ddim:
1150
- ddim_sampler = DDIMSampler(self)
1151
- shape = (self.channels, self.total_length, *self.image_size)
1152
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1153
- shape,cond,verbose=False, **kwargs)
1154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1155
  else:
1156
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1157
- return_intermediates=True, **kwargs)
 
 
1158
 
1159
- return samples, intermediates
1160
-
1161
- @torch.no_grad()
1162
- def log_condition(self, log, batch, xc, x, c, cond_stage_key=None):
1163
- """
1164
- xc: oringinal condition before enconding.
1165
- c: condition after encoding.
1166
- """
1167
- if x.dim() == 5:
1168
- txt_img_shape = [x.shape[3], x.shape[4]]
1169
- elif x.dim() == 4:
1170
- txt_img_shape = [x.shape[2], x.shape[3]]
1171
- else:
1172
- raise ValueError
1173
- if self.model.conditioning_key is not None: #concat-time-mask
1174
- if hasattr(self.cond_stage_model, "decode"):
1175
- xc = self.cond_stage_model.decode(c)
1176
- log["conditioning"] = xc
1177
- elif cond_stage_key in ["caption", "txt"]:
1178
- log["conditioning_txt_img"] = log_txt_as_img(txt_img_shape, batch[cond_stage_key], size=x.shape[3]//25)
1179
- log["conditioning_txt"] = batch[cond_stage_key]
1180
- elif cond_stage_key == 'class_label':
1181
- try:
1182
- xc = log_txt_as_img(txt_img_shape, batch["human_label"], size=x.shape[3]//25)
1183
- except:
1184
- xc = log_txt_as_img(txt_img_shape, batch["class_name"], size=x.shape[3]//25)
1185
- log['conditioning'] = xc
1186
- elif isimage(xc):
1187
- log["conditioning"] = xc
1188
- if ismap(xc):
1189
- log["original_conditioning"] = self.to_rgb(xc)
1190
- if isinstance(c, dict) and 'mask' in c:
1191
- log['mask'] =self.mask_to_rgb(c['mask'])
1192
- return log
1193
-
1194
- @torch.no_grad()
1195
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., unconditional_guidance_scale=1.0,
1196
- first_stage_key2=None, cond_key2=None,
1197
- c=None,
1198
- **kwargs):
1199
- """ log images for LatentDiffusion """
1200
- use_ddim = ddim_steps is not None
1201
- is_imgs = first_stage_key2 is not None
1202
- if is_imgs:
1203
- assert(cond_key2 is not None)
1204
- log = dict()
1205
-
1206
- # get input
1207
- z, c, x, xrec, xc = self.get_input(batch,
1208
- k=self.first_stage_key if first_stage_key2 is None else first_stage_key2,
1209
- return_first_stage_outputs=True,
1210
- force_c_encode=True,
1211
- return_original_cond=True,
1212
- bs=N,
1213
- cond_key=cond_key2 if cond_key2 is not None else None,
1214
- )
1215
-
1216
- N_ori = N
1217
- N = min(z.shape[0], N)
1218
- n_row = min(x.shape[0], n_row)
1219
 
1220
- if unconditional_guidance_scale != 1.0:
1221
- prompts = N * self.temporal_length * [""] if (is_imgs and not self.static_video) else N * [""]
1222
- uc = self.get_condition_validate(prompts)
1223
-
1224
- else:
1225
- uc = None
1226
-
1227
- log["inputs"] = x
1228
- log["reconstruction"] = xrec
1229
- log = self.log_condition(log, batch, xc, x, c,
1230
- cond_stage_key=self.cond_stage_key if cond_key2 is None else cond_key2
1231
- )
1232
-
1233
- if sample:
1234
- with self.ema_scope("Plotting"):
1235
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1236
- ddim_steps=ddim_steps,eta=ddim_eta,
1237
- temporal_length=self.video_length,
1238
- unconditional_guidance_scale=unconditional_guidance_scale,
1239
- unconditional_conditioning=uc, **kwargs,
1240
- )
1241
- # decode samples
1242
- x_samples = self.decode_first_stage(samples)
1243
- log["samples"] = x_samples
1244
- return log
1245
-
1246
- def configure_optimizers(self):
1247
- """ configure_optimizers for LatentDiffusion """
1248
- lr = self.learning_rate
1249
-
1250
- # --------------------------------------------------------------------------------
1251
- # set parameters
1252
- if hasattr(self, "only_optimize_empty_parameters") and self.only_optimize_empty_parameters:
1253
- print("[INFO] Optimize only empty parameters!")
1254
- assert(hasattr(self, "empty_paras"))
1255
- params = [p for n, p in self.model.named_parameters() if n in self.empty_paras]
1256
- elif hasattr(self, "only_optimize_pretrained_parameters") and self.only_optimize_pretrained_parameters:
1257
- print("[INFO] Optimize only pretrained parameters!")
1258
- assert(hasattr(self, "empty_paras"))
1259
- params = [p for n, p in self.model.named_parameters() if n not in self.empty_paras]
1260
- assert(len(params) != 0)
1261
- elif getattr(self, "optimize_empty_and_spatialattn", False):
1262
- print("[INFO] Optimize empty parameters + spatial transformer!")
1263
- assert(hasattr(self, "empty_paras"))
1264
- empty_paras = [p for n, p in self.model.named_parameters() if n in self.empty_paras]
1265
- SA_list = [".attn1.", ".attn2.", ".ff.", ".norm1.", ".norm2.", ".norm3."]
1266
- SA_params = [p for n, p in self.model.named_parameters() if check_istarget(n, SA_list)]
1267
- if getattr(self, "spatial_lr_decay", False):
1268
- params = [
1269
- {"params": empty_paras},
1270
- {"params": SA_params, "lr": lr * self.spatial_lr_decay}
1271
- ]
1272
- else:
1273
- params = empty_paras + SA_params
1274
- else:
1275
- # optimize whole denoiser
1276
- if hasattr(self, "spatial_lr_decay") and self.spatial_lr_decay:
1277
- print("[INFO] Optimize the whole net with different lr!")
1278
- print(f"[INFO] {lr} for empty paras, {lr * self.spatial_lr_decay} for pretrained paras!")
1279
- empty_paras = [p for n, p in self.model.named_parameters() if n in self.empty_paras]
1280
- # assert(len(empty_paras) == len(self.empty_paras)) # self.empty_paras:cond_stage_model.embedding.weight not in diffusion model params
1281
- pretrained_paras = [p for n, p in self.model.named_parameters() if n not in self.empty_paras]
1282
- params = [
1283
- {"params": empty_paras},
1284
- {"params": pretrained_paras, "lr": lr * self.spatial_lr_decay}
1285
- ]
1286
- print(f"[INFO] Empty paras: {len(empty_paras)}, Pretrained paras: {len(pretrained_paras)}")
1287
-
1288
- else:
1289
- params = list(self.model.parameters())
1290
-
1291
- if hasattr(self, "generator_trainable") and not self.generator_trainable:
1292
- # fix unet denoiser
1293
- params = list()
1294
-
1295
- if self.inject_unet:
1296
- params = itertools.chain(*self.lora_require_grad_params)
1297
-
1298
- if self.inject_clip:
1299
- if self.inject_unet:
1300
- params = list(params)+list(itertools.chain(*self.lora_require_grad_params_clip))
1301
- else:
1302
- params = itertools.chain(*self.lora_require_grad_params_clip)
1303
-
1304
-
1305
- # append paras
1306
- # ------------------------------------------------------------------
1307
- def add_cond_model(cond_model, params):
1308
- if isinstance(params[0], dict):
1309
- # parameter groups
1310
- params.append({"params": list(cond_model.parameters())})
1311
- else:
1312
- # parameter list: [torch.nn.parameter.Parameter]
1313
- params = params + list(cond_model.parameters())
1314
- return params
1315
- # ------------------------------------------------------------------
1316
-
1317
- if self.cond_stage_trainable:
1318
- # print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1319
- params = add_cond_model(self.cond_stage_model, params)
1320
-
1321
- if self.learn_logvar:
1322
- print('Diffusion model optimizing logvar')
1323
- if isinstance(params[0], dict):
1324
- params.append({"params": [self.logvar]})
1325
- else:
1326
- params.append(self.logvar)
1327
-
1328
- # --------------------------------------------------------------------------------
1329
- opt = torch.optim.AdamW(params, lr=lr)
1330
-
1331
- # lr scheduler
1332
- if self.use_scheduler:
1333
- assert 'target' in self.scheduler_config
1334
- scheduler = instantiate_from_config(self.scheduler_config)
1335
-
1336
- print("Setting up LambdaLR scheduler...")
1337
- scheduler = [
1338
- {
1339
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1340
- 'interval': 'step',
1341
- 'frequency': 1
1342
- }]
1343
- return [opt], scheduler
1344
-
1345
- return opt
1346
-
1347
- @torch.no_grad()
1348
- def to_rgb(self, x):
1349
- x = x.float()
1350
- if not hasattr(self, "colorize"):
1351
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1352
- x = nn.functional.conv2d(x, weight=self.colorize)
1353
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1354
- return x
1355
-
1356
- @torch.no_grad()
1357
- def mask_to_rgb(self, x):
1358
- x = x * 255
1359
- x = x.int()
1360
- return x
1361
 
1362
  class DiffusionWrapper(pl.LightningModule):
1363
  def __init__(self, diff_model_config, conditioning_key):
1364
  super().__init__()
1365
  self.diffusion_model = instantiate_from_config(diff_model_config)
1366
- print('Successfully initialize the diffusion model !')
1367
  self.conditioning_key = conditioning_key
1368
- # assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'resblockcond', 'hybrid-adm', 'hybrid-time']
1369
 
1370
  def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,
1371
  c_adm=None, s=None, mask=None, **kwargs):
1372
  # temporal_context = fps is foNone
1373
  if self.conditioning_key is None:
1374
- out = self.diffusion_model(x, t, **kwargs)
1375
  elif self.conditioning_key == 'concat':
1376
  xc = torch.cat([x] + c_concat, dim=1)
1377
  out = self.diffusion_model(xc, t, **kwargs)
@@ -1379,106 +711,53 @@ class DiffusionWrapper(pl.LightningModule):
1379
  cc = torch.cat(c_crossattn, 1)
1380
  out = self.diffusion_model(x, t, context=cc, **kwargs)
1381
  elif self.conditioning_key == 'hybrid':
 
1382
  xc = torch.cat([x] + c_concat, dim=1)
1383
  cc = torch.cat(c_crossattn, 1)
1384
- out = self.diffusion_model(xc, t, context=cc, **kwargs)
1385
  elif self.conditioning_key == 'resblockcond':
1386
  cc = c_crossattn[0]
1387
- out = self.diffusion_model(x, t, context=cc, **kwargs)
1388
  elif self.conditioning_key == 'adm':
1389
  cc = c_crossattn[0]
1390
- out = self.diffusion_model(x, t, y=cc, **kwargs)
1391
  elif self.conditioning_key == 'hybrid-adm':
1392
  assert c_adm is not None
1393
  xc = torch.cat([x] + c_concat, dim=1)
1394
  cc = torch.cat(c_crossattn, 1)
1395
- out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs)
1396
  elif self.conditioning_key == 'hybrid-time':
1397
  assert s is not None
1398
  xc = torch.cat([x] + c_concat, dim=1)
1399
  cc = torch.cat(c_crossattn, 1)
1400
- out = self.diffusion_model(xc, t, context=cc, s=s, **kwargs)
1401
  elif self.conditioning_key == 'concat-time-mask':
1402
  # assert s is not None
1403
- # print('x & mask:',x.shape,c_concat[0].shape)
1404
  xc = torch.cat([x] + c_concat, dim=1)
1405
- out = self.diffusion_model(xc, t, context=None, s=s, mask=mask, **kwargs)
1406
  elif self.conditioning_key == 'concat-adm-mask':
1407
  # assert s is not None
1408
- # print('x & mask:',x.shape,c_concat[0].shape)
1409
  if c_concat is not None:
1410
  xc = torch.cat([x] + c_concat, dim=1)
1411
  else:
1412
  xc = x
1413
- out = self.diffusion_model(xc, t, context=None, y=s, mask=mask, **kwargs)
1414
- elif self.conditioning_key == 'crossattn-adm':
1415
- cc = torch.cat(c_crossattn, 1)
1416
- out = self.diffusion_model(x, t, context=cc, y=s, **kwargs)
1417
  elif self.conditioning_key == 'hybrid-adm-mask':
1418
  cc = torch.cat(c_crossattn, 1)
1419
  if c_concat is not None:
1420
  xc = torch.cat([x] + c_concat, dim=1)
1421
  else:
1422
  xc = x
1423
- out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask, **kwargs)
1424
  elif self.conditioning_key == 'hybrid-time-adm': # adm means y, e.g., class index
1425
  # assert s is not None
1426
  assert c_adm is not None
1427
  xc = torch.cat([x] + c_concat, dim=1)
1428
  cc = torch.cat(c_crossattn, 1)
1429
- out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm, **kwargs)
1430
  else:
1431
  raise NotImplementedError()
1432
 
1433
- return out
1434
-
1435
-
1436
- class T2VAdapterDepth(LatentDiffusion):
1437
- def __init__(self, depth_stage_config, adapter_config, *args, **kwargs):
1438
- super(T2VAdapterDepth, self).__init__(*args, **kwargs)
1439
- self.adapter = instantiate_from_config(adapter_config)
1440
- self.condtype = adapter_config.cond_name
1441
- self.depth_stage_model = instantiate_from_config(depth_stage_config)
1442
-
1443
- def prepare_midas_input(self, batch_x):
1444
- # input: b,c,h,w
1445
- x_midas = torch.nn.functional.interpolate(batch_x, size=(384, 384), mode='bicubic')
1446
- return x_midas
1447
-
1448
- @torch.no_grad()
1449
- def get_batch_depth(self, batch_x, target_size, encode_bs=1):
1450
- b, c, t, h, w = batch_x.shape
1451
- merge_x = rearrange(batch_x, 'b c t h w -> (b t) c h w')
1452
- split_x = torch.split(merge_x, encode_bs, dim=0)
1453
- cond_depth_list = []
1454
- for x in split_x:
1455
- x_midas = self.prepare_midas_input(x)
1456
- cond_depth = self.depth_stage_model(x_midas)
1457
- cond_depth = torch.nn.functional.interpolate(
1458
- cond_depth,
1459
- size=target_size,
1460
- mode="bicubic",
1461
- align_corners=False,
1462
- )
1463
- depth_min, depth_max = torch.amin(cond_depth, dim=[1, 2, 3], keepdim=True), torch.amax(cond_depth, dim=[1, 2, 3], keepdim=True)
1464
- cond_depth = 2. * (cond_depth - depth_min) / (depth_max - depth_min + 1e-7) - 1.
1465
- cond_depth_list.append(cond_depth)
1466
- batch_cond_depth=torch.cat(cond_depth_list, dim=0)
1467
- batch_cond_depth = rearrange(batch_cond_depth, '(b t) c h w -> b c t h w', b=b, t=t)
1468
- return batch_cond_depth
1469
-
1470
- def get_adapter_features(self, extra_cond, encode_bs=1):
1471
- b, c, t, h, w = extra_cond.shape
1472
- ## process in 2D manner
1473
- merge_extra_cond = rearrange(extra_cond, 'b c t h w -> (b t) c h w')
1474
- split_extra_cond = torch.split(merge_extra_cond, encode_bs, dim=0)
1475
- features_adapter_list = []
1476
- for extra_cond in split_extra_cond:
1477
- features_adapter = self.adapter(extra_cond)
1478
- features_adapter_list.append(features_adapter)
1479
- merge_features_adapter_list = []
1480
- for i in range(len(features_adapter_list[0])):
1481
- merge_features_adapter = torch.cat([features_adapter_list[num][i] for num in range(len(features_adapter_list))], dim=0)
1482
- merge_features_adapter_list.append(merge_features_adapter)
1483
- merge_features_adapter_list = [rearrange(feature, '(b t) c h w -> b c t h w', b=b, t=t) for feature in merge_features_adapter_list]
1484
- return merge_features_adapter_list
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
4
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
  from functools import partial
10
  from contextlib import contextmanager
 
11
  import numpy as np
12
  from tqdm import tqdm
13
  from einops import rearrange, repeat
14
+ import logging
15
+ mainlogger = logging.getLogger('mainlogger')
16
  import torch
17
  import torch.nn as nn
 
18
  from torchvision.utils import make_grid
19
+ import pytorch_lightning as pl
20
+ from utils.utils import instantiate_from_config
21
+ from lvdm.ema import LitEma
22
+ from lvdm.distributions import DiagonalGaussianDistribution
23
+ from lvdm.models.utils_diffusion import make_beta_schedule
24
+ from lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler
25
+ from lvdm.basics import disabled_train
26
+ from lvdm.common import (
27
+ extract_into_tensor,
28
+ noise_like,
29
+ exists,
30
+ default
31
+ )
32
+
33
+
34
+ __conditioning_keys__ = {'concat': 'c_concat',
35
+ 'crossattn': 'c_crossattn',
36
+ 'adm': 'y'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  class DDPM(pl.LightningModule):
39
+ # classic DDPM with Gaussian diffusion, in image space
40
  def __init__(self,
41
  unet_config,
42
  timesteps=1000,
 
45
  ckpt_path=None,
46
  ignore_keys=[],
47
  load_only_unet=False,
48
+ monitor=None,
49
  use_ema=True,
50
  first_stage_key="image",
51
  image_size=256,
 
52
  channels=3,
53
  log_every_t=100,
54
  clip_denoised=True,
 
57
  cosine_s=8e-3,
58
  given_betas=None,
59
  original_elbo_weight=0.,
60
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
61
  l_simple_weight=1.,
62
  conditioning_key=None,
63
+ parameterization="eps", # all assuming fixed variance schedules
64
  scheduler_config=None,
65
+ use_positional_encodings=False,
66
  learn_logvar=False,
67
+ logvar_init=0.
 
68
  ):
69
  super().__init__()
70
  assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
71
  self.parameterization = parameterization
72
+ mainlogger.info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
73
  self.cond_stage_model = None
74
  self.clip_denoised = clip_denoised
75
  self.log_every_t = log_every_t
76
  self.first_stage_key = first_stage_key
77
+ self.channels = channels
78
+ self.temporal_length = unet_config.params.temporal_length
79
+ self.image_size = image_size
80
  if isinstance(self.image_size, int):
81
  self.image_size = [self.image_size, self.image_size]
82
+ self.use_positional_encodings = use_positional_encodings
83
  self.model = DiffusionWrapper(unet_config, conditioning_key)
 
 
 
 
84
  self.use_ema = use_ema
85
+ if self.use_ema:
86
+ self.model_ema = LitEma(self.model)
87
+ mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
88
+
89
  self.use_scheduler = scheduler_config is not None
90
  if self.use_scheduler:
91
  self.scheduler_config = scheduler_config
 
109
  if self.learn_logvar:
110
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
111
 
112
+
113
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
114
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
115
  if exists(given_betas):
 
170
  self.model_ema.store(self.model.parameters())
171
  self.model_ema.copy_to(self.model)
172
  if context is not None:
173
+ mainlogger.info(f"{context}: Switched to EMA weights")
174
  try:
175
  yield None
176
  finally:
177
  if self.use_ema:
178
  self.model_ema.restore(self.model.parameters())
179
  if context is not None:
180
+ mainlogger.info(f"{context}: Restored training weights")
181
 
182
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
183
  sd = torch.load(path, map_location="cpu")
 
186
  keys = list(sd.keys())
187
  for k in keys:
188
  for ik in ignore_keys:
189
+ if k.startswith(ik):
190
+ mainlogger.info("Deleting key {} from state_dict.".format(k))
191
  del sd[k]
192
  missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
193
  sd, strict=False)
194
+ mainlogger.info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
195
  if len(missing) > 0:
196
+ mainlogger.info(f"Missing Keys: {missing}")
197
  if len(unexpected) > 0:
198
+ mainlogger.info(f"Unexpected Keys: {unexpected}")
199
 
200
  def q_mean_variance(self, x_start, t):
201
  """
 
262
 
263
  @torch.no_grad()
264
  def sample(self, batch_size=16, return_intermediates=False):
265
+ image_size = self.image_size
266
  channels = self.channels
267
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
 
 
268
  return_intermediates=return_intermediates)
269
 
270
  def q_sample(self, x_start, t, noise=None):
271
  noise = default(noise, lambda: torch.randn_like(x_start))
272
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start *
273
+ extract_into_tensor(self.scale_arr, t, x_start.shape) +
274
  extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def get_input(self, batch, k):
277
  x = batch[k]
278
  x = x.to(memory_format=torch.contiguous_format).float()
279
  return x
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _get_rows_from_list(self, samples):
282
  n_imgs_per_row = len(samples)
283
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
284
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
 
285
  denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
286
  return denoise_grid
287
 
288
  @torch.no_grad()
289
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
 
 
290
  log = dict()
291
  x = self.get_input(batch, self.first_stage_key)
292
  N = min(x.shape[0], N)
293
  n_row = min(x.shape[0], n_row)
294
  x = x.to(self.device)[:N]
295
  log["inputs"] = x
 
 
296
 
297
+ # get diffusion row
298
+ diffusion_row = list()
299
+ x_start = x[:n_row]
 
300
 
301
+ for t in range(self.num_timesteps):
302
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
303
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
304
+ t = t.to(self.device).long()
305
+ noise = torch.randn_like(x_start)
306
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
307
+ diffusion_row.append(x_noisy)
308
 
309
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
310
 
311
  if sample:
312
  # get denoise row
 
314
  samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
315
 
316
  log["samples"] = samples
317
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
 
318
 
319
  if return_keys:
320
  if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
 
323
  return {key: log[key] for key in return_keys}
324
  return log
325
 
 
 
 
 
 
 
 
 
326
 
327
  class LatentDiffusion(DDPM):
328
  """main class"""
 
330
  first_stage_config,
331
  cond_stage_config,
332
  num_timesteps_cond=None,
333
+ cond_stage_key="caption",
334
  cond_stage_trainable=False,
 
335
  cond_stage_forward=None,
336
  conditioning_key=None,
337
+ uncond_prob=0.2,
338
+ uncond_type="empty_seq",
339
  scale_factor=1.0,
340
  scale_by_std=False,
341
  encoder_type="2d",
 
 
 
 
342
  only_model=False,
343
+ use_scale=False,
344
+ scale_a=1,
345
+ scale_b=0.3,
346
+ mid_step=400,
347
+ fix_scale_bug=False,
348
  *args, **kwargs):
349
  self.num_timesteps_cond = default(num_timesteps_cond, 1)
350
  self.scale_by_std = scale_by_std
351
  assert self.num_timesteps_cond <= kwargs['timesteps']
352
  # for backwards compatibility after implementation of DiffusionWrapper
 
 
 
 
 
353
  ckpt_path = kwargs.pop("ckpt_path", None)
354
  ignore_keys = kwargs.pop("ignore_keys", [])
355
+ conditioning_key = default(conditioning_key, 'crossattn')
356
  super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
357
+
358
  self.cond_stage_trainable = cond_stage_trainable
359
  self.cond_stage_key = cond_stage_key
360
+
361
+ # scale factor
362
+ self.use_scale=use_scale
363
+ if self.use_scale:
364
+ self.scale_a=scale_a
365
+ self.scale_b=scale_b
366
+ if fix_scale_bug:
367
+ scale_step=self.num_timesteps-mid_step
368
+ else: #bug
369
+ scale_step = self.num_timesteps
370
+
371
+ scale_arr1 = np.linspace(scale_a, scale_b, mid_step)
372
+ scale_arr2 = np.full(scale_step, scale_b)
373
+ scale_arr = np.concatenate((scale_arr1, scale_arr2))
374
+ scale_arr_prev = np.append(scale_a, scale_arr[:-1])
375
+ to_torch = partial(torch.tensor, dtype=torch.float32)
376
+ self.register_buffer('scale_arr', to_torch(scale_arr))
377
+
378
  try:
379
  self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
380
  except:
 
385
  self.register_buffer('scale_factor', torch.tensor(scale_factor))
386
  self.instantiate_first_stage(first_stage_config)
387
  self.instantiate_cond_stage(cond_stage_config)
 
 
 
 
388
  self.first_stage_config = first_stage_config
389
+ self.cond_stage_config = cond_stage_config
390
+ self.clip_denoised = False
391
+
392
+ self.cond_stage_forward = cond_stage_forward
393
  self.encoder_type = encoder_type
394
  assert(encoder_type in ["2d", "3d"])
395
+ self.uncond_prob = uncond_prob
396
+ self.classifier_free_guidance = True if uncond_prob > 0 else False
397
+ assert(uncond_type in ["zero_embed", "empty_seq"])
398
+ self.uncond_type = uncond_type
399
+
400
+
401
  self.restarted_from_ckpt = False
 
402
  if ckpt_path is not None:
403
  self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
404
  self.restarted_from_ckpt = True
405
+
 
 
 
 
 
 
 
 
406
 
407
  def make_cond_schedule(self, ):
408
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
409
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
410
  self.cond_ids[:self.num_timesteps_cond] = ids
411
 
412
+ def q_sample(self, x_start, t, noise=None):
413
+ noise = default(noise, lambda: torch.randn_like(x_start))
414
+ if self.use_scale:
415
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start *
416
+ extract_into_tensor(self.scale_arr, t, x_start.shape) +
417
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
418
+ else:
419
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
420
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
+
423
+ def _freeze_model(self):
424
+ for name, para in self.model.diffusion_model.named_parameters():
425
+ para.requires_grad = False
426
 
427
  def instantiate_first_stage(self, config):
428
  model = instantiate_from_config(config)
 
432
  param.requires_grad = False
433
 
434
  def instantiate_cond_stage(self, config):
 
 
 
435
  if not self.cond_stage_trainable:
436
+ model = instantiate_from_config(config)
437
+ self.cond_stage_model = model.eval()
438
+ self.cond_stage_model.train = disabled_train
439
+ for param in self.cond_stage_model.parameters():
440
+ param.requires_grad = False
 
 
 
 
 
 
 
441
  else:
 
 
442
  model = instantiate_from_config(config)
443
  self.cond_stage_model = model
444
+
 
 
 
 
 
 
 
 
 
 
 
 
445
  def get_learned_conditioning(self, c):
446
  if self.cond_stage_forward is None:
447
  if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
 
455
  c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
456
  return c
457
 
458
+ def get_first_stage_encoding(self, encoder_posterior, noise=None):
459
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
460
+ z = encoder_posterior.sample(noise=noise)
461
+ elif isinstance(encoder_posterior, torch.Tensor):
462
+ z = encoder_posterior
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  else:
464
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
465
+ return self.scale_factor * z
466
+
 
 
467
  @torch.no_grad()
468
+ def encode_first_stage(self, x):
469
+ if self.encoder_type == "2d" and x.dim() == 5:
470
+ b, _, t, _, _ = x.shape
471
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
472
+ reshape_back = True
473
+ else:
474
+ reshape_back = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
+ encoder_posterior = self.first_stage_model.encode(x)
477
+ results = self.get_first_stage_encoding(encoder_posterior).detach()
478
 
479
+ if reshape_back:
480
+ results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
 
 
 
 
 
481
 
 
 
 
 
 
 
482
  return results
483
 
484
  @torch.no_grad()
485
+ def encode_first_stage_2DAE(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
 
 
487
  b, _, t, _, _ = x.shape
488
+ results = torch.cat([self.get_first_stage_encoding(self.first_stage_model.encode(x[:,:,i])).detach().unsqueeze(2) for i in range(t)], dim=2)
489
+
 
 
 
 
 
 
 
 
 
 
490
  return results
491
 
492
+ def decode_core(self, z, **kwargs):
493
+ if self.encoder_type == "2d" and z.dim() == 5:
494
+ b, _, t, _, _ = z.shape
495
+ z = rearrange(z, 'b c t h w -> (b t) c h w')
496
+ reshape_back = True
497
+ else:
498
+ reshape_back = False
499
+
500
+ z = 1. / self.scale_factor * z
501
 
502
+ results = self.first_stage_model.decode(z, **kwargs)
503
+
504
+ if reshape_back:
505
+ results = rearrange(results, '(b t) c h w -> b c t h w', b=b,t=t)
506
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
+ @torch.no_grad()
509
+ def decode_first_stage(self, z, **kwargs):
510
+ return self.decode_core(z, **kwargs)
511
 
512
+ def apply_model(self, x_noisy, t, cond, **kwargs):
513
  if isinstance(cond, dict):
514
  # hybrid case, cond is exptected to be a dict
515
  pass
 
521
 
522
  x_recon = self.model(x_noisy, t, **cond, **kwargs)
523
 
524
+ if isinstance(x_recon, tuple):
525
  return x_recon[0]
526
  else:
527
  return x_recon
528
 
529
+ def _get_denoise_row_from_list(self, samples, desc=''):
530
+ denoise_row = []
531
+ for zd in tqdm(samples, desc=desc):
532
+ denoise_row.append(self.decode_first_stage(zd.to(self.device)))
533
+ n_log_timesteps = len(denoise_row)
534
 
535
+ denoise_row = torch.stack(denoise_row) # n_log_timesteps, b, C, H, W
536
+
537
+ if denoise_row.dim() == 5:
538
+ # img, num_imgs= n_log_timesteps * bs, grid_size=[bs,n_log_timesteps]
539
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
540
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
541
+ denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps)
542
+ elif denoise_row.dim() == 6:
543
+ # video, grid_size=[n_log_timesteps*bs, t]
544
+ video_length = denoise_row.shape[3]
545
+ denoise_grid = rearrange(denoise_row, 'n b c t h w -> b n c t h w')
546
+ denoise_grid = rearrange(denoise_grid, 'b n c t h w -> (b n) c t h w')
547
+ denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w')
548
+ denoise_grid = make_grid(denoise_grid, nrow=video_length)
 
 
 
 
549
  else:
550
+ raise ValueError
 
 
551
 
552
+ return denoise_grid
553
+
554
 
555
+ @torch.no_grad()
556
+ def decode_first_stage_2DAE(self, z, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
+ b, _, t, _, _ = z.shape
559
+ z = 1. / self.scale_factor * z
560
+ results = torch.cat([self.first_stage_model.decode(z[:,:,i], **kwargs).unsqueeze(2) for i in range(t)], dim=2)
561
 
562
+ return results
 
 
 
 
563
 
 
564
 
565
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=False, score_corrector=None, corrector_kwargs=None, **kwargs):
 
 
 
566
  t_in = t
567
+ model_out = self.apply_model(x, t_in, c, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  if score_corrector is not None:
570
  assert self.parameterization == "eps"
571
  model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
572
 
 
 
 
573
  if self.parameterization == "eps":
574
  x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
575
  elif self.parameterization == "x0":
 
579
 
580
  if clip_denoised:
581
  x_recon.clamp_(-1., 1.)
582
+
 
583
  model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
584
+
585
+ if return_x0:
 
586
  return model_mean, posterior_variance, posterior_log_variance, x_recon
587
  else:
588
  return model_mean, posterior_variance, posterior_log_variance
589
 
590
  @torch.no_grad()
591
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return_x0=False, \
592
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, **kwargs):
 
 
 
593
  b, *_, device = *x.shape, x.device
594
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_x0=return_x0, \
595
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs)
596
+ if return_x0:
 
 
 
 
 
 
 
 
597
  model_mean, _, model_log_variance, x0 = outputs
598
  else:
599
  model_mean, _, model_log_variance = outputs
 
601
  noise = noise_like(x.shape, device, repeat_noise) * temperature
602
  if noise_dropout > 0.:
603
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
604
+ # no noise when t == 0
605
  nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
606
 
 
 
607
  if return_x0:
608
  return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
609
  else:
610
  return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
611
 
612
  @torch.no_grad()
613
+ def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, \
614
+ timesteps=None, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
  if not log_every_t:
617
  log_every_t = self.log_every_t
618
  device = self.betas.device
619
+ b = shape[0]
 
620
  # sample an initial noise
621
  if x_T is None:
622
  img = torch.randn(shape, device=device)
623
  else:
624
  img = x_T
625
+
626
  intermediates = [img]
627
  if timesteps is None:
628
  timesteps = self.num_timesteps
 
629
  if start_T is not None:
630
  timesteps = min(timesteps, start_T)
631
+
632
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(range(0, timesteps))
633
 
634
  if mask is not None:
635
  assert x0 is not None
 
642
  tc = self.cond_ids[ts].to(cond.device)
643
  cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
644
 
645
+ img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, **kwargs)
 
 
 
 
 
646
  if mask is not None:
647
  img_orig = self.q_sample(x0, ts)
648
  img = img_orig * mask + (1. - mask) * img
 
656
  return img, intermediates
657
  return img
658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
+ class LatentVisualDiffusion(LatentDiffusion):
661
+ def __init__(self, cond_img_config, finegrained=False, random_cond=False, *args, **kwargs):
662
+ super().__init__(*args, **kwargs)
663
+ self.random_cond = random_cond
664
+ self.instantiate_img_embedder(cond_img_config, freeze=True)
665
+ num_tokens = 16 if finegrained else 4
666
+ self.image_proj_model = self.init_projector(use_finegrained=finegrained, num_tokens=num_tokens, input_dim=1024,\
667
+ cross_attention_dim=1024, dim=1280)
668
+
669
+ def instantiate_img_embedder(self, config, freeze=True):
670
+ embedder = instantiate_from_config(config)
671
+ if freeze:
672
+ self.embedder = embedder.eval()
673
+ self.embedder.train = disabled_train
674
+ for param in self.embedder.parameters():
675
+ param.requires_grad = False
676
+
677
+ def init_projector(self, use_finegrained, num_tokens, input_dim, cross_attention_dim, dim):
678
+ if not use_finegrained:
679
+ image_proj_model = ImageProjModel(clip_extra_context_tokens=num_tokens, cross_attention_dim=cross_attention_dim,
680
+ clip_embeddings_dim=input_dim
681
+ )
682
  else:
683
+ image_proj_model = Resampler(dim=input_dim, depth=4, dim_head=64, heads=12, num_queries=num_tokens,
684
+ embedding_dim=dim, output_dim=cross_attention_dim, ff_mult=4
685
+ )
686
+ return image_proj_model
687
 
688
+ ## Never delete this func: it is used in log_images() and inference stage
689
+ def get_image_embeds(self, batch_imgs):
690
+ ## img: b c h w
691
+ img_token = self.embedder(batch_imgs)
692
+ img_emb = self.image_proj_model(img_token)
693
+ return img_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
 
696
  class DiffusionWrapper(pl.LightningModule):
697
  def __init__(self, diff_model_config, conditioning_key):
698
  super().__init__()
699
  self.diffusion_model = instantiate_from_config(diff_model_config)
 
700
  self.conditioning_key = conditioning_key
 
701
 
702
  def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,
703
  c_adm=None, s=None, mask=None, **kwargs):
704
  # temporal_context = fps is foNone
705
  if self.conditioning_key is None:
706
+ out = self.diffusion_model(x, t)
707
  elif self.conditioning_key == 'concat':
708
  xc = torch.cat([x] + c_concat, dim=1)
709
  out = self.diffusion_model(xc, t, **kwargs)
 
711
  cc = torch.cat(c_crossattn, 1)
712
  out = self.diffusion_model(x, t, context=cc, **kwargs)
713
  elif self.conditioning_key == 'hybrid':
714
+ ## it is just right [b,c,t,h,w]: concatenate in channel dim
715
  xc = torch.cat([x] + c_concat, dim=1)
716
  cc = torch.cat(c_crossattn, 1)
717
+ out = self.diffusion_model(xc, t, context=cc)
718
  elif self.conditioning_key == 'resblockcond':
719
  cc = c_crossattn[0]
720
+ out = self.diffusion_model(x, t, context=cc)
721
  elif self.conditioning_key == 'adm':
722
  cc = c_crossattn[0]
723
+ out = self.diffusion_model(x, t, y=cc)
724
  elif self.conditioning_key == 'hybrid-adm':
725
  assert c_adm is not None
726
  xc = torch.cat([x] + c_concat, dim=1)
727
  cc = torch.cat(c_crossattn, 1)
728
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
729
  elif self.conditioning_key == 'hybrid-time':
730
  assert s is not None
731
  xc = torch.cat([x] + c_concat, dim=1)
732
  cc = torch.cat(c_crossattn, 1)
733
+ out = self.diffusion_model(xc, t, context=cc, s=s)
734
  elif self.conditioning_key == 'concat-time-mask':
735
  # assert s is not None
736
+ # mainlogger.info('x & mask:',x.shape,c_concat[0].shape)
737
  xc = torch.cat([x] + c_concat, dim=1)
738
+ out = self.diffusion_model(xc, t, context=None, s=s, mask=mask)
739
  elif self.conditioning_key == 'concat-adm-mask':
740
  # assert s is not None
741
+ # mainlogger.info('x & mask:',x.shape,c_concat[0].shape)
742
  if c_concat is not None:
743
  xc = torch.cat([x] + c_concat, dim=1)
744
  else:
745
  xc = x
746
+ out = self.diffusion_model(xc, t, context=None, y=s, mask=mask)
 
 
 
747
  elif self.conditioning_key == 'hybrid-adm-mask':
748
  cc = torch.cat(c_crossattn, 1)
749
  if c_concat is not None:
750
  xc = torch.cat([x] + c_concat, dim=1)
751
  else:
752
  xc = x
753
+ out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask)
754
  elif self.conditioning_key == 'hybrid-time-adm': # adm means y, e.g., class index
755
  # assert s is not None
756
  assert c_adm is not None
757
  xc = torch.cat([x] + c_concat, dim=1)
758
  cc = torch.cat(c_crossattn, 1)
759
+ out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm)
760
  else:
761
  raise NotImplementedError()
762
 
763
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/models/modules/adapter.py DELETED
@@ -1,105 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from collections import OrderedDict
4
- from lvdm.models.modules.util import (
5
- zero_module,
6
- conv_nd,
7
- avg_pool_nd
8
- )
9
-
10
- class Downsample(nn.Module):
11
- """
12
- A downsampling layer with an optional convolution.
13
- :param channels: channels in the inputs and outputs.
14
- :param use_conv: a bool determining if a convolution is applied.
15
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
16
- downsampling occurs in the inner-two dimensions.
17
- """
18
-
19
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
20
- super().__init__()
21
- self.channels = channels
22
- self.out_channels = out_channels or channels
23
- self.use_conv = use_conv
24
- self.dims = dims
25
- stride = 2 if dims != 3 else (1, 2, 2)
26
- if use_conv:
27
- self.op = conv_nd(
28
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
29
- )
30
- else:
31
- assert self.channels == self.out_channels
32
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
33
-
34
- def forward(self, x):
35
- assert x.shape[1] == self.channels
36
- return self.op(x)
37
-
38
-
39
- class ResnetBlock(nn.Module):
40
- def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
41
- super().__init__()
42
- ps = ksize // 2
43
- if in_c != out_c or sk == False:
44
- self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
45
- else:
46
- # print('n_in')
47
- self.in_conv = None
48
- self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
49
- self.act = nn.ReLU()
50
- self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
51
- if sk == False:
52
- self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
53
- else:
54
- self.skep = None
55
-
56
- self.down = down
57
- if self.down == True:
58
- self.down_opt = Downsample(in_c, use_conv=use_conv)
59
-
60
- def forward(self, x):
61
- if self.down == True:
62
- x = self.down_opt(x)
63
- if self.in_conv is not None: # edit
64
- x = self.in_conv(x)
65
-
66
- h = self.block1(x)
67
- h = self.act(h)
68
- h = self.block2(h)
69
- if self.skep is not None:
70
- return h + self.skep(x)
71
- else:
72
- return h + x
73
-
74
-
75
- class Adapter(nn.Module):
76
- def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
77
- super(Adapter, self).__init__()
78
- self.unshuffle = nn.PixelUnshuffle(8)
79
- self.channels = channels
80
- self.nums_rb = nums_rb
81
- self.body = []
82
- for i in range(len(channels)):
83
- for j in range(nums_rb):
84
- if (i != 0) and (j == 0):
85
- self.body.append(
86
- ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
87
- else:
88
- self.body.append(
89
- ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
90
- self.body = nn.ModuleList(self.body)
91
- self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
92
-
93
- def forward(self, x):
94
- # unshuffle
95
- x = self.unshuffle(x)
96
- # extract features
97
- features = []
98
- x = self.conv_in(x)
99
- for i in range(len(self.channels)):
100
- for j in range(self.nums_rb):
101
- idx = i * self.nums_rb + j
102
- x = self.body[idx](x)
103
- features.append(x)
104
-
105
- return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/models/modules/attention_temporal.py DELETED
@@ -1,399 +0,0 @@
1
- from typing import Optional, Any
2
-
3
- import torch
4
- import torch as th
5
- from torch import nn, einsum
6
- from einops import rearrange, repeat
7
- try:
8
- import xformers
9
- import xformers.ops
10
- XFORMERS_IS_AVAILBLE = True
11
- except:
12
- XFORMERS_IS_AVAILBLE = False
13
-
14
- from lvdm.models.modules.util import (
15
- GEGLU,
16
- exists,
17
- default,
18
- Normalize,
19
- checkpoint,
20
- zero_module,
21
- )
22
-
23
-
24
- # ---------------------------------------------------------------------------------------------------
25
- class FeedForward(nn.Module):
26
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
27
- super().__init__()
28
- inner_dim = int(dim * mult)
29
- dim_out = default(dim_out, dim)
30
- project_in = nn.Sequential(
31
- nn.Linear(dim, inner_dim),
32
- nn.GELU()
33
- ) if not glu else GEGLU(dim, inner_dim)
34
-
35
- self.net = nn.Sequential(
36
- project_in,
37
- nn.Dropout(dropout),
38
- nn.Linear(inner_dim, dim_out)
39
- )
40
-
41
- def forward(self, x):
42
- return self.net(x)
43
-
44
-
45
- # ---------------------------------------------------------------------------------------------------
46
- class RelativePosition(nn.Module):
47
- """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
48
-
49
- def __init__(self, num_units, max_relative_position):
50
- super().__init__()
51
- self.num_units = num_units
52
- self.max_relative_position = max_relative_position
53
- self.embeddings_table = nn.Parameter(th.Tensor(max_relative_position * 2 + 1, num_units))
54
- nn.init.xavier_uniform_(self.embeddings_table)
55
-
56
- def forward(self, length_q, length_k):
57
- device = self.embeddings_table.device
58
- range_vec_q = th.arange(length_q, device=device)
59
- range_vec_k = th.arange(length_k, device=device)
60
- distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
61
- distance_mat_clipped = th.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
62
- final_mat = distance_mat_clipped + self.max_relative_position
63
- final_mat = final_mat.long()
64
- embeddings = self.embeddings_table[final_mat]
65
- return embeddings
66
-
67
-
68
- # ---------------------------------------------------------------------------------------------------
69
- class TemporalCrossAttention(nn.Module):
70
- def __init__(self,
71
- query_dim,
72
- context_dim=None,
73
- heads=8,
74
- dim_head=64,
75
- dropout=0.,
76
- use_relative_position=False, # whether use relative positional representation in temporal attention.
77
- temporal_length=None, # relative positional representation
78
- **kwargs,
79
- ):
80
- super().__init__()
81
- inner_dim = dim_head * heads
82
- context_dim = default(context_dim, query_dim)
83
- self.context_dim = context_dim
84
- self.scale = dim_head ** -0.5
85
- self.heads = heads
86
- self.temporal_length = temporal_length
87
- self.use_relative_position = use_relative_position
88
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
89
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
90
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
91
- self.to_out = nn.Sequential(
92
- nn.Linear(inner_dim, query_dim),
93
- nn.Dropout(dropout)
94
- )
95
-
96
- if use_relative_position:
97
- assert(temporal_length is not None)
98
- self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
99
- self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
100
-
101
- nn.init.constant_(self.to_q.weight, 0)
102
- nn.init.constant_(self.to_k.weight, 0)
103
- nn.init.constant_(self.to_v.weight, 0)
104
- nn.init.constant_(self.to_out[0].weight, 0)
105
- nn.init.constant_(self.to_out[0].bias, 0)
106
-
107
- def forward(self, x, context=None, mask=None):
108
- nh = self.heads
109
- out = x
110
-
111
- # cal qkv
112
- q = self.to_q(out)
113
- context = default(context, x)
114
- k = self.to_k(context)
115
- v = self.to_v(context)
116
-
117
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=nh), (q, k, v))
118
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
119
-
120
- # relative positional embedding
121
- if self.use_relative_position:
122
- len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
123
- k2 = self.relative_position_k(len_q, len_k)
124
- sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale
125
- sim += sim2
126
-
127
- # mask attention
128
- if mask is not None:
129
- max_neg_value = -1e9
130
- sim = sim + (1-mask.float()) * max_neg_value # 1=masking,0=no masking
131
-
132
- # attend to values
133
- attn = sim.softmax(dim=-1)
134
- out = einsum('b i j, b j d -> b i d', attn, v)
135
-
136
- # relative positional embedding
137
- if self.use_relative_position:
138
- v2 = self.relative_position_v(len_q, len_v)
139
- out2 = einsum('b t s, t s d -> b t d', attn, v2)
140
- out += out2
141
-
142
- # merge head
143
- out = rearrange(out, '(b h) n d -> b n (h d)', h=nh)
144
- return self.to_out(out)
145
-
146
-
147
- # ---------------------------------------------------------------------------------------------------
148
- class CrossAttention(nn.Module):
149
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
150
- **kwargs,):
151
- super().__init__()
152
- inner_dim = dim_head * heads
153
- context_dim = default(context_dim, query_dim)
154
-
155
- self.scale = dim_head ** -0.5
156
- self.heads = heads
157
-
158
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
159
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
160
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
161
-
162
- self.to_out = nn.Sequential(
163
- nn.Linear(inner_dim, query_dim),
164
- nn.Dropout(dropout)
165
- )
166
-
167
- def forward(self, x, context=None, mask=None):
168
- h = self.heads
169
- b = x.shape[0]
170
-
171
- q = self.to_q(x)
172
- context = default(context, x)
173
- k = self.to_k(context)
174
- v = self.to_v(context)
175
-
176
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
177
-
178
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
179
-
180
- if exists(mask):
181
- mask = rearrange(mask, 'b ... -> b (...)')
182
- max_neg_value = -torch.finfo(sim.dtype).max
183
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
184
- sim.masked_fill_(~mask, max_neg_value)
185
-
186
- attn = sim.softmax(dim=-1)
187
-
188
- out = einsum('b i j, b j d -> b i d', attn, v)
189
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
190
- return self.to_out(out)
191
-
192
-
193
- # ---------------------------------------------------------------------------------------------------
194
- class MemoryEfficientCrossAttention(nn.Module):
195
- """https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
196
- """
197
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0,
198
- **kwargs,):
199
- super().__init__()
200
- print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
201
- f"{heads} heads."
202
- )
203
- inner_dim = dim_head * heads
204
- context_dim = default(context_dim, query_dim)
205
-
206
- self.heads = heads
207
- self.dim_head = dim_head
208
-
209
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212
-
213
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214
- self.attention_op: Optional[Any] = None
215
-
216
- def forward(self, x, context=None, mask=None):
217
- q = self.to_q(x)
218
- context = default(context, x)
219
- k = self.to_k(context)
220
- v = self.to_v(context)
221
-
222
- b, _, _ = q.shape
223
- q, k, v = map(
224
- lambda t: t.unsqueeze(3)
225
- .reshape(b, t.shape[1], self.heads, self.dim_head)
226
- .permute(0, 2, 1, 3)
227
- .reshape(b * self.heads, t.shape[1], self.dim_head)
228
- .contiguous(),
229
- (q, k, v),
230
- )
231
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
232
-
233
- if exists(mask):
234
- raise NotImplementedError
235
- out = (
236
- out.unsqueeze(0)
237
- .reshape(b, self.heads, out.shape[1], self.dim_head)
238
- .permute(0, 2, 1, 3)
239
- .reshape(b, out.shape[1], self.heads * self.dim_head)
240
- )
241
- return self.to_out(out)
242
-
243
-
244
- # ---------------------------------------------------------------------------------------------------
245
- class BasicTransformerBlockST(nn.Module):
246
- """
247
- if no context is given to forward function, cross-attention defaults to self-attention
248
- """
249
- def __init__(self,
250
- # Spatial
251
- dim,
252
- n_heads,
253
- d_head,
254
- dropout=0.,
255
- context_dim=None,
256
- gated_ff=True,
257
- checkpoint=True,
258
- # Temporal
259
- temporal_length=None,
260
- use_relative_position=True,
261
- **kwargs,
262
- ):
263
- super().__init__()
264
-
265
- # spatial self attention (if context_dim is None) and spatial cross attention
266
- if XFORMERS_IS_AVAILBLE:
267
- self.attn1 = MemoryEfficientCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
268
- self.attn2 = MemoryEfficientCrossAttention(query_dim=dim, context_dim=context_dim,
269
- heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
270
- else:
271
- self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
272
- self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
273
- heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
274
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
275
-
276
- self.norm1 = nn.LayerNorm(dim)
277
- self.norm2 = nn.LayerNorm(dim)
278
- self.norm3 = nn.LayerNorm(dim)
279
- self.checkpoint = checkpoint
280
-
281
- # Temporal attention
282
- self.attn1_tmp = TemporalCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
283
- temporal_length=temporal_length,
284
- use_relative_position=use_relative_position,
285
- **kwargs,
286
- )
287
- self.attn2_tmp = TemporalCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
288
- # cross attn
289
- context_dim=None,
290
- # temporal attn
291
- temporal_length=temporal_length,
292
- use_relative_position=use_relative_position,
293
- **kwargs,
294
- )
295
- self.norm4 = nn.LayerNorm(dim)
296
- self.norm5 = nn.LayerNorm(dim)
297
-
298
- def forward(self, x, context=None, **kwargs):
299
- return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
300
-
301
- def _forward(self, x, context=None, mask=None,):
302
- assert(x.dim() == 5), f"x shape = {x.shape}"
303
- b, c, t, h, w = x.shape
304
-
305
- # spatial self attention
306
- x = rearrange(x, 'b c t h w -> (b t) (h w) c')
307
- x = self.attn1(self.norm1(x)) + x
308
- x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h)
309
-
310
- # temporal self attention
311
- x = rearrange(x, 'b c t h w -> (b h w) t c')
312
- x = self.attn1_tmp(self.norm4(x), mask=mask) + x
313
- x = rearrange(x, '(b h w) t c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
314
-
315
- # spatial cross attention
316
- x = rearrange(x, 'b c t h w -> (b t) (h w) c')
317
- if context is not None:
318
- context_ = []
319
- for i in range(context.shape[0]):
320
- context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
321
- context_ = torch.cat(context_,dim=0)
322
- else:
323
- context_ = None
324
- x = self.attn2(self.norm2(x), context=context_) + x
325
- x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h)
326
-
327
- # temporal cross attention
328
- x = rearrange(x, 'b c t h w -> (b h w) t c')
329
- x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
330
-
331
- # feedforward
332
- x = self.ff(self.norm3(x)) + x
333
- x = rearrange(x, '(b h w) t c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
334
-
335
- return x
336
-
337
-
338
- # ---------------------------------------------------------------------------------------------------
339
- class SpatialTemporalTransformer(nn.Module):
340
- """
341
- Transformer block for video-like data (5D tensor).
342
- First, project the input (aka embedding) with NO reshape.
343
- Then apply standard transformer action.
344
- The 5D -> 3D reshape operation will be done in the specific attention module.
345
- """
346
- def __init__(
347
- self,
348
- in_channels, n_heads, d_head,
349
- depth=1, dropout=0.,
350
- context_dim=None,
351
- # Temporal
352
- temporal_length=None,
353
- use_relative_position=True,
354
- **kwargs,
355
- ):
356
- super().__init__()
357
-
358
- self.in_channels = in_channels
359
- inner_dim = n_heads * d_head
360
-
361
- self.norm = Normalize(in_channels)
362
- self.proj_in = nn.Conv3d(in_channels,
363
- inner_dim,
364
- kernel_size=1,
365
- stride=1,
366
- padding=0)
367
-
368
- self.transformer_blocks = nn.ModuleList(
369
- [BasicTransformerBlockST(
370
- inner_dim, n_heads, d_head, dropout=dropout,
371
- # cross attn
372
- context_dim=context_dim,
373
- # temporal attn
374
- temporal_length=temporal_length,
375
- use_relative_position=use_relative_position,
376
- **kwargs
377
- ) for d in range(depth)]
378
- )
379
-
380
- self.proj_out = zero_module(nn.Conv3d(inner_dim,
381
- in_channels,
382
- kernel_size=1,
383
- stride=1,
384
- padding=0))
385
-
386
- def forward(self, x, context=None, **kwargs):
387
-
388
- assert(x.dim() == 5), f"x shape = {x.shape}"
389
- x_in = x
390
-
391
- x = self.norm(x)
392
- x = self.proj_in(x)
393
-
394
- for block in self.transformer_blocks:
395
- x = block(x, context=context, **kwargs)
396
-
397
- x = self.proj_out(x)
398
-
399
- return x + x_in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/models/modules/condition_modules.py DELETED
@@ -1,40 +0,0 @@
1
- import torch.nn as nn
2
- from transformers import logging
3
- from transformers import CLIPTokenizer, CLIPTextModel
4
- logging.set_verbosity_error()
5
-
6
-
7
- class AbstractEncoder(nn.Module):
8
- def __init__(self):
9
- super().__init__()
10
-
11
- def encode(self, *args, **kwargs):
12
- raise NotImplementedError
13
-
14
-
15
- class FrozenCLIPEmbedder(AbstractEncoder):
16
- """Uses the CLIP transformer encoder for text (from huggingface)"""
17
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
18
- super().__init__()
19
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
20
- self.transformer = CLIPTextModel.from_pretrained(version)
21
- self.device = device
22
- self.max_length = max_length
23
- self.freeze()
24
-
25
- def freeze(self):
26
- self.transformer = self.transformer.eval()
27
- for param in self.parameters():
28
- param.requires_grad = False
29
-
30
- def forward(self, text):
31
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
32
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
33
- tokens = batch_encoding["input_ids"].to(self.device)
34
- outputs = self.transformer(input_ids=tokens)
35
-
36
- z = outputs.last_hidden_state
37
- return z
38
-
39
- def encode(self, text):
40
- return self(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/models/modules/lora.py DELETED
@@ -1,1251 +0,0 @@
1
- import json
2
- from itertools import groupby
3
- from typing import Dict, List, Optional, Set, Tuple, Type, Union
4
-
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- # try:
11
- # from safetensors.torch import safe_open
12
- # from safetensors.torch import save_file as safe_save
13
-
14
- # safetensors_available = True
15
- # except ImportError:
16
- # from .safe_open import safe_open
17
-
18
- # def safe_save(
19
- # tensors: Dict[str, torch.Tensor],
20
- # filename: str,
21
- # metadata: Optional[Dict[str, str]] = None,
22
- # ) -> None:
23
- # raise EnvironmentError(
24
- # "Saving safetensors requires the safetensors library. Please install with pip or similar."
25
- # )
26
-
27
- # safetensors_available = False
28
-
29
-
30
- class LoraInjectedLinear(nn.Module):
31
- def __init__(
32
- self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
33
- ):
34
- super().__init__()
35
-
36
- if r > min(in_features, out_features):
37
- raise ValueError(
38
- f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
39
- )
40
- self.r = r
41
- self.linear = nn.Linear(in_features, out_features, bias)
42
- self.lora_down = nn.Linear(in_features, r, bias=False)
43
- self.dropout = nn.Dropout(dropout_p)
44
- self.lora_up = nn.Linear(r, out_features, bias=False)
45
- self.scale = scale
46
- self.selector = nn.Identity()
47
-
48
- nn.init.normal_(self.lora_down.weight, std=1 / r)
49
- nn.init.zeros_(self.lora_up.weight)
50
-
51
- def forward(self, input):
52
- return (
53
- self.linear(input)
54
- + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
55
- * self.scale
56
- )
57
-
58
- def realize_as_lora(self):
59
- return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
60
-
61
- def set_selector_from_diag(self, diag: torch.Tensor):
62
- # diag is a 1D tensor of size (r,)
63
- assert diag.shape == (self.r,)
64
- self.selector = nn.Linear(self.r, self.r, bias=False)
65
- self.selector.weight.data = torch.diag(diag)
66
- self.selector.weight.data = self.selector.weight.data.to(
67
- self.lora_up.weight.device
68
- ).to(self.lora_up.weight.dtype)
69
-
70
-
71
- class LoraInjectedConv2d(nn.Module):
72
- def __init__(
73
- self,
74
- in_channels: int,
75
- out_channels: int,
76
- kernel_size,
77
- stride=1,
78
- padding=0,
79
- dilation=1,
80
- groups: int = 1,
81
- bias: bool = True,
82
- r: int = 4,
83
- dropout_p: float = 0.1,
84
- scale: float = 1.0,
85
- ):
86
- super().__init__()
87
- if r > min(in_channels, out_channels):
88
- raise ValueError(
89
- f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
90
- )
91
- self.r = r
92
- self.conv = nn.Conv2d(
93
- in_channels=in_channels,
94
- out_channels=out_channels,
95
- kernel_size=kernel_size,
96
- stride=stride,
97
- padding=padding,
98
- dilation=dilation,
99
- groups=groups,
100
- bias=bias,
101
- )
102
-
103
- self.lora_down = nn.Conv2d(
104
- in_channels=in_channels,
105
- out_channels=r,
106
- kernel_size=kernel_size,
107
- stride=stride,
108
- padding=padding,
109
- dilation=dilation,
110
- groups=groups,
111
- bias=False,
112
- )
113
- self.dropout = nn.Dropout(dropout_p)
114
- self.lora_up = nn.Conv2d(
115
- in_channels=r,
116
- out_channels=out_channels,
117
- kernel_size=1,
118
- stride=1,
119
- padding=0,
120
- bias=False,
121
- )
122
- self.selector = nn.Identity()
123
- self.scale = scale
124
-
125
- nn.init.normal_(self.lora_down.weight, std=1 / r)
126
- nn.init.zeros_(self.lora_up.weight)
127
-
128
- def forward(self, input):
129
- return (
130
- self.conv(input)
131
- + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
132
- * self.scale
133
- )
134
-
135
- def realize_as_lora(self):
136
- return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
137
-
138
- def set_selector_from_diag(self, diag: torch.Tensor):
139
- # diag is a 1D tensor of size (r,)
140
- assert diag.shape == (self.r,)
141
- self.selector = nn.Conv2d(
142
- in_channels=self.r,
143
- out_channels=self.r,
144
- kernel_size=1,
145
- stride=1,
146
- padding=0,
147
- bias=False,
148
- )
149
- self.selector.weight.data = torch.diag(diag)
150
-
151
- # same device + dtype as lora_up
152
- self.selector.weight.data = self.selector.weight.data.to(
153
- self.lora_up.weight.device
154
- ).to(self.lora_up.weight.dtype)
155
-
156
-
157
- UNET_DEFAULT_TARGET_REPLACE = {"MemoryEfficientCrossAttention","CrossAttention", "Attention", "GEGLU"}
158
-
159
- UNET_EXTENDED_TARGET_REPLACE = {"TimestepEmbedSequential","SpatialTemporalTransformer", "MemoryEfficientCrossAttention","CrossAttention", "Attention", "GEGLU"}
160
-
161
- TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
162
-
163
- TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPMLP","CLIPAttention"}
164
-
165
- DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
166
-
167
- EMBED_FLAG = "<embed>"
168
-
169
-
170
- def _find_children(
171
- model,
172
- search_class: List[Type[nn.Module]] = [nn.Linear],
173
- ):
174
- """
175
- Find all modules of a certain class (or union of classes).
176
-
177
- Returns all matching modules, along with the parent of those moduless and the
178
- names they are referenced by.
179
- """
180
- # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
181
- for parent in model.modules():
182
- for name, module in parent.named_children():
183
- if any([isinstance(module, _class) for _class in search_class]):
184
- yield parent, name, module
185
-
186
-
187
- def _find_modules_v2(
188
- model,
189
- ancestor_class: Optional[Set[str]] = None,
190
- search_class: List[Type[nn.Module]] = [nn.Linear],
191
- exclude_children_of: Optional[List[Type[nn.Module]]] = [
192
- LoraInjectedLinear,
193
- LoraInjectedConv2d,
194
- ],
195
- ):
196
- """
197
- Find all modules of a certain class (or union of classes) that are direct or
198
- indirect descendants of other modules of a certain class (or union of classes).
199
-
200
- Returns all matching modules, along with the parent of those moduless and the
201
- names they are referenced by.
202
- """
203
-
204
- # Get the targets we should replace all linears under
205
- if type(ancestor_class) is not set:
206
- ancestor_class = set(ancestor_class)
207
- print(ancestor_class)
208
- if ancestor_class is not None:
209
- ancestors = (
210
- module
211
- for module in model.modules()
212
- if module.__class__.__name__ in ancestor_class
213
- )
214
- else:
215
- # this, incase you want to naively iterate over all modules.
216
- ancestors = [module for module in model.modules()]
217
-
218
- # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
219
- for ancestor in ancestors:
220
- for fullname, module in ancestor.named_children():
221
- if any([isinstance(module, _class) for _class in search_class]):
222
- # Find the direct parent if this is a descendant, not a child, of target
223
- *path, name = fullname.split(".")
224
- parent = ancestor
225
- while path:
226
- parent = parent.get_submodule(path.pop(0))
227
- # Skip this linear if it's a child of a LoraInjectedLinear
228
- if exclude_children_of and any(
229
- [isinstance(parent, _class) for _class in exclude_children_of]
230
- ):
231
- continue
232
- # Otherwise, yield it
233
- yield parent, name, module
234
-
235
-
236
- def _find_modules_old(
237
- model,
238
- ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
239
- search_class: List[Type[nn.Module]] = [nn.Linear],
240
- exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
241
- ):
242
- ret = []
243
- for _module in model.modules():
244
- if _module.__class__.__name__ in ancestor_class:
245
-
246
- for name, _child_module in _module.named_children():
247
- if _child_module.__class__ in search_class:
248
- ret.append((_module, name, _child_module))
249
- print(ret)
250
- return ret
251
-
252
-
253
- _find_modules = _find_modules_v2
254
-
255
-
256
- def inject_trainable_lora(
257
- model: nn.Module,
258
- target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
259
- r: int = 4,
260
- loras=None, # path to lora .pt
261
- verbose: bool = False,
262
- dropout_p: float = 0.0,
263
- scale: float = 1.0,
264
- ):
265
- """
266
- inject lora into model, and returns lora parameter groups.
267
- """
268
-
269
- require_grad_params = []
270
- names = []
271
-
272
- if loras != None:
273
- loras = torch.load(loras)
274
-
275
- for _module, name, _child_module in _find_modules(
276
- model, target_replace_module, search_class=[nn.Linear]
277
- ):
278
- weight = _child_module.weight
279
- bias = _child_module.bias
280
- if verbose:
281
- print("LoRA Injection : injecting lora into ", name)
282
- print("LoRA Injection : weight shape", weight.shape)
283
- _tmp = LoraInjectedLinear(
284
- _child_module.in_features,
285
- _child_module.out_features,
286
- _child_module.bias is not None,
287
- r=r,
288
- dropout_p=dropout_p,
289
- scale=scale,
290
- )
291
- _tmp.linear.weight = weight
292
- if bias is not None:
293
- _tmp.linear.bias = bias
294
-
295
- # switch the module
296
- _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
297
- _module._modules[name] = _tmp
298
-
299
- require_grad_params.append(_module._modules[name].lora_up.parameters())
300
- require_grad_params.append(_module._modules[name].lora_down.parameters())
301
-
302
- if loras != None:
303
- _module._modules[name].lora_up.weight = loras.pop(0)
304
- _module._modules[name].lora_down.weight = loras.pop(0)
305
-
306
- _module._modules[name].lora_up.weight.requires_grad = True
307
- _module._modules[name].lora_down.weight.requires_grad = True
308
- names.append(name)
309
-
310
- return require_grad_params, names
311
-
312
-
313
- def inject_trainable_lora_extended(
314
- model: nn.Module,
315
- target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
316
- r: int = 4,
317
- loras=None, # path to lora .pt
318
- ):
319
- """
320
- inject lora into model, and returns lora parameter groups.
321
- """
322
-
323
- require_grad_params = []
324
- names = []
325
-
326
- if loras != None:
327
- loras = torch.load(loras)
328
-
329
- for _module, name, _child_module in _find_modules(
330
- model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
331
- ):
332
- if _child_module.__class__ == nn.Linear:
333
- weight = _child_module.weight
334
- bias = _child_module.bias
335
- _tmp = LoraInjectedLinear(
336
- _child_module.in_features,
337
- _child_module.out_features,
338
- _child_module.bias is not None,
339
- r=r,
340
- )
341
- _tmp.linear.weight = weight
342
- if bias is not None:
343
- _tmp.linear.bias = bias
344
- elif _child_module.__class__ == nn.Conv2d:
345
- weight = _child_module.weight
346
- bias = _child_module.bias
347
- _tmp = LoraInjectedConv2d(
348
- _child_module.in_channels,
349
- _child_module.out_channels,
350
- _child_module.kernel_size,
351
- _child_module.stride,
352
- _child_module.padding,
353
- _child_module.dilation,
354
- _child_module.groups,
355
- _child_module.bias is not None,
356
- r=r,
357
- )
358
-
359
- _tmp.conv.weight = weight
360
- if bias is not None:
361
- _tmp.conv.bias = bias
362
-
363
- # switch the module
364
- _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
365
- if bias is not None:
366
- _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
367
-
368
- _module._modules[name] = _tmp
369
-
370
- require_grad_params.append(_module._modules[name].lora_up.parameters())
371
- require_grad_params.append(_module._modules[name].lora_down.parameters())
372
-
373
- if loras != None:
374
- _module._modules[name].lora_up.weight = loras.pop(0)
375
- _module._modules[name].lora_down.weight = loras.pop(0)
376
-
377
- _module._modules[name].lora_up.weight.requires_grad = True
378
- _module._modules[name].lora_down.weight.requires_grad = True
379
- names.append(name)
380
-
381
- return require_grad_params, names
382
-
383
-
384
- def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
385
-
386
- loras = []
387
-
388
- for _m, _n, _child_module in _find_modules(
389
- model,
390
- target_replace_module,
391
- search_class=[LoraInjectedLinear, LoraInjectedConv2d],
392
- ):
393
- loras.append((_child_module.lora_up, _child_module.lora_down))
394
-
395
- if len(loras) == 0:
396
- raise ValueError("No lora injected.")
397
-
398
- return loras
399
-
400
-
401
- def extract_lora_as_tensor(
402
- model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
403
- ):
404
-
405
- loras = []
406
-
407
- for _m, _n, _child_module in _find_modules(
408
- model,
409
- target_replace_module,
410
- search_class=[LoraInjectedLinear, LoraInjectedConv2d],
411
- ):
412
- up, down = _child_module.realize_as_lora()
413
- if as_fp16:
414
- up = up.to(torch.float16)
415
- down = down.to(torch.float16)
416
-
417
- loras.append((up, down))
418
-
419
- if len(loras) == 0:
420
- raise ValueError("No lora injected.")
421
-
422
- return loras
423
-
424
-
425
- def save_lora_weight(
426
- model,
427
- path="./lora.pt",
428
- target_replace_module=DEFAULT_TARGET_REPLACE,
429
- ):
430
- weights = []
431
- for _up, _down in extract_lora_ups_down(
432
- model, target_replace_module=target_replace_module
433
- ):
434
- weights.append(_up.weight.to("cpu").to(torch.float16))
435
- weights.append(_down.weight.to("cpu").to(torch.float16))
436
-
437
- torch.save(weights, path)
438
-
439
-
440
- def save_lora_as_json(model, path="./lora.json"):
441
- weights = []
442
- for _up, _down in extract_lora_ups_down(model):
443
- weights.append(_up.weight.detach().cpu().numpy().tolist())
444
- weights.append(_down.weight.detach().cpu().numpy().tolist())
445
-
446
- import json
447
-
448
- with open(path, "w") as f:
449
- json.dump(weights, f)
450
-
451
-
452
- def save_safeloras_with_embeds(
453
- modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
454
- embeds: Dict[str, torch.Tensor] = {},
455
- outpath="./lora.safetensors",
456
- ):
457
- """
458
- Saves the Lora from multiple modules in a single safetensor file.
459
-
460
- modelmap is a dictionary of {
461
- "module name": (module, target_replace_module)
462
- }
463
- """
464
- weights = {}
465
- metadata = {}
466
-
467
- for name, (model, target_replace_module) in modelmap.items():
468
- metadata[name] = json.dumps(list(target_replace_module))
469
-
470
- for i, (_up, _down) in enumerate(
471
- extract_lora_as_tensor(model, target_replace_module)
472
- ):
473
- rank = _down.shape[0]
474
-
475
- metadata[f"{name}:{i}:rank"] = str(rank)
476
- weights[f"{name}:{i}:up"] = _up
477
- weights[f"{name}:{i}:down"] = _down
478
-
479
- for token, tensor in embeds.items():
480
- metadata[token] = EMBED_FLAG
481
- weights[token] = tensor
482
-
483
- print(f"Saving weights to {outpath}")
484
- safe_save(weights, outpath, metadata)
485
-
486
-
487
- def save_safeloras(
488
- modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
489
- outpath="./lora.safetensors",
490
- ):
491
- return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
492
-
493
-
494
- def convert_loras_to_safeloras_with_embeds(
495
- modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
496
- embeds: Dict[str, torch.Tensor] = {},
497
- outpath="./lora.safetensors",
498
- ):
499
- """
500
- Converts the Lora from multiple pytorch .pt files into a single safetensor file.
501
-
502
- modelmap is a dictionary of {
503
- "module name": (pytorch_model_path, target_replace_module, rank)
504
- }
505
- """
506
-
507
- weights = {}
508
- metadata = {}
509
-
510
- for name, (path, target_replace_module, r) in modelmap.items():
511
- metadata[name] = json.dumps(list(target_replace_module))
512
-
513
- lora = torch.load(path)
514
- for i, weight in enumerate(lora):
515
- is_up = i % 2 == 0
516
- i = i // 2
517
-
518
- if is_up:
519
- metadata[f"{name}:{i}:rank"] = str(r)
520
- weights[f"{name}:{i}:up"] = weight
521
- else:
522
- weights[f"{name}:{i}:down"] = weight
523
-
524
- for token, tensor in embeds.items():
525
- metadata[token] = EMBED_FLAG
526
- weights[token] = tensor
527
-
528
- print(f"Saving weights to {outpath}")
529
- safe_save(weights, outpath, metadata)
530
-
531
-
532
- def convert_loras_to_safeloras(
533
- modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
534
- outpath="./lora.safetensors",
535
- ):
536
- convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
537
-
538
-
539
- def parse_safeloras(
540
- safeloras,
541
- ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
542
- """
543
- Converts a loaded safetensor file that contains a set of module Loras
544
- into Parameters and other information
545
-
546
- Output is a dictionary of {
547
- "module name": (
548
- [list of weights],
549
- [list of ranks],
550
- target_replacement_modules
551
- )
552
- }
553
- """
554
- loras = {}
555
- metadata = safeloras.metadata()
556
-
557
- get_name = lambda k: k.split(":")[0]
558
-
559
- keys = list(safeloras.keys())
560
- keys.sort(key=get_name)
561
-
562
- for name, module_keys in groupby(keys, get_name):
563
- info = metadata.get(name)
564
-
565
- if not info:
566
- raise ValueError(
567
- f"Tensor {name} has no metadata - is this a Lora safetensor?"
568
- )
569
-
570
- # Skip Textual Inversion embeds
571
- if info == EMBED_FLAG:
572
- continue
573
-
574
- # Handle Loras
575
- # Extract the targets
576
- target = json.loads(info)
577
-
578
- # Build the result lists - Python needs us to preallocate lists to insert into them
579
- module_keys = list(module_keys)
580
- ranks = [4] * (len(module_keys) // 2)
581
- weights = [None] * len(module_keys)
582
-
583
- for key in module_keys:
584
- # Split the model name and index out of the key
585
- _, idx, direction = key.split(":")
586
- idx = int(idx)
587
-
588
- # Add the rank
589
- ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
590
-
591
- # Insert the weight into the list
592
- idx = idx * 2 + (1 if direction == "down" else 0)
593
- weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
594
-
595
- loras[name] = (weights, ranks, target)
596
-
597
- return loras
598
-
599
-
600
- def parse_safeloras_embeds(
601
- safeloras,
602
- ) -> Dict[str, torch.Tensor]:
603
- """
604
- Converts a loaded safetensor file that contains Textual Inversion embeds into
605
- a dictionary of embed_token: Tensor
606
- """
607
- embeds = {}
608
- metadata = safeloras.metadata()
609
-
610
- for key in safeloras.keys():
611
- # Only handle Textual Inversion embeds
612
- meta = metadata.get(key)
613
- if not meta or meta != EMBED_FLAG:
614
- continue
615
-
616
- embeds[key] = safeloras.get_tensor(key)
617
-
618
- return embeds
619
-
620
- def net_load_lora(net, checkpoint_path, alpha=1.0, remove=False):
621
- visited=[]
622
- state_dict = torch.load(checkpoint_path)
623
- for k, v in state_dict.items():
624
- state_dict[k] = v.to(net.device)
625
-
626
- for key in state_dict:
627
- if ".alpha" in key or key in visited:
628
- continue
629
- layer_infos = key.split(".")[:-2] # remove lora_up and down weight
630
- curr_layer = net
631
- # find the target layer
632
- temp_name = layer_infos.pop(0)
633
- while len(layer_infos) > -1:
634
- curr_layer = curr_layer.__getattr__(temp_name)
635
- if len(layer_infos) > 0:
636
- temp_name = layer_infos.pop(0)
637
- elif len(layer_infos) == 0:
638
- break
639
- if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
640
- print('missing param at:', key)
641
- continue
642
- pair_keys = []
643
- if "lora_down" in key:
644
- pair_keys.append(key.replace("lora_down", "lora_up"))
645
- pair_keys.append(key)
646
- else:
647
- pair_keys.append(key)
648
- pair_keys.append(key.replace("lora_up", "lora_down"))
649
-
650
- # update weight
651
- if len(state_dict[pair_keys[0]].shape) == 4:
652
- # for conv
653
- weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
654
- weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
655
- if remove:
656
- curr_layer.weight.data -= alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
657
- else:
658
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
659
- else:
660
- # for linear
661
- weight_up = state_dict[pair_keys[0]].to(torch.float32)
662
- weight_down = state_dict[pair_keys[1]].to(torch.float32)
663
- if remove:
664
- curr_layer.weight.data -= alpha * torch.mm(weight_up, weight_down)
665
- else:
666
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
667
-
668
- # update visited list
669
- for item in pair_keys:
670
- visited.append(item)
671
- print('load_weight_num:',len(visited))
672
- return
673
-
674
- def change_lora(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0):
675
- # remove lora
676
- if last_time_lora != '':
677
- net_load_lora(model, last_time_lora, alpha=last_time_lora_scale, remove=True)
678
- # add new lora
679
- if inject_lora:
680
- net_load_lora(model, lora_path, alpha=lora_scale)
681
-
682
-
683
- def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weight=None):
684
- visited=[]
685
- state_dict = torch.load(checkpoint_path)
686
- for k, v in state_dict.items():
687
- state_dict[k] = v.to(net.device)
688
-
689
- for key in state_dict:
690
- if ".alpha" in key or key in visited:
691
- continue
692
- layer_infos = key.split(".")[:-2] # remove lora_up and down weight
693
- curr_layer = net
694
- # find the target layer
695
- temp_name = layer_infos.pop(0)
696
- while len(layer_infos) > -1:
697
- curr_layer = curr_layer.__getattr__(temp_name)
698
- if len(layer_infos) > 0:
699
- temp_name = layer_infos.pop(0)
700
- elif len(layer_infos) == 0:
701
- break
702
- if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
703
- print('missing param at:', key)
704
- continue
705
- pair_keys = []
706
- if "lora_down" in key:
707
- pair_keys.append(key.replace("lora_down", "lora_up"))
708
- pair_keys.append(key)
709
- else:
710
- pair_keys.append(key)
711
- pair_keys.append(key.replace("lora_up", "lora_down"))
712
-
713
- # storage weight
714
- if origin_weight is None:
715
- origin_weight = dict()
716
- storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
717
- origin_weight[storage_key] = curr_layer.weight.data.clone()
718
- else:
719
- storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
720
- if storage_key not in origin_weight.keys():
721
- origin_weight[storage_key] = curr_layer.weight.data.clone()
722
-
723
-
724
- # update
725
- if len(state_dict[pair_keys[0]].shape) == 4:
726
- # for conv
727
- if remove:
728
- curr_layer.weight.data = origin_weight[storage_key].clone()
729
- else:
730
- weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
731
- weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
732
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
733
- else:
734
- # for linear
735
- if remove:
736
- curr_layer.weight.data = origin_weight[storage_key].clone()
737
- else:
738
- weight_up = state_dict[pair_keys[0]].to(torch.float32)
739
- weight_down = state_dict[pair_keys[1]].to(torch.float32)
740
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
741
-
742
- # update visited list
743
- for item in pair_keys:
744
- visited.append(item)
745
- print('load_weight_num:',len(visited))
746
- return origin_weight
747
-
748
- def change_lora_v2(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0, origin_weight=None):
749
- # remove lora
750
- if last_time_lora != '':
751
- origin_weight = net_load_lora_v2(model, last_time_lora, alpha=last_time_lora_scale, remove=True, origin_weight=origin_weight)
752
- # add new lora
753
- if inject_lora:
754
- origin_weight = net_load_lora_v2(model, lora_path, alpha=lora_scale, origin_weight=origin_weight)
755
- return origin_weight
756
-
757
-
758
-
759
-
760
-
761
- def load_safeloras(path, device="cpu"):
762
- safeloras = safe_open(path, framework="pt", device=device)
763
- return parse_safeloras(safeloras)
764
-
765
-
766
- def load_safeloras_embeds(path, device="cpu"):
767
- safeloras = safe_open(path, framework="pt", device=device)
768
- return parse_safeloras_embeds(safeloras)
769
-
770
-
771
- def load_safeloras_both(path, device="cpu"):
772
- safeloras = safe_open(path, framework="pt", device=device)
773
- return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
774
-
775
-
776
- def collapse_lora(model, alpha=1.0):
777
-
778
- for _module, name, _child_module in _find_modules(
779
- model,
780
- UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
781
- search_class=[LoraInjectedLinear, LoraInjectedConv2d],
782
- ):
783
-
784
- if isinstance(_child_module, LoraInjectedLinear):
785
- print("Collapsing Lin Lora in", name)
786
-
787
- _child_module.linear.weight = nn.Parameter(
788
- _child_module.linear.weight.data
789
- + alpha
790
- * (
791
- _child_module.lora_up.weight.data
792
- @ _child_module.lora_down.weight.data
793
- )
794
- .type(_child_module.linear.weight.dtype)
795
- .to(_child_module.linear.weight.device)
796
- )
797
-
798
- else:
799
- print("Collapsing Conv Lora in", name)
800
- _child_module.conv.weight = nn.Parameter(
801
- _child_module.conv.weight.data
802
- + alpha
803
- * (
804
- _child_module.lora_up.weight.data.flatten(start_dim=1)
805
- @ _child_module.lora_down.weight.data.flatten(start_dim=1)
806
- )
807
- .reshape(_child_module.conv.weight.data.shape)
808
- .type(_child_module.conv.weight.dtype)
809
- .to(_child_module.conv.weight.device)
810
- )
811
-
812
-
813
- def monkeypatch_or_replace_lora(
814
- model,
815
- loras,
816
- target_replace_module=DEFAULT_TARGET_REPLACE,
817
- r: Union[int, List[int]] = 4,
818
- ):
819
- for _module, name, _child_module in _find_modules(
820
- model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
821
- ):
822
- _source = (
823
- _child_module.linear
824
- if isinstance(_child_module, LoraInjectedLinear)
825
- else _child_module
826
- )
827
-
828
- weight = _source.weight
829
- bias = _source.bias
830
- _tmp = LoraInjectedLinear(
831
- _source.in_features,
832
- _source.out_features,
833
- _source.bias is not None,
834
- r=r.pop(0) if isinstance(r, list) else r,
835
- )
836
- _tmp.linear.weight = weight
837
-
838
- if bias is not None:
839
- _tmp.linear.bias = bias
840
-
841
- # switch the module
842
- _module._modules[name] = _tmp
843
-
844
- up_weight = loras.pop(0)
845
- down_weight = loras.pop(0)
846
-
847
- _module._modules[name].lora_up.weight = nn.Parameter(
848
- up_weight.type(weight.dtype)
849
- )
850
- _module._modules[name].lora_down.weight = nn.Parameter(
851
- down_weight.type(weight.dtype)
852
- )
853
-
854
- _module._modules[name].to(weight.device)
855
-
856
-
857
- def monkeypatch_or_replace_lora_extended(
858
- model,
859
- loras,
860
- target_replace_module=DEFAULT_TARGET_REPLACE,
861
- r: Union[int, List[int]] = 4,
862
- ):
863
- for _module, name, _child_module in _find_modules(
864
- model,
865
- target_replace_module,
866
- search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
867
- ):
868
-
869
- if (_child_module.__class__ == nn.Linear) or (
870
- _child_module.__class__ == LoraInjectedLinear
871
- ):
872
- if len(loras[0].shape) != 2:
873
- continue
874
-
875
- _source = (
876
- _child_module.linear
877
- if isinstance(_child_module, LoraInjectedLinear)
878
- else _child_module
879
- )
880
-
881
- weight = _source.weight
882
- bias = _source.bias
883
- _tmp = LoraInjectedLinear(
884
- _source.in_features,
885
- _source.out_features,
886
- _source.bias is not None,
887
- r=r.pop(0) if isinstance(r, list) else r,
888
- )
889
- _tmp.linear.weight = weight
890
-
891
- if bias is not None:
892
- _tmp.linear.bias = bias
893
-
894
- elif (_child_module.__class__ == nn.Conv2d) or (
895
- _child_module.__class__ == LoraInjectedConv2d
896
- ):
897
- if len(loras[0].shape) != 4:
898
- continue
899
- _source = (
900
- _child_module.conv
901
- if isinstance(_child_module, LoraInjectedConv2d)
902
- else _child_module
903
- )
904
-
905
- weight = _source.weight
906
- bias = _source.bias
907
- _tmp = LoraInjectedConv2d(
908
- _source.in_channels,
909
- _source.out_channels,
910
- _source.kernel_size,
911
- _source.stride,
912
- _source.padding,
913
- _source.dilation,
914
- _source.groups,
915
- _source.bias is not None,
916
- r=r.pop(0) if isinstance(r, list) else r,
917
- )
918
-
919
- _tmp.conv.weight = weight
920
-
921
- if bias is not None:
922
- _tmp.conv.bias = bias
923
-
924
- # switch the module
925
- _module._modules[name] = _tmp
926
-
927
- up_weight = loras.pop(0)
928
- down_weight = loras.pop(0)
929
-
930
- _module._modules[name].lora_up.weight = nn.Parameter(
931
- up_weight.type(weight.dtype)
932
- )
933
- _module._modules[name].lora_down.weight = nn.Parameter(
934
- down_weight.type(weight.dtype)
935
- )
936
-
937
- _module._modules[name].to(weight.device)
938
-
939
-
940
- def monkeypatch_or_replace_safeloras(models, safeloras):
941
- loras = parse_safeloras(safeloras)
942
-
943
- for name, (lora, ranks, target) in loras.items():
944
- model = getattr(models, name, None)
945
-
946
- if not model:
947
- print(f"No model provided for {name}, contained in Lora")
948
- continue
949
-
950
- monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
951
-
952
-
953
- def monkeypatch_remove_lora(model):
954
- for _module, name, _child_module in _find_modules(
955
- model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
956
- ):
957
- if isinstance(_child_module, LoraInjectedLinear):
958
- _source = _child_module.linear
959
- weight, bias = _source.weight, _source.bias
960
-
961
- _tmp = nn.Linear(
962
- _source.in_features, _source.out_features, bias is not None
963
- )
964
-
965
- _tmp.weight = weight
966
- if bias is not None:
967
- _tmp.bias = bias
968
-
969
- else:
970
- _source = _child_module.conv
971
- weight, bias = _source.weight, _source.bias
972
-
973
- _tmp = nn.Conv2d(
974
- in_channels=_source.in_channels,
975
- out_channels=_source.out_channels,
976
- kernel_size=_source.kernel_size,
977
- stride=_source.stride,
978
- padding=_source.padding,
979
- dilation=_source.dilation,
980
- groups=_source.groups,
981
- bias=bias is not None,
982
- )
983
-
984
- _tmp.weight = weight
985
- if bias is not None:
986
- _tmp.bias = bias
987
-
988
- _module._modules[name] = _tmp
989
-
990
-
991
- def monkeypatch_add_lora(
992
- model,
993
- loras,
994
- target_replace_module=DEFAULT_TARGET_REPLACE,
995
- alpha: float = 1.0,
996
- beta: float = 1.0,
997
- ):
998
- for _module, name, _child_module in _find_modules(
999
- model, target_replace_module, search_class=[LoraInjectedLinear]
1000
- ):
1001
- weight = _child_module.linear.weight
1002
-
1003
- up_weight = loras.pop(0)
1004
- down_weight = loras.pop(0)
1005
-
1006
- _module._modules[name].lora_up.weight = nn.Parameter(
1007
- up_weight.type(weight.dtype).to(weight.device) * alpha
1008
- + _module._modules[name].lora_up.weight.to(weight.device) * beta
1009
- )
1010
- _module._modules[name].lora_down.weight = nn.Parameter(
1011
- down_weight.type(weight.dtype).to(weight.device) * alpha
1012
- + _module._modules[name].lora_down.weight.to(weight.device) * beta
1013
- )
1014
-
1015
- _module._modules[name].to(weight.device)
1016
-
1017
-
1018
- def tune_lora_scale(model, alpha: float = 1.0):
1019
- for _module in model.modules():
1020
- if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1021
- _module.scale = alpha
1022
-
1023
-
1024
- def set_lora_diag(model, diag: torch.Tensor):
1025
- for _module in model.modules():
1026
- if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1027
- _module.set_selector_from_diag(diag)
1028
-
1029
-
1030
- def _text_lora_path(path: str) -> str:
1031
- assert path.endswith(".pt"), "Only .pt files are supported"
1032
- return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
1033
-
1034
-
1035
- def _ti_lora_path(path: str) -> str:
1036
- assert path.endswith(".pt"), "Only .pt files are supported"
1037
- return ".".join(path.split(".")[:-1] + ["ti", "pt"])
1038
-
1039
-
1040
- def apply_learned_embed_in_clip(
1041
- learned_embeds,
1042
- text_encoder,
1043
- tokenizer,
1044
- token: Optional[Union[str, List[str]]] = None,
1045
- idempotent=False,
1046
- ):
1047
- if isinstance(token, str):
1048
- trained_tokens = [token]
1049
- elif isinstance(token, list):
1050
- assert len(learned_embeds.keys()) == len(
1051
- token
1052
- ), "The number of tokens and the number of embeds should be the same"
1053
- trained_tokens = token
1054
- else:
1055
- trained_tokens = list(learned_embeds.keys())
1056
-
1057
- for token in trained_tokens:
1058
- print(token)
1059
- embeds = learned_embeds[token]
1060
-
1061
- # cast to dtype of text_encoder
1062
- dtype = text_encoder.get_input_embeddings().weight.dtype
1063
- num_added_tokens = tokenizer.add_tokens(token)
1064
-
1065
- i = 1
1066
- if not idempotent:
1067
- while num_added_tokens == 0:
1068
- print(f"The tokenizer already contains the token {token}.")
1069
- token = f"{token[:-1]}-{i}>"
1070
- print(f"Attempting to add the token {token}.")
1071
- num_added_tokens = tokenizer.add_tokens(token)
1072
- i += 1
1073
- elif num_added_tokens == 0 and idempotent:
1074
- print(f"The tokenizer already contains the token {token}.")
1075
- print(f"Replacing {token} embedding.")
1076
-
1077
- # resize the token embeddings
1078
- text_encoder.resize_token_embeddings(len(tokenizer))
1079
-
1080
- # get the id for the token and assign the embeds
1081
- token_id = tokenizer.convert_tokens_to_ids(token)
1082
- text_encoder.get_input_embeddings().weight.data[token_id] = embeds
1083
- return token
1084
-
1085
-
1086
- def load_learned_embed_in_clip(
1087
- learned_embeds_path,
1088
- text_encoder,
1089
- tokenizer,
1090
- token: Optional[Union[str, List[str]]] = None,
1091
- idempotent=False,
1092
- ):
1093
- learned_embeds = torch.load(learned_embeds_path)
1094
- apply_learned_embed_in_clip(
1095
- learned_embeds, text_encoder, tokenizer, token, idempotent
1096
- )
1097
-
1098
-
1099
- def patch_pipe(
1100
- pipe,
1101
- maybe_unet_path,
1102
- token: Optional[str] = None,
1103
- r: int = 4,
1104
- patch_unet=True,
1105
- patch_text=True,
1106
- patch_ti=True,
1107
- idempotent_token=True,
1108
- unet_target_replace_module=DEFAULT_TARGET_REPLACE,
1109
- text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1110
- ):
1111
- if maybe_unet_path.endswith(".pt"):
1112
- # torch format
1113
-
1114
- if maybe_unet_path.endswith(".ti.pt"):
1115
- unet_path = maybe_unet_path[:-6] + ".pt"
1116
- elif maybe_unet_path.endswith(".text_encoder.pt"):
1117
- unet_path = maybe_unet_path[:-16] + ".pt"
1118
- else:
1119
- unet_path = maybe_unet_path
1120
-
1121
- ti_path = _ti_lora_path(unet_path)
1122
- text_path = _text_lora_path(unet_path)
1123
-
1124
- if patch_unet:
1125
- print("LoRA : Patching Unet")
1126
- monkeypatch_or_replace_lora(
1127
- pipe.unet,
1128
- torch.load(unet_path),
1129
- r=r,
1130
- target_replace_module=unet_target_replace_module,
1131
- )
1132
-
1133
- if patch_text:
1134
- print("LoRA : Patching text encoder")
1135
- monkeypatch_or_replace_lora(
1136
- pipe.text_encoder,
1137
- torch.load(text_path),
1138
- target_replace_module=text_target_replace_module,
1139
- r=r,
1140
- )
1141
- if patch_ti:
1142
- print("LoRA : Patching token input")
1143
- token = load_learned_embed_in_clip(
1144
- ti_path,
1145
- pipe.text_encoder,
1146
- pipe.tokenizer,
1147
- token=token,
1148
- idempotent=idempotent_token,
1149
- )
1150
-
1151
- elif maybe_unet_path.endswith(".safetensors"):
1152
- safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1153
- monkeypatch_or_replace_safeloras(pipe, safeloras)
1154
- tok_dict = parse_safeloras_embeds(safeloras)
1155
- if patch_ti:
1156
- apply_learned_embed_in_clip(
1157
- tok_dict,
1158
- pipe.text_encoder,
1159
- pipe.tokenizer,
1160
- token=token,
1161
- idempotent=idempotent_token,
1162
- )
1163
- return tok_dict
1164
-
1165
-
1166
- @torch.no_grad()
1167
- def inspect_lora(model):
1168
- moved = {}
1169
-
1170
- for name, _module in model.named_modules():
1171
- if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1172
- ups = _module.lora_up.weight.data.clone()
1173
- downs = _module.lora_down.weight.data.clone()
1174
-
1175
- wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1176
-
1177
- dist = wght.flatten().abs().mean().item()
1178
- if name in moved:
1179
- moved[name].append(dist)
1180
- else:
1181
- moved[name] = [dist]
1182
-
1183
- return moved
1184
-
1185
-
1186
- def save_all(
1187
- unet,
1188
- text_encoder,
1189
- save_path,
1190
- placeholder_token_ids=None,
1191
- placeholder_tokens=None,
1192
- save_lora=True,
1193
- save_ti=True,
1194
- target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1195
- target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1196
- safe_form=True,
1197
- ):
1198
- if not safe_form:
1199
- # save ti
1200
- if save_ti:
1201
- ti_path = _ti_lora_path(save_path)
1202
- learned_embeds_dict = {}
1203
- for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1204
- learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1205
- print(
1206
- f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1207
- learned_embeds[:4],
1208
- )
1209
- learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1210
-
1211
- torch.save(learned_embeds_dict, ti_path)
1212
- print("Ti saved to ", ti_path)
1213
-
1214
- # save text encoder
1215
- if save_lora:
1216
-
1217
- save_lora_weight(
1218
- unet, save_path, target_replace_module=target_replace_module_unet
1219
- )
1220
- print("Unet saved to ", save_path)
1221
-
1222
- save_lora_weight(
1223
- text_encoder,
1224
- _text_lora_path(save_path),
1225
- target_replace_module=target_replace_module_text,
1226
- )
1227
- print("Text Encoder saved to ", _text_lora_path(save_path))
1228
-
1229
- else:
1230
- assert save_path.endswith(
1231
- ".safetensors"
1232
- ), f"Save path : {save_path} should end with .safetensors"
1233
-
1234
- loras = {}
1235
- embeds = {}
1236
-
1237
- if save_lora:
1238
-
1239
- loras["unet"] = (unet, target_replace_module_unet)
1240
- loras["text_encoder"] = (text_encoder, target_replace_module_text)
1241
-
1242
- if save_ti:
1243
- for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1244
- learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1245
- print(
1246
- f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1247
- learned_embeds[:4],
1248
- )
1249
- embeds[tok] = learned_embeds.detach().cpu()
1250
-
1251
- save_safeloras_with_embeds(loras, embeds, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/{samplers → models/samplers}/ddim.py RENAMED
@@ -1,10 +1,8 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
  import numpy as np
5
  from tqdm import tqdm
6
-
7
- from lvdm.models.modules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
 
8
 
9
 
10
  class DDIMSampler(object):
@@ -31,6 +29,15 @@ class DDIMSampler(object):
31
  self.register_buffer('betas', to_torch(self.model.betas))
32
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
  self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
 
 
 
 
 
 
 
 
 
34
 
35
  # calculations for diffusion q(x_t | x_{t-1}) and others
36
  self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
@@ -59,6 +66,7 @@ class DDIMSampler(object):
59
  shape,
60
  conditioning=None,
61
  callback=None,
 
62
  img_callback=None,
63
  quantize_x0=False,
64
  eta=0.,
@@ -74,9 +82,6 @@ class DDIMSampler(object):
74
  log_every_t=100,
75
  unconditional_guidance_scale=1.,
76
  unconditional_conditioning=None,
77
- postprocess_fn=None,
78
- sample_noise=None,
79
- cond_fn=None,
80
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
  **kwargs
82
  ):
@@ -86,11 +91,11 @@ class DDIMSampler(object):
86
  if isinstance(conditioning, dict):
87
  try:
88
  cbs = conditioning[list(conditioning.keys())[0]].shape[0]
89
- if cbs != batch_size:
90
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
91
  except:
92
- # cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
93
- pass
 
 
94
  else:
95
  if conditioning.shape[0] != batch_size:
96
  print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@@ -104,6 +109,7 @@ class DDIMSampler(object):
104
  elif len(shape) == 4:
105
  C, T, H, W = shape
106
  size = (batch_size, C, T, H, W)
 
107
 
108
  samples, intermediates = self.ddim_sampling(conditioning, size,
109
  callback=callback,
@@ -119,12 +125,8 @@ class DDIMSampler(object):
119
  log_every_t=log_every_t,
120
  unconditional_guidance_scale=unconditional_guidance_scale,
121
  unconditional_conditioning=unconditional_conditioning,
122
- postprocess_fn=postprocess_fn,
123
- sample_noise=sample_noise,
124
- cond_fn=cond_fn,
125
  verbose=verbose,
126
- **kwargs
127
- )
128
  return samples, intermediates
129
 
130
  @torch.no_grad()
@@ -133,13 +135,11 @@ class DDIMSampler(object):
133
  callback=None, timesteps=None, quantize_denoised=False,
134
  mask=None, x0=None, img_callback=None, log_every_t=100,
135
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
136
- unconditional_guidance_scale=1., unconditional_conditioning=None,
137
- postprocess_fn=None,sample_noise=None,cond_fn=None,
138
- uc_type=None, verbose=True, **kwargs,
139
- ):
140
-
141
- device = self.model.betas.device
142
-
143
  b = shape[0]
144
  if x_T is None:
145
  img = torch.randn(shape, device=device)
@@ -151,6 +151,7 @@ class DDIMSampler(object):
151
  elif timesteps is not None and not ddim_use_original_steps:
152
  subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
153
  timesteps = self.ddim_timesteps[:subset_end]
 
154
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
155
  time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
156
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
@@ -159,31 +160,46 @@ class DDIMSampler(object):
159
  else:
160
  iterator = time_range
161
 
 
 
162
  for i, step in enumerate(iterator):
163
  index = total_steps - i - 1
164
  ts = torch.full((b,), step, device=device, dtype=torch.long)
 
 
 
 
 
 
 
165
 
166
- if postprocess_fn is not None:
167
- img = postprocess_fn(img, ts)
 
 
 
 
 
 
168
 
 
 
 
 
 
 
 
 
169
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
170
  quantize_denoised=quantize_denoised, temperature=temperature,
171
  noise_dropout=noise_dropout, score_corrector=score_corrector,
172
  corrector_kwargs=corrector_kwargs,
173
  unconditional_guidance_scale=unconditional_guidance_scale,
174
  unconditional_conditioning=unconditional_conditioning,
175
- sample_noise=sample_noise,cond_fn=cond_fn,uc_type=uc_type, **kwargs,)
176
- img, pred_x0 = outs
177
-
178
- if mask is not None:
179
- # use mask to blend x_known_t-1 & x_sample_t-1
180
- assert x0 is not None
181
- x0 = x0.to(img.device)
182
- mask = mask.to(img.device)
183
- t = torch.tensor([step-1]*x0.shape[0], dtype=torch.long, device=img.device)
184
- img_known = self.model.q_sample(x0, t)
185
- img = img_known * mask + (1. - mask) * img
186
 
 
187
  if callback: callback(i)
188
  if img_callback: img_callback(pred_x0, i)
189
 
@@ -196,10 +212,8 @@ class DDIMSampler(object):
196
  @torch.no_grad()
197
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199
- unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None,
200
- cond_fn=None, uc_type=None,
201
- **kwargs,
202
- ):
203
  b, *_, device = *x.shape, x.device
204
  if x.dim() == 5:
205
  is_video = True
@@ -227,7 +241,12 @@ class DDIMSampler(object):
227
  e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t)
228
  else:
229
  raise NotImplementedError
230
-
 
 
 
 
 
231
  if score_corrector is not None:
232
  assert self.model.parameterization == "eps"
233
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
@@ -249,19 +268,69 @@ class DDIMSampler(object):
249
 
250
  # current prediction for x_0
251
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
252
- # print(f't={t}, pred_x0, min={torch.min(pred_x0)}, max={torch.max(pred_x0)}',file=f)
253
  if quantize_denoised:
254
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
255
  # direction pointing to x_t
256
  dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
 
 
 
 
257
 
258
- if sample_noise is None:
259
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
260
- if noise_dropout > 0.:
261
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
 
 
 
 
262
  else:
263
- noise = sigma_t * sample_noise * temperature
264
-
265
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
266
-
267
  return x_prev, pred_x0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from tqdm import tqdm
3
+ import torch
4
+ from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps
5
+ from lvdm.common import noise_like
6
 
7
 
8
  class DDIMSampler(object):
 
29
  self.register_buffer('betas', to_torch(self.model.betas))
30
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
31
  self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
32
+ self.use_scale = self.model.use_scale
33
+ print('DDIM scale', self.use_scale)
34
+
35
+ if self.use_scale:
36
+ self.register_buffer('scale_arr', to_torch(self.model.scale_arr))
37
+ ddim_scale_arr = self.scale_arr.cpu()[self.ddim_timesteps]
38
+ self.register_buffer('ddim_scale_arr', ddim_scale_arr)
39
+ ddim_scale_arr = np.asarray([self.scale_arr.cpu()[0]] + self.scale_arr.cpu()[self.ddim_timesteps[:-1]].tolist())
40
+ self.register_buffer('ddim_scale_arr_prev', ddim_scale_arr)
41
 
42
  # calculations for diffusion q(x_t | x_{t-1}) and others
43
  self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
 
66
  shape,
67
  conditioning=None,
68
  callback=None,
69
+ normals_sequence=None,
70
  img_callback=None,
71
  quantize_x0=False,
72
  eta=0.,
 
82
  log_every_t=100,
83
  unconditional_guidance_scale=1.,
84
  unconditional_conditioning=None,
 
 
 
85
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
86
  **kwargs
87
  ):
 
91
  if isinstance(conditioning, dict):
92
  try:
93
  cbs = conditioning[list(conditioning.keys())[0]].shape[0]
 
 
94
  except:
95
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
96
+
97
+ if cbs != batch_size:
98
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
99
  else:
100
  if conditioning.shape[0] != batch_size:
101
  print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
 
109
  elif len(shape) == 4:
110
  C, T, H, W = shape
111
  size = (batch_size, C, T, H, W)
112
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
113
 
114
  samples, intermediates = self.ddim_sampling(conditioning, size,
115
  callback=callback,
 
125
  log_every_t=log_every_t,
126
  unconditional_guidance_scale=unconditional_guidance_scale,
127
  unconditional_conditioning=unconditional_conditioning,
 
 
 
128
  verbose=verbose,
129
+ **kwargs)
 
130
  return samples, intermediates
131
 
132
  @torch.no_grad()
 
135
  callback=None, timesteps=None, quantize_denoised=False,
136
  mask=None, x0=None, img_callback=None, log_every_t=100,
137
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
138
+ unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,
139
+ cond_tau=1., target_size=None, start_timesteps=None,
140
+ **kwargs):
141
+ device = self.model.betas.device
142
+ print('ddim device', device)
 
 
143
  b = shape[0]
144
  if x_T is None:
145
  img = torch.randn(shape, device=device)
 
151
  elif timesteps is not None and not ddim_use_original_steps:
152
  subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
153
  timesteps = self.ddim_timesteps[:subset_end]
154
+
155
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
156
  time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
157
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
 
160
  else:
161
  iterator = time_range
162
 
163
+ init_x0 = False
164
+ clean_cond = kwargs.pop("clean_cond", False)
165
  for i, step in enumerate(iterator):
166
  index = total_steps - i - 1
167
  ts = torch.full((b,), step, device=device, dtype=torch.long)
168
+ if start_timesteps is not None:
169
+ assert x0 is not None
170
+ if step > start_timesteps*time_range[0]:
171
+ continue
172
+ elif not init_x0:
173
+ img = self.model.q_sample(x0, ts)
174
+ init_x0 = True
175
 
176
+ # use mask to blend noised original latent (img_orig) & new sampled latent (img)
177
+ if mask is not None:
178
+ assert x0 is not None
179
+ if clean_cond:
180
+ img_orig = x0
181
+ else:
182
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? <ddim inversion>
183
+ img = img_orig * mask + (1. - mask) * img # keep original & modify use img
184
 
185
+ index_clip = int((1 - cond_tau) * total_steps)
186
+ if index <= index_clip and target_size is not None:
187
+ target_size_ = [target_size[0], target_size[1]//8, target_size[2]//8]
188
+ img = torch.nn.functional.interpolate(
189
+ img,
190
+ size=target_size_,
191
+ mode="nearest",
192
+ )
193
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
194
  quantize_denoised=quantize_denoised, temperature=temperature,
195
  noise_dropout=noise_dropout, score_corrector=score_corrector,
196
  corrector_kwargs=corrector_kwargs,
197
  unconditional_guidance_scale=unconditional_guidance_scale,
198
  unconditional_conditioning=unconditional_conditioning,
199
+ x0=x0,
200
+ **kwargs)
 
 
 
 
 
 
 
 
 
201
 
202
+ img, pred_x0 = outs
203
  if callback: callback(i)
204
  if img_callback: img_callback(pred_x0, i)
205
 
 
212
  @torch.no_grad()
213
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
214
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
215
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
216
+ uc_type=None, conditional_guidance_scale_temporal=None, **kwargs):
 
 
217
  b, *_, device = *x.shape, x.device
218
  if x.dim() == 5:
219
  is_video = True
 
241
  e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t)
242
  else:
243
  raise NotImplementedError
244
+ # temporal guidance
245
+ if conditional_guidance_scale_temporal is not None:
246
+ e_t_temporal = self.model.apply_model(x, t, c, **kwargs)
247
+ e_t_image = self.model.apply_model(x, t, c, no_temporal_attn=True, **kwargs)
248
+ e_t = e_t + conditional_guidance_scale_temporal * (e_t_temporal - e_t_image)
249
+
250
  if score_corrector is not None:
251
  assert self.model.parameterization == "eps"
252
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
 
268
 
269
  # current prediction for x_0
270
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
 
271
  if quantize_denoised:
272
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
273
  # direction pointing to x_t
274
  dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
275
+
276
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
277
+ if noise_dropout > 0.:
278
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
279
 
280
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
281
+ if self.use_scale:
282
+ scale_arr = self.model.scale_arr if use_original_steps else self.ddim_scale_arr
283
+ scale_t = torch.full(size, scale_arr[index], device=device)
284
+ scale_arr_prev = self.model.scale_arr_prev if use_original_steps else self.ddim_scale_arr_prev
285
+ scale_t_prev = torch.full(size, scale_arr_prev[index], device=device)
286
+ pred_x0 /= scale_t
287
+ x_prev = a_prev.sqrt() * scale_t_prev * pred_x0 + dir_xt + noise
288
  else:
289
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
290
+
 
 
291
  return x_prev, pred_x0
292
+
293
+
294
+ @torch.no_grad()
295
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
296
+ # fast, but does not allow for exact reconstruction
297
+ # t serves as an index to gather the correct alphas
298
+ if use_original_steps:
299
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
300
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
301
+ else:
302
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
303
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
304
+
305
+ if noise is None:
306
+ noise = torch.randn_like(x0)
307
+
308
+ def extract_into_tensor(a, t, x_shape):
309
+ b, *_ = t.shape
310
+ out = a.gather(-1, t)
311
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
312
+
313
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
314
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
315
+
316
+ @torch.no_grad()
317
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
318
+ use_original_steps=False):
319
+
320
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
321
+ timesteps = timesteps[:t_start]
322
+
323
+ time_range = np.flip(timesteps)
324
+ total_steps = timesteps.shape[0]
325
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
326
+
327
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
328
+ x_dec = x_latent
329
+ for i, step in enumerate(iterator):
330
+ index = total_steps - i - 1
331
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
332
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
333
+ unconditional_guidance_scale=unconditional_guidance_scale,
334
+ unconditional_conditioning=unconditional_conditioning)
335
+ return x_dec
336
+
lvdm/models/{modules/util.py → utils_diffusion.py} RENAMED
@@ -1,13 +1,31 @@
1
  import math
2
- from inspect import isfunction
3
-
4
- import torch
5
  import numpy as np
6
- import torch.nn as nn
7
  from einops import repeat
 
8
  import torch.nn.functional as F
9
 
10
- from lvdm.utils.common_utils import instantiate_from_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@@ -15,6 +33,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
15
  betas = (
16
  torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
17
  )
 
18
  elif schedule == "cosine":
19
  timesteps = (
20
  torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
@@ -24,6 +43,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
24
  alphas = alphas / alphas[0]
25
  betas = 1 - alphas[1:] / alphas[:-1]
26
  betas = np.clip(betas, a_min=0, a_max=0.999)
 
27
  elif schedule == "sqrt_linear":
28
  betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
29
  elif schedule == "sqrt":
@@ -42,6 +62,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
42
  else:
43
  raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
44
 
 
45
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
46
  steps_out = ddim_timesteps + 1
47
  if verbose:
@@ -51,6 +72,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
51
 
52
  def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
53
  # select alphas for computing the variance schedule
 
54
  alphas = alphacums[ddim_timesteps]
55
  alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
56
 
@@ -79,270 +101,4 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
79
  t1 = i / num_diffusion_timesteps
80
  t2 = (i + 1) / num_diffusion_timesteps
81
  betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
82
- return np.array(betas)
83
-
84
-
85
- def extract_into_tensor(a, t, x_shape):
86
- b, *_ = t.shape
87
- out = a.gather(-1, t)
88
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
89
-
90
-
91
- def checkpoint(func, inputs, params, flag):
92
- """
93
- Evaluate a function without caching intermediate activations, allowing for
94
- reduced memory at the expense of extra compute in the backward pass.
95
- :param func: the function to evaluate.
96
- :param inputs: the argument sequence to pass to `func`.
97
- :param params: a sequence of parameters `func` depends on but does not
98
- explicitly take as arguments.
99
- :param flag: if False, disable gradient checkpointing.
100
- """
101
- if flag:
102
- args = tuple(inputs) + tuple(params)
103
- return CheckpointFunction.apply(func, len(inputs), *args)
104
- else:
105
- return func(*inputs)
106
-
107
-
108
- class CheckpointFunction(torch.autograd.Function):
109
- @staticmethod
110
- @torch.cuda.amp.custom_fwd
111
- def forward(ctx, run_function, length, *args):
112
- ctx.run_function = run_function
113
- ctx.input_tensors = list(args[:length])
114
- ctx.input_params = list(args[length:])
115
-
116
- with torch.no_grad():
117
- output_tensors = ctx.run_function(*ctx.input_tensors)
118
- return output_tensors
119
-
120
- @staticmethod
121
- @torch.cuda.amp.custom_bwd
122
- def backward(ctx, *output_grads):
123
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
124
- with torch.enable_grad():
125
- # Fixes a bug where the first op in run_function modifies the
126
- # Tensor storage in place, which is not allowed for detach()'d
127
- # Tensors.
128
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
129
- output_tensors = ctx.run_function(*shallow_copies)
130
- input_grads = torch.autograd.grad(
131
- output_tensors,
132
- ctx.input_tensors + ctx.input_params,
133
- output_grads,
134
- allow_unused=True,
135
- )
136
- del ctx.input_tensors
137
- del ctx.input_params
138
- del output_tensors
139
- return (None, None) + input_grads
140
-
141
-
142
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
143
- """
144
- Create sinusoidal timestep embeddings.
145
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
146
- These may be fractional.
147
- :param dim: the dimension of the output.
148
- :param max_period: controls the minimum frequency of the embeddings.
149
- :return: an [N x dim] Tensor of positional embeddings.
150
- """
151
- if not repeat_only:
152
- half = dim // 2
153
- freqs = torch.exp(
154
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
155
- ).to(device=timesteps.device)
156
- args = timesteps[:, None].float() * freqs[None]
157
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
158
- if dim % 2:
159
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
160
- else:
161
- embedding = repeat(timesteps, 'b -> b d', d=dim)
162
- return embedding
163
-
164
-
165
- def zero_module(module):
166
- """
167
- Zero out the parameters of a module and return it.
168
- """
169
- for p in module.parameters():
170
- p.detach().zero_()
171
- return module
172
-
173
-
174
- def scale_module(module, scale):
175
- """
176
- Scale the parameters of a module and return it.
177
- """
178
- for p in module.parameters():
179
- p.detach().mul_(scale)
180
- return module
181
-
182
-
183
- def mean_flat(tensor):
184
- """
185
- Take the mean over all non-batch dimensions.
186
- """
187
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
188
-
189
-
190
- def normalization(channels):
191
- """
192
- Make a standard normalization layer.
193
- :param channels: number of input channels.
194
- :return: an nn.Module for normalization.
195
- """
196
- return GroupNorm32(32, channels)
197
-
198
- def Normalize(in_channels):
199
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
200
-
201
- def identity(*args, **kwargs):
202
- return nn.Identity()
203
-
204
- class Normalization(nn.Module):
205
- def __init__(self, output_size, eps=1e-5, norm_type='gn'):
206
- super(Normalization, self).__init__()
207
- # epsilon to avoid dividing by 0
208
- self.eps = eps
209
- self.norm_type = norm_type
210
-
211
- if self.norm_type in ['bn', 'in']:
212
- self.register_buffer('stored_mean', torch.zeros(output_size))
213
- self.register_buffer('stored_var', torch.ones(output_size))
214
-
215
- def forward(self, x):
216
- if self.norm_type == 'bn':
217
- out = F.batch_norm(x, self.stored_mean, self.stored_var, None,
218
- None,
219
- self.training, 0.1, self.eps)
220
- elif self.norm_type == 'in':
221
- out = F.instance_norm(x, self.stored_mean, self.stored_var,
222
- None, None,
223
- self.training, 0.1, self.eps)
224
- elif self.norm_type == 'gn':
225
- out = F.group_norm(x, 32)
226
- elif self.norm_type == 'nonorm':
227
- out = x
228
- return out
229
-
230
-
231
- class CCNormalization(nn.Module):
232
- def __init__(self, embed_dim, feature_dim, *args, **kwargs):
233
- super(CCNormalization, self).__init__()
234
-
235
- self.embed_dim = embed_dim
236
- self.feature_dim = feature_dim
237
- self.norm = Normalization(feature_dim, *args, **kwargs)
238
-
239
- self.gain = nn.Linear(self.embed_dim, self.feature_dim)
240
- self.bias = nn.Linear(self.embed_dim, self.feature_dim)
241
-
242
- def forward(self, x, y):
243
- shape = [1] * (x.dim() - 2)
244
- gain = (1 + self.gain(y)).view(y.size(0), -1, *shape)
245
- bias = self.bias(y).view(y.size(0), -1, *shape)
246
- return self.norm(x) * gain + bias
247
-
248
-
249
- def nonlinearity(type='silu'):
250
- if type == 'silu':
251
- return nn.SiLU()
252
- elif type == 'leaky_relu':
253
- return nn.LeakyReLU()
254
-
255
-
256
- class GEGLU(nn.Module):
257
- def __init__(self, dim_in, dim_out):
258
- super().__init__()
259
- self.proj = nn.Linear(dim_in, dim_out * 2)
260
-
261
- def forward(self, x):
262
- x, gate = self.proj(x).chunk(2, dim=-1)
263
- return x * F.gelu(gate)
264
-
265
-
266
- class SiLU(nn.Module):
267
- def forward(self, x):
268
- return x * torch.sigmoid(x)
269
-
270
-
271
- class GroupNorm32(nn.GroupNorm):
272
- def forward(self, x):
273
- return super().forward(x.float()).type(x.dtype)
274
-
275
-
276
- def conv_nd(dims, *args, **kwargs):
277
- """
278
- Create a 1D, 2D, or 3D convolution module.
279
- """
280
- if dims == 1:
281
- return nn.Conv1d(*args, **kwargs)
282
- elif dims == 2:
283
- return nn.Conv2d(*args, **kwargs)
284
- elif dims == 3:
285
- return nn.Conv3d(*args, **kwargs)
286
- raise ValueError(f"unsupported dimensions: {dims}")
287
-
288
-
289
- def linear(*args, **kwargs):
290
- """
291
- Create a linear module.
292
- """
293
- return nn.Linear(*args, **kwargs)
294
-
295
-
296
- def avg_pool_nd(dims, *args, **kwargs):
297
- """
298
- Create a 1D, 2D, or 3D average pooling module.
299
- """
300
- if dims == 1:
301
- return nn.AvgPool1d(*args, **kwargs)
302
- elif dims == 2:
303
- return nn.AvgPool2d(*args, **kwargs)
304
- elif dims == 3:
305
- return nn.AvgPool3d(*args, **kwargs)
306
- raise ValueError(f"unsupported dimensions: {dims}")
307
-
308
-
309
- class HybridConditioner(nn.Module):
310
-
311
- def __init__(self, c_concat_config, c_crossattn_config):
312
- super().__init__()
313
- self.concat_conditioner = instantiate_from_config(c_concat_config)
314
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
315
-
316
- def forward(self, c_concat, c_crossattn):
317
- c_concat = self.concat_conditioner(c_concat)
318
- c_crossattn = self.crossattn_conditioner(c_crossattn)
319
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
320
-
321
-
322
- def noise_like(shape, device, repeat=False):
323
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
324
- noise = lambda: torch.randn(shape, device=device)
325
- return repeat_noise() if repeat else noise()
326
-
327
-
328
- def init_(tensor):
329
- dim = tensor.shape[-1]
330
- std = 1 / math.sqrt(dim)
331
- tensor.uniform_(-std, std)
332
- return tensor
333
-
334
-
335
- def exists(val):
336
- return val is not None
337
-
338
-
339
- def uniq(arr):
340
- return{el: True for el in arr}.keys()
341
-
342
-
343
- def default(val, d):
344
- if exists(val):
345
- return val
346
- return d() if isfunction(d) else d
347
-
348
-
 
1
  import math
 
 
 
2
  import numpy as np
 
3
  from einops import repeat
4
+ import torch
5
  import torch.nn.functional as F
6
 
7
+
8
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
9
+ """
10
+ Create sinusoidal timestep embeddings.
11
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
12
+ These may be fractional.
13
+ :param dim: the dimension of the output.
14
+ :param max_period: controls the minimum frequency of the embeddings.
15
+ :return: an [N x dim] Tensor of positional embeddings.
16
+ """
17
+ if not repeat_only:
18
+ half = dim // 2
19
+ freqs = torch.exp(
20
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
21
+ ).to(device=timesteps.device)
22
+ args = timesteps[:, None].float() * freqs[None]
23
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
24
+ if dim % 2:
25
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
26
+ else:
27
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
28
+ return embedding
29
 
30
 
31
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 
33
  betas = (
34
  torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
35
  )
36
+
37
  elif schedule == "cosine":
38
  timesteps = (
39
  torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
 
43
  alphas = alphas / alphas[0]
44
  betas = 1 - alphas[1:] / alphas[:-1]
45
  betas = np.clip(betas, a_min=0, a_max=0.999)
46
+
47
  elif schedule == "sqrt_linear":
48
  betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
49
  elif schedule == "sqrt":
 
62
  else:
63
  raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
64
 
65
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
66
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
67
  steps_out = ddim_timesteps + 1
68
  if verbose:
 
72
 
73
  def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
74
  # select alphas for computing the variance schedule
75
+ # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
76
  alphas = alphacums[ddim_timesteps]
77
  alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
78
 
 
101
  t1 = i / num_diffusion_timesteps
102
  t2 = (i + 1) / num_diffusion_timesteps
103
  betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
104
+ return np.array(betas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/modules/attention.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ from torch import nn, einsum
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ try:
7
+ import xformers
8
+ import xformers.ops
9
+ XFORMERS_IS_AVAILBLE = True
10
+ except:
11
+ XFORMERS_IS_AVAILBLE = False
12
+ from lvdm.common import (
13
+ checkpoint,
14
+ exists,
15
+ default,
16
+ )
17
+ from lvdm.basics import (
18
+ zero_module,
19
+ )
20
+
21
+ class RelativePosition(nn.Module):
22
+ """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
23
+
24
+ def __init__(self, num_units, max_relative_position):
25
+ super().__init__()
26
+ self.num_units = num_units
27
+ self.max_relative_position = max_relative_position
28
+ self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
29
+ nn.init.xavier_uniform_(self.embeddings_table)
30
+
31
+ def forward(self, length_q, length_k):
32
+ device = self.embeddings_table.device
33
+ range_vec_q = torch.arange(length_q, device=device)
34
+ range_vec_k = torch.arange(length_k, device=device)
35
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
36
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
37
+ final_mat = distance_mat_clipped + self.max_relative_position
38
+ final_mat = final_mat.long()
39
+ embeddings = self.embeddings_table[final_mat]
40
+ return embeddings
41
+
42
+
43
+ class CrossAttention(nn.Module):
44
+
45
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
46
+ relative_position=False, temporal_length=None, img_cross_attention=False):
47
+ super().__init__()
48
+ inner_dim = dim_head * heads
49
+ context_dim = default(context_dim, query_dim)
50
+
51
+ self.scale = dim_head**-0.5
52
+ self.heads = heads
53
+ self.dim_head = dim_head
54
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
55
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
56
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
57
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
58
+
59
+ self.image_cross_attention_scale = 1.0
60
+ self.text_context_len = 77
61
+ self.img_cross_attention = img_cross_attention
62
+ if self.img_cross_attention:
63
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
64
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
65
+
66
+ self.relative_position = relative_position
67
+ if self.relative_position:
68
+ assert(temporal_length is not None)
69
+ self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
70
+ self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
71
+ else:
72
+ ## only used for spatial attention, while NOT for temporal attention
73
+ if XFORMERS_IS_AVAILBLE and temporal_length is None:
74
+ self.forward = self.efficient_forward
75
+
76
+ def forward(self, x, context=None, mask=None):
77
+ h = self.heads
78
+
79
+ q = self.to_q(x)
80
+ context = default(context, x)
81
+ ## considering image token additionally
82
+ if context is not None and self.img_cross_attention:
83
+ context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:]
84
+ k = self.to_k(context)
85
+ v = self.to_v(context)
86
+ k_ip = self.to_k_ip(context_img)
87
+ v_ip = self.to_v_ip(context_img)
88
+ else:
89
+ k = self.to_k(context)
90
+ v = self.to_v(context)
91
+
92
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
93
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
94
+ if self.relative_position:
95
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
96
+ k2 = self.relative_position_k(len_q, len_k)
97
+ sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check
98
+ sim += sim2
99
+ del k
100
+
101
+ if exists(mask):
102
+ ## feasible for causal attention mask only
103
+ max_neg_value = -torch.finfo(sim.dtype).max
104
+ mask = repeat(mask, 'b i j -> (b h) i j', h=h)
105
+ sim.masked_fill_(~(mask>0.5), max_neg_value)
106
+
107
+ # attention, what we cannot get enough of
108
+ sim = sim.softmax(dim=-1)
109
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
110
+ if self.relative_position:
111
+ v2 = self.relative_position_v(len_q, len_v)
112
+ out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
113
+ out += out2
114
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
115
+
116
+ ## considering image token additionally
117
+ if context is not None and self.img_cross_attention:
118
+ k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip))
119
+ sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale
120
+ del k_ip
121
+ sim_ip = sim_ip.softmax(dim=-1)
122
+ out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
123
+ out_ip = rearrange(out, '(b h) n d -> b n (h d)', h=h)
124
+ out = out + self.image_cross_attention_scale * out_ip
125
+ del q
126
+
127
+ return self.to_out(out)
128
+
129
+ def efficient_forward(self, x, context=None, mask=None):
130
+ q = self.to_q(x)
131
+ context = default(context, x)
132
+
133
+ ## considering image token additionally
134
+ if context is not None and self.img_cross_attention:
135
+ context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:]
136
+ k = self.to_k(context)
137
+ v = self.to_v(context)
138
+ k_ip = self.to_k_ip(context_img)
139
+ v_ip = self.to_v_ip(context_img)
140
+ else:
141
+ k = self.to_k(context)
142
+ v = self.to_v(context)
143
+
144
+ b, _, _ = q.shape
145
+ q, k, v = map(
146
+ lambda t: t.unsqueeze(3)
147
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
148
+ .permute(0, 2, 1, 3)
149
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
150
+ .contiguous(),
151
+ (q, k, v),
152
+ )
153
+ # actually compute the attention, what we cannot get enough of
154
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
155
+
156
+ ## considering image token additionally
157
+ if context is not None and self.img_cross_attention:
158
+ k_ip, v_ip = map(
159
+ lambda t: t.unsqueeze(3)
160
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
161
+ .permute(0, 2, 1, 3)
162
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
163
+ .contiguous(),
164
+ (k_ip, v_ip),
165
+ )
166
+ out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None)
167
+ out_ip = (
168
+ out_ip.unsqueeze(0)
169
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
170
+ .permute(0, 2, 1, 3)
171
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
172
+ )
173
+
174
+ if exists(mask):
175
+ raise NotImplementedError
176
+ out = (
177
+ out.unsqueeze(0)
178
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
179
+ .permute(0, 2, 1, 3)
180
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
181
+ )
182
+ if context is not None and self.img_cross_attention:
183
+ out = out + self.image_cross_attention_scale * out_ip
184
+ return self.to_out(out)
185
+
186
+
187
+ class BasicTransformerBlock(nn.Module):
188
+
189
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
190
+ disable_self_attn=False, attention_cls=None, img_cross_attention=False):
191
+ super().__init__()
192
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
193
+ self.disable_self_attn = disable_self_attn
194
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
195
+ context_dim=context_dim if self.disable_self_attn else None)
196
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
197
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
198
+ img_cross_attention=img_cross_attention)
199
+ self.norm1 = nn.LayerNorm(dim)
200
+ self.norm2 = nn.LayerNorm(dim)
201
+ self.norm3 = nn.LayerNorm(dim)
202
+ self.checkpoint = checkpoint
203
+
204
+ def forward(self, x, context=None, mask=None):
205
+ ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
206
+ input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
207
+ if context is not None:
208
+ input_tuple = (x, context)
209
+ if mask is not None:
210
+ forward_mask = partial(self._forward, mask=mask)
211
+ return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
212
+ if context is not None and mask is not None:
213
+ input_tuple = (x, context, mask)
214
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint)
215
+
216
+ def _forward(self, x, context=None, mask=None):
217
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
218
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
219
+ x = self.ff(self.norm3(x)) + x
220
+ return x
221
+
222
+
223
+ class SpatialTransformer(nn.Module):
224
+ """
225
+ Transformer block for image-like data in spatial axis.
226
+ First, project the input (aka embedding)
227
+ and reshape to b, t, d.
228
+ Then apply standard transformer action.
229
+ Finally, reshape to image
230
+ NEW: use_linear for more efficiency instead of the 1x1 convs
231
+ """
232
+
233
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
234
+ use_checkpoint=True, disable_self_attn=False, use_linear=False, img_cross_attention=False):
235
+ super().__init__()
236
+ self.in_channels = in_channels
237
+ inner_dim = n_heads * d_head
238
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
239
+ if not use_linear:
240
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
241
+ else:
242
+ self.proj_in = nn.Linear(in_channels, inner_dim)
243
+
244
+ self.transformer_blocks = nn.ModuleList([
245
+ BasicTransformerBlock(
246
+ inner_dim,
247
+ n_heads,
248
+ d_head,
249
+ dropout=dropout,
250
+ context_dim=context_dim,
251
+ img_cross_attention=img_cross_attention,
252
+ disable_self_attn=disable_self_attn,
253
+ checkpoint=use_checkpoint) for d in range(depth)
254
+ ])
255
+ if not use_linear:
256
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
257
+ else:
258
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
259
+ self.use_linear = use_linear
260
+
261
+
262
+ def forward(self, x, context=None):
263
+ b, c, h, w = x.shape
264
+ x_in = x
265
+ x = self.norm(x)
266
+ if not self.use_linear:
267
+ x = self.proj_in(x)
268
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
269
+ if self.use_linear:
270
+ x = self.proj_in(x)
271
+ for i, block in enumerate(self.transformer_blocks):
272
+ x = block(x, context=context)
273
+ if self.use_linear:
274
+ x = self.proj_out(x)
275
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
276
+ if not self.use_linear:
277
+ x = self.proj_out(x)
278
+ return x + x_in
279
+
280
+
281
+ class TemporalTransformer(nn.Module):
282
+ """
283
+ Transformer block for image-like data in temporal axis.
284
+ First, reshape to b, t, d.
285
+ Then apply standard transformer action.
286
+ Finally, reshape to image
287
+ """
288
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
289
+ use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False,
290
+ relative_position=False, temporal_length=None):
291
+ super().__init__()
292
+ self.only_self_att = only_self_att
293
+ self.relative_position = relative_position
294
+ self.causal_attention = causal_attention
295
+ self.in_channels = in_channels
296
+ inner_dim = n_heads * d_head
297
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
299
+ if not use_linear:
300
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
301
+ else:
302
+ self.proj_in = nn.Linear(in_channels, inner_dim)
303
+
304
+ if relative_position:
305
+ assert(temporal_length is not None)
306
+ attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
307
+ else:
308
+ attention_cls = None
309
+ if self.causal_attention:
310
+ assert(temporal_length is not None)
311
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
312
+
313
+ if self.only_self_att:
314
+ context_dim = None
315
+ self.transformer_blocks = nn.ModuleList([
316
+ BasicTransformerBlock(
317
+ inner_dim,
318
+ n_heads,
319
+ d_head,
320
+ dropout=dropout,
321
+ context_dim=context_dim,
322
+ attention_cls=attention_cls,
323
+ checkpoint=use_checkpoint) for d in range(depth)
324
+ ])
325
+ if not use_linear:
326
+ self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
327
+ else:
328
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
329
+ self.use_linear = use_linear
330
+
331
+ def forward(self, x, context=None):
332
+ b, c, t, h, w = x.shape
333
+ x_in = x
334
+ x = self.norm(x)
335
+ x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
336
+ if not self.use_linear:
337
+ x = self.proj_in(x)
338
+ x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
339
+ if self.use_linear:
340
+ x = self.proj_in(x)
341
+
342
+ if self.causal_attention:
343
+ mask = self.mask.to(x.device)
344
+ mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
345
+ else:
346
+ mask = None
347
+
348
+ if self.only_self_att:
349
+ ## note: if no context is given, cross-attention defaults to self-attention
350
+ for i, block in enumerate(self.transformer_blocks):
351
+ x = block(x, mask=mask)
352
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
353
+ else:
354
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
355
+ context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
356
+ for i, block in enumerate(self.transformer_blocks):
357
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
358
+ for j in range(b):
359
+ context_j = repeat(
360
+ context[j],
361
+ 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
362
+ ## note: causal mask will not applied in cross-attention case
363
+ x[j] = block(x[j], context=context_j)
364
+
365
+ if self.use_linear:
366
+ x = self.proj_out(x)
367
+ x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
368
+ if not self.use_linear:
369
+ x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
370
+ x = self.proj_out(x)
371
+ x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
372
+
373
+ return x + x_in
374
+
375
+
376
+ class GEGLU(nn.Module):
377
+ def __init__(self, dim_in, dim_out):
378
+ super().__init__()
379
+ self.proj = nn.Linear(dim_in, dim_out * 2)
380
+
381
+ def forward(self, x):
382
+ x, gate = self.proj(x).chunk(2, dim=-1)
383
+ return x * F.gelu(gate)
384
+
385
+
386
+ class FeedForward(nn.Module):
387
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
388
+ super().__init__()
389
+ inner_dim = int(dim * mult)
390
+ dim_out = default(dim_out, dim)
391
+ project_in = nn.Sequential(
392
+ nn.Linear(dim, inner_dim),
393
+ nn.GELU()
394
+ ) if not glu else GEGLU(dim, inner_dim)
395
+
396
+ self.net = nn.Sequential(
397
+ project_in,
398
+ nn.Dropout(dropout),
399
+ nn.Linear(inner_dim, dim_out)
400
+ )
401
+
402
+ def forward(self, x):
403
+ return self.net(x)
404
+
405
+
406
+ class LinearAttention(nn.Module):
407
+ def __init__(self, dim, heads=4, dim_head=32):
408
+ super().__init__()
409
+ self.heads = heads
410
+ hidden_dim = dim_head * heads
411
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
412
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
413
+
414
+ def forward(self, x):
415
+ b, c, h, w = x.shape
416
+ qkv = self.to_qkv(x)
417
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
418
+ k = k.softmax(dim=-1)
419
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
420
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
421
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
422
+ return self.to_out(out)
423
+
424
+
425
+ class SpatialSelfAttention(nn.Module):
426
+ def __init__(self, in_channels):
427
+ super().__init__()
428
+ self.in_channels = in_channels
429
+
430
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
431
+ self.q = torch.nn.Conv2d(in_channels,
432
+ in_channels,
433
+ kernel_size=1,
434
+ stride=1,
435
+ padding=0)
436
+ self.k = torch.nn.Conv2d(in_channels,
437
+ in_channels,
438
+ kernel_size=1,
439
+ stride=1,
440
+ padding=0)
441
+ self.v = torch.nn.Conv2d(in_channels,
442
+ in_channels,
443
+ kernel_size=1,
444
+ stride=1,
445
+ padding=0)
446
+ self.proj_out = torch.nn.Conv2d(in_channels,
447
+ in_channels,
448
+ kernel_size=1,
449
+ stride=1,
450
+ padding=0)
451
+
452
+ def forward(self, x):
453
+ h_ = x
454
+ h_ = self.norm(h_)
455
+ q = self.q(h_)
456
+ k = self.k(h_)
457
+ v = self.v(h_)
458
+
459
+ # compute attention
460
+ b,c,h,w = q.shape
461
+ q = rearrange(q, 'b c h w -> b (h w) c')
462
+ k = rearrange(k, 'b c h w -> b c (h w)')
463
+ w_ = torch.einsum('bij,bjk->bik', q, k)
464
+
465
+ w_ = w_ * (int(c)**(-0.5))
466
+ w_ = torch.nn.functional.softmax(w_, dim=2)
467
+
468
+ # attend to values
469
+ v = rearrange(v, 'b c h w -> b c (h w)')
470
+ w_ = rearrange(w_, 'b i j -> b j i')
471
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
472
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
473
+ h_ = self.proj_out(h_)
474
+
475
+ return x+h_
lvdm/modules/encoders/condition.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ import kornia
5
+ import open_clip
6
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7
+ from lvdm.common import autocast
8
+ from utils.utils import count_params
9
+
10
+ class AbstractEncoder(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def encode(self, *args, **kwargs):
15
+ raise NotImplementedError
16
+
17
+
18
+ class IdentityEncoder(AbstractEncoder):
19
+
20
+ def encode(self, x):
21
+ return x
22
+
23
+
24
+ class ClassEmbedder(nn.Module):
25
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
26
+ super().__init__()
27
+ self.key = key
28
+ self.embedding = nn.Embedding(n_classes, embed_dim)
29
+ self.n_classes = n_classes
30
+ self.ucg_rate = ucg_rate
31
+
32
+ def forward(self, batch, key=None, disable_dropout=False):
33
+ if key is None:
34
+ key = self.key
35
+ # this is for use in crossattn
36
+ c = batch[key][:, None]
37
+ if self.ucg_rate > 0. and not disable_dropout:
38
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
39
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
40
+ c = c.long()
41
+ c = self.embedding(c)
42
+ return c
43
+
44
+ def get_unconditional_conditioning(self, bs, device="cuda"):
45
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
46
+ uc = torch.ones((bs,), device=device) * uc_class
47
+ uc = {self.key: uc}
48
+ return uc
49
+
50
+
51
+ def disabled_train(self, mode=True):
52
+ """Overwrite model.train with this function to make sure train/eval mode
53
+ does not change anymore."""
54
+ return self
55
+
56
+
57
+ class FrozenT5Embedder(AbstractEncoder):
58
+ """Uses the T5 transformer encoder for text"""
59
+
60
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
61
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
62
+ super().__init__()
63
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
64
+ self.transformer = T5EncoderModel.from_pretrained(version)
65
+ self.device = device
66
+ self.max_length = max_length # TODO: typical value?
67
+ if freeze:
68
+ self.freeze()
69
+
70
+ def freeze(self):
71
+ self.transformer = self.transformer.eval()
72
+ # self.train = disabled_train
73
+ for param in self.parameters():
74
+ param.requires_grad = False
75
+
76
+ def forward(self, text):
77
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
78
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
79
+ tokens = batch_encoding["input_ids"].to(self.device)
80
+ outputs = self.transformer(input_ids=tokens)
81
+
82
+ z = outputs.last_hidden_state
83
+ return z
84
+
85
+ def encode(self, text):
86
+ return self(text)
87
+
88
+
89
+ class FrozenCLIPEmbedder(AbstractEncoder):
90
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
91
+ LAYERS = [
92
+ "last",
93
+ "pooled",
94
+ "hidden"
95
+ ]
96
+
97
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
98
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
99
+ super().__init__()
100
+ assert layer in self.LAYERS
101
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
102
+ self.transformer = CLIPTextModel.from_pretrained(version)
103
+ self.device = device
104
+ self.max_length = max_length
105
+ if freeze:
106
+ self.freeze()
107
+ self.layer = layer
108
+ self.layer_idx = layer_idx
109
+ if layer == "hidden":
110
+ assert layer_idx is not None
111
+ assert 0 <= abs(layer_idx) <= 12
112
+
113
+ def freeze(self):
114
+ self.transformer = self.transformer.eval()
115
+ # self.train = disabled_train
116
+ for param in self.parameters():
117
+ param.requires_grad = False
118
+
119
+ def forward(self, text):
120
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
121
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
122
+ tokens = batch_encoding["input_ids"].to(self.device)
123
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
124
+ if self.layer == "last":
125
+ z = outputs.last_hidden_state
126
+ elif self.layer == "pooled":
127
+ z = outputs.pooler_output[:, None, :]
128
+ else:
129
+ z = outputs.hidden_states[self.layer_idx]
130
+ return z
131
+
132
+ def encode(self, text):
133
+ return self(text)
134
+
135
+
136
+ class ClipImageEmbedder(nn.Module):
137
+ def __init__(
138
+ self,
139
+ model,
140
+ jit=False,
141
+ device='cuda' if torch.cuda.is_available() else 'cpu',
142
+ antialias=True,
143
+ ucg_rate=0.
144
+ ):
145
+ super().__init__()
146
+ from clip import load as load_clip
147
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
148
+
149
+ self.antialias = antialias
150
+
151
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
152
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
153
+ self.ucg_rate = ucg_rate
154
+
155
+ def preprocess(self, x):
156
+ # normalize to [0,1]
157
+ x = kornia.geometry.resize(x, (224, 224),
158
+ interpolation='bicubic', align_corners=True,
159
+ antialias=self.antialias)
160
+ x = (x + 1.) / 2.
161
+ # re-normalize according to clip
162
+ x = kornia.enhance.normalize(x, self.mean, self.std)
163
+ return x
164
+
165
+ def forward(self, x, no_dropout=False):
166
+ # x is assumed to be in range [-1,1]
167
+ out = self.model.encode_image(self.preprocess(x))
168
+ out = out.to(x.dtype)
169
+ if self.ucg_rate > 0. and not no_dropout:
170
+ out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
171
+ return out
172
+
173
+
174
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
175
+ """
176
+ Uses the OpenCLIP transformer encoder for text
177
+ """
178
+ LAYERS = [
179
+ # "pooled",
180
+ "last",
181
+ "penultimate"
182
+ ]
183
+
184
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
185
+ freeze=True, layer="last"):
186
+ super().__init__()
187
+ assert layer in self.LAYERS
188
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
189
+ del model.visual
190
+ self.model = model
191
+
192
+ self.device = device
193
+ self.max_length = max_length
194
+ if freeze:
195
+ self.freeze()
196
+ self.layer = layer
197
+ if self.layer == "last":
198
+ self.layer_idx = 0
199
+ elif self.layer == "penultimate":
200
+ self.layer_idx = 1
201
+ else:
202
+ raise NotImplementedError()
203
+
204
+ def freeze(self):
205
+ self.model = self.model.eval()
206
+ for param in self.parameters():
207
+ param.requires_grad = False
208
+
209
+ def forward(self, text):
210
+ self.device = self.model.positional_embedding.device
211
+ tokens = open_clip.tokenize(text)
212
+ z = self.encode_with_transformer(tokens.to(self.device))
213
+ return z
214
+
215
+ def encode_with_transformer(self, text):
216
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
217
+ x = x + self.model.positional_embedding
218
+ x = x.permute(1, 0, 2) # NLD -> LND
219
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
220
+ x = x.permute(1, 0, 2) # LND -> NLD
221
+ x = self.model.ln_final(x)
222
+ return x
223
+
224
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
225
+ for i, r in enumerate(self.model.transformer.resblocks):
226
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
227
+ break
228
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
229
+ x = checkpoint(r, x, attn_mask)
230
+ else:
231
+ x = r(x, attn_mask=attn_mask)
232
+ return x
233
+
234
+ def encode(self, text):
235
+ return self(text)
236
+
237
+
238
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
239
+ """
240
+ Uses the OpenCLIP vision transformer encoder for images
241
+ """
242
+
243
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
244
+ freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
245
+ super().__init__()
246
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
247
+ pretrained=version, )
248
+ del model.transformer
249
+ self.model = model
250
+
251
+ self.device = device
252
+ self.max_length = max_length
253
+ if freeze:
254
+ self.freeze()
255
+ self.layer = layer
256
+ if self.layer == "penultimate":
257
+ raise NotImplementedError()
258
+ self.layer_idx = 1
259
+
260
+ self.antialias = antialias
261
+
262
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
263
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
264
+ self.ucg_rate = ucg_rate
265
+
266
+ def preprocess(self, x):
267
+ # normalize to [0,1]
268
+ x = kornia.geometry.resize(x, (224, 224),
269
+ interpolation='bicubic', align_corners=True,
270
+ antialias=self.antialias)
271
+ x = (x + 1.) / 2.
272
+ # renormalize according to clip
273
+ x = kornia.enhance.normalize(x, self.mean, self.std)
274
+ return x
275
+
276
+ def freeze(self):
277
+ self.model = self.model.eval()
278
+ for param in self.parameters():
279
+ param.requires_grad = False
280
+
281
+ @autocast
282
+ def forward(self, image, no_dropout=False):
283
+ z = self.encode_with_vision_transformer(image)
284
+ if self.ucg_rate > 0. and not no_dropout:
285
+ z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
286
+ return z
287
+
288
+ def encode_with_vision_transformer(self, img):
289
+ img = self.preprocess(img)
290
+ x = self.model.visual(img)
291
+ return x
292
+
293
+ def encode(self, text):
294
+ return self(text)
295
+
296
+
297
+
298
+ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
299
+ """
300
+ Uses the OpenCLIP vision transformer encoder for images
301
+ """
302
+
303
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
304
+ freeze=True, layer="pooled", antialias=True):
305
+ super().__init__()
306
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
307
+ pretrained=version, )
308
+ del model.transformer
309
+ self.model = model
310
+ self.device = device
311
+
312
+ if freeze:
313
+ self.freeze()
314
+ self.layer = layer
315
+ if self.layer == "penultimate":
316
+ raise NotImplementedError()
317
+ self.layer_idx = 1
318
+
319
+ self.antialias = antialias
320
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
321
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
322
+
323
+
324
+ def preprocess(self, x):
325
+ # normalize to [0,1]
326
+ x = kornia.geometry.resize(x, (224, 224),
327
+ interpolation='bicubic', align_corners=True,
328
+ antialias=self.antialias)
329
+ x = (x + 1.) / 2.
330
+ # renormalize according to clip
331
+ x = kornia.enhance.normalize(x, self.mean, self.std)
332
+ return x
333
+
334
+ def freeze(self):
335
+ self.model = self.model.eval()
336
+ for param in self.model.parameters():
337
+ param.requires_grad = False
338
+
339
+ def forward(self, image, no_dropout=False):
340
+ ## image: b c h w
341
+ z = self.encode_with_vision_transformer(image)
342
+ return z
343
+
344
+ def encode_with_vision_transformer(self, x):
345
+ x = self.preprocess(x)
346
+
347
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
348
+ if self.model.visual.input_patchnorm:
349
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
350
+ x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
351
+ x = x.permute(0, 2, 4, 1, 3, 5)
352
+ x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
353
+ x = self.model.visual.patchnorm_pre_ln(x)
354
+ x = self.model.visual.conv1(x)
355
+ else:
356
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
357
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
358
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
359
+
360
+ # class embeddings and positional embeddings
361
+ x = torch.cat(
362
+ [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
363
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
364
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
365
+
366
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
367
+ x = self.model.visual.patch_dropout(x)
368
+ x = self.model.visual.ln_pre(x)
369
+
370
+ x = x.permute(1, 0, 2) # NLD -> LND
371
+ x = self.model.visual.transformer(x)
372
+ x = x.permute(1, 0, 2) # LND -> NLD
373
+
374
+ return x
375
+
376
+
377
+ class FrozenCLIPT5Encoder(AbstractEncoder):
378
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
379
+ clip_max_length=77, t5_max_length=77):
380
+ super().__init__()
381
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
382
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
383
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
384
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
385
+
386
+ def encode(self, text):
387
+ return self(text)
388
+
389
+ def forward(self, text):
390
+ clip_z = self.clip_encoder.encode(text)
391
+ t5_z = self.t5_encoder.encode(text)
392
+ return [clip_z, t5_z]
lvdm/modules/encoders/ip_resampler.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ImageProjModel(nn.Module):
8
+ """Projection Model"""
9
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
10
+ super().__init__()
11
+ self.cross_attention_dim = cross_attention_dim
12
+ self.clip_extra_context_tokens = clip_extra_context_tokens
13
+ self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
14
+ self.norm = nn.LayerNorm(cross_attention_dim)
15
+
16
+ def forward(self, image_embeds):
17
+ #embeds = image_embeds
18
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
19
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
20
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
21
+ return clip_extra_context_tokens
22
+
23
+ # FFN
24
+ def FeedForward(dim, mult=4):
25
+ inner_dim = int(dim * mult)
26
+ return nn.Sequential(
27
+ nn.LayerNorm(dim),
28
+ nn.Linear(dim, inner_dim, bias=False),
29
+ nn.GELU(),
30
+ nn.Linear(inner_dim, dim, bias=False),
31
+ )
32
+
33
+
34
+ def reshape_tensor(x, heads):
35
+ bs, length, width = x.shape
36
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
37
+ x = x.view(bs, length, heads, -1)
38
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
39
+ x = x.transpose(1, 2)
40
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
41
+ x = x.reshape(bs, heads, length, -1)
42
+ return x
43
+
44
+
45
+ class PerceiverAttention(nn.Module):
46
+ def __init__(self, *, dim, dim_head=64, heads=8):
47
+ super().__init__()
48
+ self.scale = dim_head**-0.5
49
+ self.dim_head = dim_head
50
+ self.heads = heads
51
+ inner_dim = dim_head * heads
52
+
53
+ self.norm1 = nn.LayerNorm(dim)
54
+ self.norm2 = nn.LayerNorm(dim)
55
+
56
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
57
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
58
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
59
+
60
+
61
+ def forward(self, x, latents):
62
+ """
63
+ Args:
64
+ x (torch.Tensor): image features
65
+ shape (b, n1, D)
66
+ latent (torch.Tensor): latent features
67
+ shape (b, n2, D)
68
+ """
69
+ x = self.norm1(x)
70
+ latents = self.norm2(latents)
71
+
72
+ b, l, _ = latents.shape
73
+
74
+ q = self.to_q(latents)
75
+ kv_input = torch.cat((x, latents), dim=-2)
76
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
77
+
78
+ q = reshape_tensor(q, self.heads)
79
+ k = reshape_tensor(k, self.heads)
80
+ v = reshape_tensor(v, self.heads)
81
+
82
+ # attention
83
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
84
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
85
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
86
+ out = weight @ v
87
+
88
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
89
+
90
+ return self.to_out(out)
91
+
92
+
93
+ class Resampler(nn.Module):
94
+ def __init__(
95
+ self,
96
+ dim=1024,
97
+ depth=8,
98
+ dim_head=64,
99
+ heads=16,
100
+ num_queries=8,
101
+ embedding_dim=768,
102
+ output_dim=1024,
103
+ ff_mult=4,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
108
+
109
+ self.proj_in = nn.Linear(embedding_dim, dim)
110
+
111
+ self.proj_out = nn.Linear(dim, output_dim)
112
+ self.norm_out = nn.LayerNorm(output_dim)
113
+
114
+ self.layers = nn.ModuleList([])
115
+ for _ in range(depth):
116
+ self.layers.append(
117
+ nn.ModuleList(
118
+ [
119
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
120
+ FeedForward(dim=dim, mult=ff_mult),
121
+ ]
122
+ )
123
+ )
124
+
125
+ def forward(self, x):
126
+
127
+ latents = self.latents.repeat(x.size(0), 1, 1)
128
+
129
+ x = self.proj_in(x)
130
+
131
+ for attn, ff in self.layers:
132
+ latents = attn(x, latents) + latents
133
+ latents = ff(latents) + latents
134
+
135
+ latents = self.proj_out(latents)
136
+ return self.norm_out(latents)
lvdm/{models/modules/autoencoder_modules.py → modules/networks/ae_modules.py} RENAMED
@@ -1,30 +1,11 @@
 
1
  import math
2
-
3
  import torch
4
  import numpy as np
5
- from torch import nn
6
  from einops import rearrange
7
-
8
-
9
- def get_timestep_embedding(timesteps, embedding_dim):
10
- """
11
- This matches the implementation in Denoising Diffusion Probabilistic Models:
12
- From Fairseq.
13
- Build sinusoidal embeddings.
14
- This matches the implementation in tensor2tensor, but differs slightly
15
- from the description in Section 3.5 of "Attention Is All You Need".
16
- """
17
- assert len(timesteps.shape) == 1
18
-
19
- half_dim = embedding_dim // 2
20
- emb = math.log(10000) / (half_dim - 1)
21
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
22
- emb = emb.to(device=timesteps.device)
23
- emb = timesteps.float()[:, None] * emb[None, :]
24
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
25
- if embedding_dim % 2 == 1: # zero pad
26
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
27
- return emb
28
 
29
  def nonlinearity(x):
30
  # swish
@@ -36,25 +17,6 @@ def Normalize(in_channels, num_groups=32):
36
 
37
 
38
 
39
- class LinearAttention(nn.Module):
40
- def __init__(self, dim, heads=4, dim_head=32):
41
- super().__init__()
42
- self.heads = heads
43
- hidden_dim = dim_head * heads
44
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
45
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
46
-
47
- def forward(self, x):
48
- b, c, h, w = x.shape
49
- qkv = self.to_qkv(x)
50
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
51
- k = k.softmax(dim=-1)
52
- context = torch.einsum('bhdn,bhen->bhde', k, v)
53
- out = torch.einsum('bhde,bhdn->bhen', context, q)
54
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
55
- return self.to_out(out)
56
-
57
-
58
  class LinAttnBlock(LinearAttention):
59
  """to match AttnBlock usage"""
60
  def __init__(self, in_channels):
@@ -115,10 +77,9 @@ class AttnBlock(nn.Module):
115
 
116
  return x+h_
117
 
118
-
119
  def make_attn(in_channels, attn_type="vanilla"):
120
  assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
121
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
122
  if attn_type == "vanilla":
123
  return AttnBlock(in_channels)
124
  elif attn_type == "none":
@@ -165,6 +126,27 @@ class Upsample(nn.Module):
165
  x = self.conv(x)
166
  return x
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  class ResnetBlock(nn.Module):
170
  def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
@@ -502,7 +484,7 @@ class Decoder(nn.Module):
502
  block_in = ch*ch_mult[self.num_resolutions-1]
503
  curr_res = resolution // 2**(self.num_resolutions-1)
504
  self.z_shape = (1,z_channels,curr_res,curr_res)
505
- print("Working with z of shape {} = {} dimensions.".format(
506
  self.z_shape, np.prod(self.z_shape)))
507
 
508
  # z to block_in
@@ -594,3 +576,270 @@ class Decoder(nn.Module):
594
  if self.tanh_out:
595
  h = torch.tanh(h)
596
  return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
  import math
 
3
  import torch
4
  import numpy as np
5
+ import torch.nn as nn
6
  from einops import rearrange
7
+ from utils.utils import instantiate_from_config
8
+ from lvdm.modules.attention import LinearAttention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def nonlinearity(x):
11
  # swish
 
17
 
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class LinAttnBlock(LinearAttention):
21
  """to match AttnBlock usage"""
22
  def __init__(self, in_channels):
 
77
 
78
  return x+h_
79
 
 
80
  def make_attn(in_channels, attn_type="vanilla"):
81
  assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
82
+ #print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
83
  if attn_type == "vanilla":
84
  return AttnBlock(in_channels)
85
  elif attn_type == "none":
 
126
  x = self.conv(x)
127
  return x
128
 
129
+ def get_timestep_embedding(timesteps, embedding_dim):
130
+ """
131
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
132
+ From Fairseq.
133
+ Build sinusoidal embeddings.
134
+ This matches the implementation in tensor2tensor, but differs slightly
135
+ from the description in Section 3.5 of "Attention Is All You Need".
136
+ """
137
+ assert len(timesteps.shape) == 1
138
+
139
+ half_dim = embedding_dim // 2
140
+ emb = math.log(10000) / (half_dim - 1)
141
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
142
+ emb = emb.to(device=timesteps.device)
143
+ emb = timesteps.float()[:, None] * emb[None, :]
144
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
145
+ if embedding_dim % 2 == 1: # zero pad
146
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
147
+ return emb
148
+
149
+
150
 
151
  class ResnetBlock(nn.Module):
152
  def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
 
484
  block_in = ch*ch_mult[self.num_resolutions-1]
485
  curr_res = resolution // 2**(self.num_resolutions-1)
486
  self.z_shape = (1,z_channels,curr_res,curr_res)
487
+ print("AE working on z of shape {} = {} dimensions.".format(
488
  self.z_shape, np.prod(self.z_shape)))
489
 
490
  # z to block_in
 
576
  if self.tanh_out:
577
  h = torch.tanh(h)
578
  return h
579
+
580
+
581
+ class SimpleDecoder(nn.Module):
582
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
583
+ super().__init__()
584
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
585
+ ResnetBlock(in_channels=in_channels,
586
+ out_channels=2 * in_channels,
587
+ temb_channels=0, dropout=0.0),
588
+ ResnetBlock(in_channels=2 * in_channels,
589
+ out_channels=4 * in_channels,
590
+ temb_channels=0, dropout=0.0),
591
+ ResnetBlock(in_channels=4 * in_channels,
592
+ out_channels=2 * in_channels,
593
+ temb_channels=0, dropout=0.0),
594
+ nn.Conv2d(2*in_channels, in_channels, 1),
595
+ Upsample(in_channels, with_conv=True)])
596
+ # end
597
+ self.norm_out = Normalize(in_channels)
598
+ self.conv_out = torch.nn.Conv2d(in_channels,
599
+ out_channels,
600
+ kernel_size=3,
601
+ stride=1,
602
+ padding=1)
603
+
604
+ def forward(self, x):
605
+ for i, layer in enumerate(self.model):
606
+ if i in [1,2,3]:
607
+ x = layer(x, None)
608
+ else:
609
+ x = layer(x)
610
+
611
+ h = self.norm_out(x)
612
+ h = nonlinearity(h)
613
+ x = self.conv_out(h)
614
+ return x
615
+
616
+
617
+ class UpsampleDecoder(nn.Module):
618
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
619
+ ch_mult=(2,2), dropout=0.0):
620
+ super().__init__()
621
+ # upsampling
622
+ self.temb_ch = 0
623
+ self.num_resolutions = len(ch_mult)
624
+ self.num_res_blocks = num_res_blocks
625
+ block_in = in_channels
626
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
627
+ self.res_blocks = nn.ModuleList()
628
+ self.upsample_blocks = nn.ModuleList()
629
+ for i_level in range(self.num_resolutions):
630
+ res_block = []
631
+ block_out = ch * ch_mult[i_level]
632
+ for i_block in range(self.num_res_blocks + 1):
633
+ res_block.append(ResnetBlock(in_channels=block_in,
634
+ out_channels=block_out,
635
+ temb_channels=self.temb_ch,
636
+ dropout=dropout))
637
+ block_in = block_out
638
+ self.res_blocks.append(nn.ModuleList(res_block))
639
+ if i_level != self.num_resolutions - 1:
640
+ self.upsample_blocks.append(Upsample(block_in, True))
641
+ curr_res = curr_res * 2
642
+
643
+ # end
644
+ self.norm_out = Normalize(block_in)
645
+ self.conv_out = torch.nn.Conv2d(block_in,
646
+ out_channels,
647
+ kernel_size=3,
648
+ stride=1,
649
+ padding=1)
650
+
651
+ def forward(self, x):
652
+ # upsampling
653
+ h = x
654
+ for k, i_level in enumerate(range(self.num_resolutions)):
655
+ for i_block in range(self.num_res_blocks + 1):
656
+ h = self.res_blocks[i_level][i_block](h, None)
657
+ if i_level != self.num_resolutions - 1:
658
+ h = self.upsample_blocks[k](h)
659
+ h = self.norm_out(h)
660
+ h = nonlinearity(h)
661
+ h = self.conv_out(h)
662
+ return h
663
+
664
+
665
+ class LatentRescaler(nn.Module):
666
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
667
+ super().__init__()
668
+ # residual block, interpolate, residual block
669
+ self.factor = factor
670
+ self.conv_in = nn.Conv2d(in_channels,
671
+ mid_channels,
672
+ kernel_size=3,
673
+ stride=1,
674
+ padding=1)
675
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
676
+ out_channels=mid_channels,
677
+ temb_channels=0,
678
+ dropout=0.0) for _ in range(depth)])
679
+ self.attn = AttnBlock(mid_channels)
680
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
681
+ out_channels=mid_channels,
682
+ temb_channels=0,
683
+ dropout=0.0) for _ in range(depth)])
684
+
685
+ self.conv_out = nn.Conv2d(mid_channels,
686
+ out_channels,
687
+ kernel_size=1,
688
+ )
689
+
690
+ def forward(self, x):
691
+ x = self.conv_in(x)
692
+ for block in self.res_block1:
693
+ x = block(x, None)
694
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
695
+ x = self.attn(x)
696
+ for block in self.res_block2:
697
+ x = block(x, None)
698
+ x = self.conv_out(x)
699
+ return x
700
+
701
+
702
+ class MergedRescaleEncoder(nn.Module):
703
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
704
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
705
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
706
+ super().__init__()
707
+ intermediate_chn = ch * ch_mult[-1]
708
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
709
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
710
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
711
+ out_ch=None)
712
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
713
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
714
+
715
+ def forward(self, x):
716
+ x = self.encoder(x)
717
+ x = self.rescaler(x)
718
+ return x
719
+
720
+
721
+ class MergedRescaleDecoder(nn.Module):
722
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
723
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
724
+ super().__init__()
725
+ tmp_chn = z_channels*ch_mult[-1]
726
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
727
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
728
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
729
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
730
+ out_channels=tmp_chn, depth=rescale_module_depth)
731
+
732
+ def forward(self, x):
733
+ x = self.rescaler(x)
734
+ x = self.decoder(x)
735
+ return x
736
+
737
+
738
+ class Upsampler(nn.Module):
739
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
740
+ super().__init__()
741
+ assert out_size >= in_size
742
+ num_blocks = int(np.log2(out_size//in_size))+1
743
+ factor_up = 1.+ (out_size % in_size)
744
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
745
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
746
+ out_channels=in_channels)
747
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
748
+ attn_resolutions=[], in_channels=None, ch=in_channels,
749
+ ch_mult=[ch_mult for _ in range(num_blocks)])
750
+
751
+ def forward(self, x):
752
+ x = self.rescaler(x)
753
+ x = self.decoder(x)
754
+ return x
755
+
756
+
757
+ class Resize(nn.Module):
758
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
759
+ super().__init__()
760
+ self.with_conv = learned
761
+ self.mode = mode
762
+ if self.with_conv:
763
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
764
+ raise NotImplementedError()
765
+ assert in_channels is not None
766
+ # no asymmetric padding in torch conv, must do it ourselves
767
+ self.conv = torch.nn.Conv2d(in_channels,
768
+ in_channels,
769
+ kernel_size=4,
770
+ stride=2,
771
+ padding=1)
772
+
773
+ def forward(self, x, scale_factor=1.0):
774
+ if scale_factor==1.0:
775
+ return x
776
+ else:
777
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
778
+ return x
779
+
780
+ class FirstStagePostProcessor(nn.Module):
781
+
782
+ def __init__(self, ch_mult:list, in_channels,
783
+ pretrained_model:nn.Module=None,
784
+ reshape=False,
785
+ n_channels=None,
786
+ dropout=0.,
787
+ pretrained_config=None):
788
+ super().__init__()
789
+ if pretrained_config is None:
790
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
791
+ self.pretrained_model = pretrained_model
792
+ else:
793
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
794
+ self.instantiate_pretrained(pretrained_config)
795
+
796
+ self.do_reshape = reshape
797
+
798
+ if n_channels is None:
799
+ n_channels = self.pretrained_model.encoder.ch
800
+
801
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
802
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
803
+ stride=1,padding=1)
804
+
805
+ blocks = []
806
+ downs = []
807
+ ch_in = n_channels
808
+ for m in ch_mult:
809
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
810
+ ch_in = m * n_channels
811
+ downs.append(Downsample(ch_in, with_conv=False))
812
+
813
+ self.model = nn.ModuleList(blocks)
814
+ self.downsampler = nn.ModuleList(downs)
815
+
816
+
817
+ def instantiate_pretrained(self, config):
818
+ model = instantiate_from_config(config)
819
+ self.pretrained_model = model.eval()
820
+ # self.pretrained_model.train = False
821
+ for param in self.pretrained_model.parameters():
822
+ param.requires_grad = False
823
+
824
+
825
+ @torch.no_grad()
826
+ def encode_with_pretrained(self,x):
827
+ c = self.pretrained_model.encode(x)
828
+ if isinstance(c, DiagonalGaussianDistribution):
829
+ c = c.mode()
830
+ return c
831
+
832
+ def forward(self,x):
833
+ z_fs = self.encode_with_pretrained(x)
834
+ z = self.proj_norm(z_fs)
835
+ z = self.proj(z)
836
+ z = nonlinearity(z)
837
+
838
+ for submodel, downmodel in zip(self.model,self.downsampler):
839
+ z = submodel(z,temb=None)
840
+ z = downmodel(z)
841
+
842
+ if self.do_reshape:
843
+ z = rearrange(z,'b c h w -> b (h w) c')
844
+ return z
845
+
lvdm/{models/modules → modules/networks}/openaimodel3d.py RENAMED
@@ -1,38 +1,25 @@
1
- from abc import abstractmethod
2
- import math
3
- from einops import rearrange
4
  from functools import partial
5
- import numpy as np
6
- import torch as th
7
  import torch.nn as nn
 
8
  import torch.nn.functional as F
9
- from omegaconf.listconfig import ListConfig
10
-
11
- from lvdm.models.modules.util import (
12
- checkpoint,
13
  conv_nd,
14
  linear,
15
  avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- nonlinearity,
20
  )
 
21
 
22
- # dummy replace
23
- def convert_module_to_f16(x):
24
- pass
25
-
26
- def convert_module_to_f32(x):
27
- pass
28
 
29
- ## go
30
- # ---------------------------------------------------------------------------------------------------
31
  class TimestepBlock(nn.Module):
32
  """
33
  Any module where forward() takes timestep embeddings as a second argument.
34
  """
35
-
36
  @abstractmethod
37
  def forward(self, x, emb):
38
  """
@@ -40,107 +27,85 @@ class TimestepBlock(nn.Module):
40
  """
41
 
42
 
43
- # ---------------------------------------------------------------------------------------------------
44
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
45
  """
46
  A sequential module that passes timestep embeddings to the children that
47
  support it as an extra input.
48
  """
49
 
50
- def forward(self, x, emb, context, **kwargs):
51
  for layer in self:
52
  if isinstance(layer, TimestepBlock):
53
- x = layer(x, emb, **kwargs)
54
- elif isinstance(layer, STTransformerClass):
55
- x = layer(x, context, **kwargs)
 
 
 
 
56
  else:
57
- x = layer(x)
58
  return x
59
 
60
 
61
- # ---------------------------------------------------------------------------------------------------
62
- class Upsample(nn.Module):
63
  """
64
- An upsampling layer with an optional convolution.
65
  :param channels: channels in the inputs and outputs.
66
  :param use_conv: a bool determining if a convolution is applied.
67
  :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
68
- upsampling occurs in the inner-two dimensions.
69
  """
70
 
71
- def __init__(self, channels, use_conv, dims=2, out_channels=None,
72
- kernel_size_t=3,
73
- padding_t=1,
74
- ):
75
  super().__init__()
76
  self.channels = channels
77
  self.out_channels = out_channels or channels
78
  self.use_conv = use_conv
79
  self.dims = dims
 
80
  if use_conv:
81
- self.conv = conv_nd(dims, self.channels, self.out_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1))
82
-
83
- def forward(self, x):
84
- assert x.shape[1] == self.channels
85
- if self.dims == 3:
86
- x = F.interpolate(
87
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
88
  )
89
  else:
90
- x = F.interpolate(x, scale_factor=2, mode="nearest")
91
- if self.use_conv:
92
- x = self.conv(x)
93
- return x
94
-
95
-
96
- # ---------------------------------------------------------------------------------------------------
97
- class TransposedUpsample(nn.Module):
98
- 'Learned 2x upsampling without padding'
99
- def __init__(self, channels, out_channels=None, ks=5):
100
- super().__init__()
101
- self.channels = channels
102
- self.out_channels = out_channels or channels
103
-
104
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
105
 
106
- def forward(self,x):
107
- return self.up(x)
 
108
 
109
 
110
- # ---------------------------------------------------------------------------------------------------
111
- class Downsample(nn.Module):
112
  """
113
- A downsampling layer with an optional convolution.
114
  :param channels: channels in the inputs and outputs.
115
  :param use_conv: a bool determining if a convolution is applied.
116
  :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
117
- downsampling occurs in the inner-two dimensions.
118
  """
119
 
120
- def __init__(self, channels, use_conv, dims=2, out_channels=None,
121
- kernel_size_t=3,
122
- padding_t=1,
123
- ):
124
  super().__init__()
125
  self.channels = channels
126
  self.out_channels = out_channels or channels
127
  self.use_conv = use_conv
128
  self.dims = dims
129
- stride = 2 if dims != 3 else (1, 2, 2)
130
  if use_conv:
131
- self.op = conv_nd(
132
- dims, self.channels, self.out_channels, (kernel_size_t, 3,3), stride=stride, padding=(padding_t, 1,1)
133
- )
134
- else:
135
- assert self.channels == self.out_channels
136
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
137
 
138
  def forward(self, x):
139
  assert x.shape[1] == self.channels
140
- return self.op(x)
 
 
 
 
 
 
141
 
142
 
143
- # ---------------------------------------------------------------------------------------------------
144
  class ResBlock(TimestepBlock):
145
  """
146
  A residual block that can optionally change the number of channels.
@@ -152,7 +117,6 @@ class ResBlock(TimestepBlock):
152
  convolution instead of a smaller 1x1 convolution to change the
153
  channels in the skip connection.
154
  :param dims: determines if the signal is 1D, 2D, or 3D.
155
- :param use_checkpoint: if True, use gradient checkpointing on this module.
156
  :param up: if True, use this block for upsampling.
157
  :param down: if True, use this block for downsampling.
158
  """
@@ -163,17 +127,14 @@ class ResBlock(TimestepBlock):
163
  emb_channels,
164
  dropout,
165
  out_channels=None,
166
- use_conv=False,
167
  use_scale_shift_norm=False,
168
  dims=2,
169
  use_checkpoint=False,
 
170
  up=False,
171
  down=False,
172
- # temporal
173
- kernel_size_t=3,
174
- padding_t=1,
175
- nonlinearity_type='silu',
176
- **kwargs
177
  ):
178
  super().__init__()
179
  self.channels = channels
@@ -183,65 +144,68 @@ class ResBlock(TimestepBlock):
183
  self.use_conv = use_conv
184
  self.use_checkpoint = use_checkpoint
185
  self.use_scale_shift_norm = use_scale_shift_norm
186
- self.nonlinearity_type = nonlinearity_type
187
 
188
  self.in_layers = nn.Sequential(
189
  normalization(channels),
190
- nonlinearity(nonlinearity_type),
191
- conv_nd(dims, channels, self.out_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1)),
192
  )
193
 
194
  self.updown = up or down
195
 
196
  if up:
197
- self.h_upd = Upsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
198
- self.x_upd = Upsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
199
  elif down:
200
- self.h_upd = Downsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
201
- self.x_upd = Downsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
202
  else:
203
  self.h_upd = self.x_upd = nn.Identity()
204
 
205
  self.emb_layers = nn.Sequential(
206
- nonlinearity(nonlinearity_type),
207
- linear(
208
  emb_channels,
209
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
210
  ),
211
  )
212
  self.out_layers = nn.Sequential(
213
  normalization(self.out_channels),
214
- nonlinearity(nonlinearity_type),
215
  nn.Dropout(p=dropout),
216
- zero_module(
217
- conv_nd(dims, self.out_channels, self.out_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1))
218
- ),
219
  )
220
 
221
  if self.out_channels == channels:
222
  self.skip_connection = nn.Identity()
223
  elif use_conv:
224
- self.skip_connection = conv_nd(
225
- dims, channels, self.out_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1)
226
- )
227
  else:
228
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
229
-
230
 
231
- def forward(self, x, emb, **kwargs):
 
 
 
 
 
 
 
 
232
  """
233
  Apply the block to a Tensor, conditioned on a timestep embedding.
234
  :param x: an [N x C x ...] Tensor of features.
235
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
236
  :return: an [N x C x ...] Tensor of outputs.
237
  """
238
- return checkpoint(self._forward,
239
- (x, emb),
240
- self.parameters(),
241
- self.use_checkpoint
242
- )
243
 
244
- def _forward(self, x, emb,):
245
  if self.updown:
246
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
247
  h = in_rest(x)
@@ -250,38 +214,72 @@ class ResBlock(TimestepBlock):
250
  h = in_conv(h)
251
  else:
252
  h = self.in_layers(x)
253
-
254
  emb_out = self.emb_layers(emb).type(h.dtype)
255
- if emb_out.dim() == 3: # btc for video data
256
- emb_out = rearrange(emb_out, 'b t c -> b c t')
257
- while len(emb_out.shape) < h.dim():
258
- emb_out = emb_out[..., None] # bct -> bct11 or bc -> bc111
259
-
260
  if self.use_scale_shift_norm:
261
  out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
262
- scale, shift = th.chunk(emb_out, 2, dim=1)
263
  h = out_norm(h) * (1 + scale) + shift
264
  h = out_rest(h)
265
  else:
266
  h = h + emb_out
267
  h = self.out_layers(h)
268
-
269
- out = self.skip_connection(x) + h
270
 
271
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # ---------------------------------------------------------------------------------------------------
274
- def make_spatialtemporal_transformer(module_name='attention_temporal', class_name='SpatialTemporalTransformer'):
275
- module = __import__(f"lvdm.models.modules.{module_name}", fromlist=[class_name])
276
- global STTransformerClass
277
- STTransformerClass = getattr(module, class_name)
278
- return STTransformerClass
279
 
280
- # ---------------------------------------------------------------------------------------------------
281
  class UNetModel(nn.Module):
282
  """
283
  The full UNet model with attention and timestep embedding.
284
- :param in_channels: channels in the input Tensor.
285
  :param model_channels: base channel count for the model.
286
  :param out_channels: channels in the output Tensor.
287
  :param num_res_blocks: number of residual blocks per downsample.
@@ -304,67 +302,45 @@ class UNetModel(nn.Module):
304
  of heads for upsampling. Deprecated.
305
  :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
306
  :param resblock_updown: use residual blocks for up/downsampling.
307
- :param use_new_attention_order: use a different attention pattern for potentially
308
- increased efficiency.
309
  """
310
 
311
- def __init__(
312
- self,
313
- image_size, # not used in UNetModel
314
- in_channels,
315
- model_channels,
316
- out_channels,
317
- num_res_blocks,
318
- attention_resolutions,
319
- dropout=0,
320
- channel_mult=(1, 2, 4, 8),
321
- conv_resample=True,
322
- dims=3,
323
- num_classes=None,
324
- use_checkpoint=False,
325
- use_fp16=False,
326
- num_heads=-1,
327
- num_head_channels=-1,
328
- num_heads_upsample=-1,
329
- use_scale_shift_norm=False,
330
- resblock_updown=False,
331
- transformer_depth=1, # custom transformer support
332
- context_dim=None, # custom transformer support
333
- legacy=True,
334
- # temporal related
335
- kernel_size_t=1,
336
- padding_t=1,
337
- use_temporal_transformer=True,
338
- temporal_length=None,
339
- use_relative_position=False,
340
- cross_attn_on_tempoal=False,
341
- temporal_crossattn_type="crossattn",
342
- order="stst",
343
- nonlinearity_type='silu',
344
- temporalcrossfirst=False,
345
- split_stcontext=False,
346
- temporal_context_dim=None,
347
- use_tempoal_causal_attn=False,
348
- ST_transformer_module='attention_temporal',
349
- ST_transformer_class='SpatialTemporalTransformer',
350
- **kwargs,
351
- ):
352
- super().__init__()
353
- assert(use_temporal_transformer)
354
- if context_dim is not None:
355
- if type(context_dim) == ListConfig:
356
- context_dim = list(context_dim)
357
-
358
- if num_heads_upsample == -1:
359
- num_heads_upsample = num_heads
360
-
361
  if num_heads == -1:
362
  assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
363
-
364
  if num_head_channels == -1:
365
  assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
366
 
367
- self.image_size = image_size
368
  self.in_channels = in_channels
369
  self.model_channels = model_channels
370
  self.out_channels = out_channels
@@ -373,65 +349,55 @@ class UNetModel(nn.Module):
373
  self.dropout = dropout
374
  self.channel_mult = channel_mult
375
  self.conv_resample = conv_resample
376
- self.num_classes = num_classes
 
377
  self.use_checkpoint = use_checkpoint
378
- self.dtype = th.float16 if use_fp16 else th.float32
379
- self.num_heads = num_heads
380
- self.num_head_channels = num_head_channels
381
- self.num_heads_upsample = num_heads_upsample
382
-
383
- self.use_relative_position = use_relative_position
384
- self.temporal_length = temporal_length
385
- self.cross_attn_on_tempoal = cross_attn_on_tempoal
386
- self.temporal_crossattn_type = temporal_crossattn_type
387
- self.order = order
388
- self.temporalcrossfirst = temporalcrossfirst
389
- self.split_stcontext = split_stcontext
390
- self.temporal_context_dim = temporal_context_dim
391
- self.nonlinearity_type = nonlinearity_type
392
- self.use_tempoal_causal_attn = use_tempoal_causal_attn
393
-
394
 
395
- time_embed_dim = model_channels * 4
396
- self.time_embed_dim = time_embed_dim
397
  self.time_embed = nn.Sequential(
398
  linear(model_channels, time_embed_dim),
399
- nonlinearity(nonlinearity_type),
400
  linear(time_embed_dim, time_embed_dim),
401
  )
402
-
403
- if self.num_classes is not None:
404
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
405
-
406
- STTransformerClass = make_spatialtemporal_transformer(module_name=ST_transformer_module,
407
- class_name=ST_transformer_class)
408
 
409
  self.input_blocks = nn.ModuleList(
410
  [
411
- TimestepEmbedSequential(
412
- conv_nd(dims, in_channels, model_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1))
413
- )
414
  ]
415
  )
416
- self._feature_size = model_channels
 
 
 
 
 
 
 
 
 
 
 
417
  input_block_chans = [model_channels]
418
  ch = model_channels
419
  ds = 1
420
  for level, mult in enumerate(channel_mult):
421
  for _ in range(num_res_blocks):
422
  layers = [
423
- ResBlock(
424
- ch,
425
- time_embed_dim,
426
- dropout,
427
- out_channels=mult * model_channels,
428
- dims=dims,
429
- use_checkpoint=use_checkpoint,
430
- use_scale_shift_norm=use_scale_shift_norm,
431
- kernel_size_t=kernel_size_t,
432
- padding_t=padding_t,
433
- nonlinearity_type=nonlinearity_type,
434
- **kwargs
435
  )
436
  ]
437
  ch = mult * model_channels
@@ -441,120 +407,85 @@ class UNetModel(nn.Module):
441
  else:
442
  num_heads = ch // num_head_channels
443
  dim_head = num_head_channels
444
- if legacy:
445
- dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
446
- layers.append(STTransformerClass(
447
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
448
- # temporal related
449
- temporal_length=temporal_length,
450
- use_relative_position=use_relative_position,
451
- cross_attn_on_tempoal=cross_attn_on_tempoal,
452
- temporal_crossattn_type=temporal_crossattn_type,
453
- order=order,
454
- temporalcrossfirst=temporalcrossfirst,
455
- split_stcontext=split_stcontext,
456
- temporal_context_dim=temporal_context_dim,
457
- use_tempoal_causal_attn=use_tempoal_causal_attn,
458
- **kwargs,
459
- ))
460
  self.input_blocks.append(TimestepEmbedSequential(*layers))
461
- self._feature_size += ch
462
  input_block_chans.append(ch)
463
  if level != len(channel_mult) - 1:
464
  out_ch = ch
465
  self.input_blocks.append(
466
  TimestepEmbedSequential(
467
- ResBlock(
468
- ch,
469
- time_embed_dim,
470
- dropout,
471
- out_channels=out_ch,
472
- dims=dims,
473
- use_checkpoint=use_checkpoint,
474
  use_scale_shift_norm=use_scale_shift_norm,
475
- down=True,
476
- kernel_size_t=kernel_size_t,
477
- padding_t=padding_t,
478
- nonlinearity_type=nonlinearity_type,
479
- **kwargs
480
  )
481
  if resblock_updown
482
- else Downsample(
483
- ch, conv_resample, dims=dims, out_channels=out_ch, kernel_size_t=kernel_size_t, padding_t=padding_t
484
- )
485
  )
486
  )
487
  ch = out_ch
488
  input_block_chans.append(ch)
489
  ds *= 2
490
- self._feature_size += ch
491
 
492
  if num_head_channels == -1:
493
  dim_head = ch // num_heads
494
  else:
495
  num_heads = ch // num_head_channels
496
  dim_head = num_head_channels
497
- if legacy:
498
- dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
499
- self.middle_block = TimestepEmbedSequential(
500
- ResBlock(
501
- ch,
502
- time_embed_dim,
503
- dropout,
504
- dims=dims,
505
- use_checkpoint=use_checkpoint,
506
- use_scale_shift_norm=use_scale_shift_norm,
507
- kernel_size_t=kernel_size_t,
508
- padding_t=padding_t,
509
- nonlinearity_type=nonlinearity_type,
510
- **kwargs
511
- ),
512
- STTransformerClass(
513
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
514
- # temporal related
515
- temporal_length=temporal_length,
516
- use_relative_position=use_relative_position,
517
- cross_attn_on_tempoal=cross_attn_on_tempoal,
518
- temporal_crossattn_type=temporal_crossattn_type,
519
- order=order,
520
- temporalcrossfirst=temporalcrossfirst,
521
- split_stcontext=split_stcontext,
522
- temporal_context_dim=temporal_context_dim,
523
- use_tempoal_causal_attn=use_tempoal_causal_attn,
524
- **kwargs,
525
- ),
526
- ResBlock(
527
- ch,
528
- time_embed_dim,
529
- dropout,
530
- dims=dims,
531
- use_checkpoint=use_checkpoint,
532
- use_scale_shift_norm=use_scale_shift_norm,
533
- kernel_size_t=kernel_size_t,
534
- padding_t=padding_t,
535
- nonlinearity_type=nonlinearity_type,
536
- **kwargs
537
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  )
539
- self._feature_size += ch
540
 
541
  self.output_blocks = nn.ModuleList([])
542
  for level, mult in list(enumerate(channel_mult))[::-1]:
543
  for i in range(num_res_blocks + 1):
544
  ich = input_block_chans.pop()
545
  layers = [
546
- ResBlock(
547
- ch + ich,
548
- time_embed_dim,
549
- dropout,
550
- out_channels=model_channels * mult,
551
- dims=dims,
552
- use_checkpoint=use_checkpoint,
553
- use_scale_shift_norm=use_scale_shift_norm,
554
- kernel_size_t=kernel_size_t,
555
- padding_t=padding_t,
556
- nonlinearity_type=nonlinearity_type,
557
- **kwargs
558
  )
559
  ]
560
  ch = model_channels * mult
@@ -564,107 +495,83 @@ class UNetModel(nn.Module):
564
  else:
565
  num_heads = ch // num_head_channels
566
  dim_head = num_head_channels
567
- if legacy:
568
- dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
569
  layers.append(
570
- STTransformerClass(
571
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
572
- # temporal related
573
- temporal_length=temporal_length,
574
- use_relative_position=use_relative_position,
575
- cross_attn_on_tempoal=cross_attn_on_tempoal,
576
- temporal_crossattn_type=temporal_crossattn_type,
577
- order=order,
578
- temporalcrossfirst=temporalcrossfirst,
579
- split_stcontext=split_stcontext,
580
- temporal_context_dim=temporal_context_dim,
581
- use_tempoal_causal_attn=use_tempoal_causal_attn,
582
- **kwargs,
583
  )
584
  )
 
 
 
 
 
 
 
 
 
585
  if level and i == num_res_blocks:
586
  out_ch = ch
587
  layers.append(
588
- ResBlock(
589
- ch,
590
- time_embed_dim,
591
- dropout,
592
- out_channels=out_ch,
593
- dims=dims,
594
- use_checkpoint=use_checkpoint,
595
  use_scale_shift_norm=use_scale_shift_norm,
596
- up=True,
597
- kernel_size_t=kernel_size_t,
598
- padding_t=padding_t,
599
- nonlinearity_type=nonlinearity_type,
600
- **kwargs
601
  )
602
  if resblock_updown
603
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, kernel_size_t=kernel_size_t, padding_t=padding_t)
604
  )
605
  ds //= 2
606
  self.output_blocks.append(TimestepEmbedSequential(*layers))
607
- self._feature_size += ch
608
 
609
  self.out = nn.Sequential(
610
  normalization(ch),
611
- nonlinearity(nonlinearity_type),
612
- zero_module(conv_nd(dims, model_channels, out_channels, (kernel_size_t, 3,3), padding=(padding_t, 1,1))),
613
  )
614
-
615
 
616
- def convert_to_fp16(self):
617
- """
618
- Convert the torso of the model to float16.
619
- """
620
- self.input_blocks.apply(convert_module_to_f16)
621
- self.middle_block.apply(convert_module_to_f16)
622
- self.output_blocks.apply(convert_module_to_f16)
623
 
624
- def convert_to_fp32(self):
625
- """
626
- Convert the torso of the model to float32.
627
- """
628
- self.input_blocks.apply(convert_module_to_f32)
629
- self.middle_block.apply(convert_module_to_f32)
630
- self.output_blocks.apply(convert_module_to_f32)
631
 
632
- def forward(self, x, timesteps=None, time_emb_replace=None, context=None, features_adapter=None, y=None, **kwargs):
633
- """
634
- Apply the model to an input batch.
635
- :param x: an [N x C x ...] Tensor of inputs.
636
- :param timesteps: a 1-D batch of timesteps.
637
- :param context: conditioning plugged in via crossattn
638
- :param y: an [N] Tensor of labels, if class-conditional.
639
- :return: an [N x C x ...] Tensor of outputs.
640
- """
641
-
642
- hs = []
643
- if time_emb_replace is None:
644
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
645
- emb = self.time_embed(t_emb)
646
- else:
647
- emb = time_emb_replace
648
-
649
- if y is not None: # if class-conditional model, inject class labels
650
- assert y.shape == (x.shape[0],)
651
- emb = emb + self.label_emb(y)
652
 
653
  h = x.type(self.dtype)
654
  adapter_idx = 0
 
655
  for id, module in enumerate(self.input_blocks):
656
- h = module(h, emb, context, **kwargs)
 
 
657
  ## plug-in adapter features
658
  if ((id+1)%3 == 0) and features_adapter is not None:
659
  h = h + features_adapter[adapter_idx]
660
  adapter_idx += 1
661
  hs.append(h)
662
  if features_adapter is not None:
663
- assert len(features_adapter)==adapter_idx, 'Mismatch features adapter'
664
 
665
- h = self.middle_block(h, emb, context, **kwargs)
666
  for module in self.output_blocks:
667
- h = th.cat([h, hs.pop()], dim=1)
668
- h = module(h, emb, context, **kwargs)
669
  h = h.type(x.dtype)
670
- return self.out(h)
 
 
 
 
 
 
 
 
 
1
  from functools import partial
2
+ from abc import abstractmethod
3
+ import torch
4
  import torch.nn as nn
5
+ from einops import rearrange
6
  import torch.nn.functional as F
7
+ from lvdm.models.utils_diffusion import timestep_embedding
8
+ from lvdm.common import checkpoint
9
+ from lvdm.basics import (
10
+ zero_module,
11
  conv_nd,
12
  linear,
13
  avg_pool_nd,
14
+ normalization
 
 
 
15
  )
16
+ from lvdm.modules.attention import SpatialTransformer, TemporalTransformer
17
 
 
 
 
 
 
 
18
 
 
 
19
  class TimestepBlock(nn.Module):
20
  """
21
  Any module where forward() takes timestep embeddings as a second argument.
22
  """
 
23
  @abstractmethod
24
  def forward(self, x, emb):
25
  """
 
27
  """
28
 
29
 
 
30
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
31
  """
32
  A sequential module that passes timestep embeddings to the children that
33
  support it as an extra input.
34
  """
35
 
36
+ def forward(self, x, emb, context=None, batch_size=None):
37
  for layer in self:
38
  if isinstance(layer, TimestepBlock):
39
+ x = layer(x, emb, batch_size)
40
+ elif isinstance(layer, SpatialTransformer):
41
+ x = layer(x, context)
42
+ elif isinstance(layer, TemporalTransformer):
43
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
44
+ x = layer(x, context)
45
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
46
  else:
47
+ x = layer(x,)
48
  return x
49
 
50
 
51
+ class Downsample(nn.Module):
 
52
  """
53
+ A downsampling layer with an optional convolution.
54
  :param channels: channels in the inputs and outputs.
55
  :param use_conv: a bool determining if a convolution is applied.
56
  :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
57
+ downsampling occurs in the inner-two dimensions.
58
  """
59
 
60
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
 
 
 
61
  super().__init__()
62
  self.channels = channels
63
  self.out_channels = out_channels or channels
64
  self.use_conv = use_conv
65
  self.dims = dims
66
+ stride = 2 if dims != 3 else (1, 2, 2)
67
  if use_conv:
68
+ self.op = conv_nd(
69
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
 
 
 
 
 
70
  )
71
  else:
72
+ assert self.channels == self.out_channels
73
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def forward(self, x):
76
+ assert x.shape[1] == self.channels
77
+ return self.op(x)
78
 
79
 
80
+ class Upsample(nn.Module):
 
81
  """
82
+ An upsampling layer with an optional convolution.
83
  :param channels: channels in the inputs and outputs.
84
  :param use_conv: a bool determining if a convolution is applied.
85
  :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
86
+ upsampling occurs in the inner-two dimensions.
87
  """
88
 
89
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
 
 
 
90
  super().__init__()
91
  self.channels = channels
92
  self.out_channels = out_channels or channels
93
  self.use_conv = use_conv
94
  self.dims = dims
 
95
  if use_conv:
96
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
 
 
 
 
 
97
 
98
  def forward(self, x):
99
  assert x.shape[1] == self.channels
100
+ if self.dims == 3:
101
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
102
+ else:
103
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
104
+ if self.use_conv:
105
+ x = self.conv(x)
106
+ return x
107
 
108
 
 
109
  class ResBlock(TimestepBlock):
110
  """
111
  A residual block that can optionally change the number of channels.
 
117
  convolution instead of a smaller 1x1 convolution to change the
118
  channels in the skip connection.
119
  :param dims: determines if the signal is 1D, 2D, or 3D.
 
120
  :param up: if True, use this block for upsampling.
121
  :param down: if True, use this block for downsampling.
122
  """
 
127
  emb_channels,
128
  dropout,
129
  out_channels=None,
 
130
  use_scale_shift_norm=False,
131
  dims=2,
132
  use_checkpoint=False,
133
+ use_conv=False,
134
  up=False,
135
  down=False,
136
+ use_temporal_conv=False,
137
+ tempspatial_aware=False
 
 
 
138
  ):
139
  super().__init__()
140
  self.channels = channels
 
144
  self.use_conv = use_conv
145
  self.use_checkpoint = use_checkpoint
146
  self.use_scale_shift_norm = use_scale_shift_norm
147
+ self.use_temporal_conv = use_temporal_conv
148
 
149
  self.in_layers = nn.Sequential(
150
  normalization(channels),
151
+ nn.SiLU(),
152
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
153
  )
154
 
155
  self.updown = up or down
156
 
157
  if up:
158
+ self.h_upd = Upsample(channels, False, dims)
159
+ self.x_upd = Upsample(channels, False, dims)
160
  elif down:
161
+ self.h_upd = Downsample(channels, False, dims)
162
+ self.x_upd = Downsample(channels, False, dims)
163
  else:
164
  self.h_upd = self.x_upd = nn.Identity()
165
 
166
  self.emb_layers = nn.Sequential(
167
+ nn.SiLU(),
168
+ nn.Linear(
169
  emb_channels,
170
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
171
  ),
172
  )
173
  self.out_layers = nn.Sequential(
174
  normalization(self.out_channels),
175
+ nn.SiLU(),
176
  nn.Dropout(p=dropout),
177
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
 
 
178
  )
179
 
180
  if self.out_channels == channels:
181
  self.skip_connection = nn.Identity()
182
  elif use_conv:
183
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
 
 
184
  else:
185
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
 
186
 
187
+ if self.use_temporal_conv:
188
+ self.temopral_conv = TemporalConvBlock(
189
+ self.out_channels,
190
+ self.out_channels,
191
+ dropout=0.1,
192
+ spatial_aware=tempspatial_aware
193
+ )
194
+
195
+ def forward(self, x, emb, batch_size=None):
196
  """
197
  Apply the block to a Tensor, conditioned on a timestep embedding.
198
  :param x: an [N x C x ...] Tensor of features.
199
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
200
  :return: an [N x C x ...] Tensor of outputs.
201
  """
202
+ input_tuple = (x, emb,)
203
+ if batch_size:
204
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
205
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
206
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
207
 
208
+ def _forward(self, x, emb, batch_size=None,):
209
  if self.updown:
210
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
211
  h = in_rest(x)
 
214
  h = in_conv(h)
215
  else:
216
  h = self.in_layers(x)
 
217
  emb_out = self.emb_layers(emb).type(h.dtype)
218
+ while len(emb_out.shape) < len(h.shape):
219
+ emb_out = emb_out[..., None]
 
 
 
220
  if self.use_scale_shift_norm:
221
  out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
222
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
223
  h = out_norm(h) * (1 + scale) + shift
224
  h = out_rest(h)
225
  else:
226
  h = h + emb_out
227
  h = self.out_layers(h)
228
+ h = self.skip_connection(x) + h
 
229
 
230
+ if self.use_temporal_conv and batch_size:
231
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
232
+ h = self.temopral_conv(h)
233
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
234
+ return h
235
+
236
+
237
+ class TemporalConvBlock(nn.Module):
238
+ """
239
+ Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
240
+ """
241
+
242
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
243
+ super(TemporalConvBlock, self).__init__()
244
+ if out_channels is None:
245
+ out_channels = in_channels
246
+ self.in_channels = in_channels
247
+ self.out_channels = out_channels
248
+ kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 3)
249
+ padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 1)
250
+
251
+ # conv layers
252
+ self.conv1 = nn.Sequential(
253
+ nn.GroupNorm(32, in_channels), nn.SiLU(),
254
+ nn.Conv3d(in_channels, out_channels, kernel_shape, padding=padding_shape))
255
+ self.conv2 = nn.Sequential(
256
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
257
+ nn.Conv3d(out_channels, in_channels, kernel_shape, padding=padding_shape))
258
+ self.conv3 = nn.Sequential(
259
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
260
+ nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)))
261
+ self.conv4 = nn.Sequential(
262
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
263
+ nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)))
264
+
265
+ # zero out the last layer params,so the conv block is identity
266
+ nn.init.zeros_(self.conv4[-1].weight)
267
+ nn.init.zeros_(self.conv4[-1].bias)
268
+
269
+ def forward(self, x):
270
+ identity = x
271
+ x = self.conv1(x)
272
+ x = self.conv2(x)
273
+ x = self.conv3(x)
274
+ x = self.conv4(x)
275
+
276
+ return x + identity
277
 
 
 
 
 
 
 
278
 
 
279
  class UNetModel(nn.Module):
280
  """
281
  The full UNet model with attention and timestep embedding.
282
+ :param in_channels: in_channels in the input Tensor.
283
  :param model_channels: base channel count for the model.
284
  :param out_channels: channels in the output Tensor.
285
  :param num_res_blocks: number of residual blocks per downsample.
 
302
  of heads for upsampling. Deprecated.
303
  :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
304
  :param resblock_updown: use residual blocks for up/downsampling.
 
 
305
  """
306
 
307
+ def __init__(self,
308
+ in_channels,
309
+ model_channels,
310
+ out_channels,
311
+ num_res_blocks,
312
+ attention_resolutions,
313
+ dropout=0.0,
314
+ channel_mult=(1, 2, 4, 8),
315
+ conv_resample=True,
316
+ dims=2,
317
+ context_dim=None,
318
+ use_scale_shift_norm=False,
319
+ resblock_updown=False,
320
+ num_heads=-1,
321
+ num_head_channels=-1,
322
+ transformer_depth=1,
323
+ use_linear=False,
324
+ use_checkpoint=False,
325
+ temporal_conv=False,
326
+ tempspatial_aware=False,
327
+ temporal_attention=True,
328
+ temporal_selfatt_only=True,
329
+ use_relative_position=True,
330
+ use_causal_attention=False,
331
+ temporal_length=None,
332
+ use_fp16=False,
333
+ addition_attention=False,
334
+ use_image_attention=False,
335
+ temporal_transformer_depth=1,
336
+ fps_cond=False,
337
+ ):
338
+ super(UNetModel, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  if num_heads == -1:
340
  assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
 
341
  if num_head_channels == -1:
342
  assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
343
 
 
344
  self.in_channels = in_channels
345
  self.model_channels = model_channels
346
  self.out_channels = out_channels
 
349
  self.dropout = dropout
350
  self.channel_mult = channel_mult
351
  self.conv_resample = conv_resample
352
+ self.temporal_attention = temporal_attention
353
+ time_embed_dim = model_channels * 4
354
  self.use_checkpoint = use_checkpoint
355
+ self.dtype = torch.float16 if use_fp16 else torch.float32
356
+ self.addition_attention=addition_attention
357
+ self.use_image_attention = use_image_attention
358
+ self.fps_cond=fps_cond
359
+
360
+
 
 
 
 
 
 
 
 
 
 
361
 
 
 
362
  self.time_embed = nn.Sequential(
363
  linear(model_channels, time_embed_dim),
364
+ nn.SiLU(),
365
  linear(time_embed_dim, time_embed_dim),
366
  )
367
+ if self.fps_cond:
368
+ self.fps_embedding = nn.Sequential(
369
+ linear(model_channels, time_embed_dim),
370
+ nn.SiLU(),
371
+ linear(time_embed_dim, time_embed_dim),
372
+ )
373
 
374
  self.input_blocks = nn.ModuleList(
375
  [
376
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
 
 
377
  ]
378
  )
379
+ if self.addition_attention:
380
+ self.init_attn=TimestepEmbedSequential(
381
+ TemporalTransformer(
382
+ model_channels,
383
+ n_heads=8,
384
+ d_head=num_head_channels,
385
+ depth=transformer_depth,
386
+ context_dim=context_dim,
387
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
388
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
389
+ temporal_length=temporal_length))
390
+
391
  input_block_chans = [model_channels]
392
  ch = model_channels
393
  ds = 1
394
  for level, mult in enumerate(channel_mult):
395
  for _ in range(num_res_blocks):
396
  layers = [
397
+ ResBlock(ch, time_embed_dim, dropout,
398
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
399
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
400
+ use_temporal_conv=temporal_conv
 
 
 
 
 
 
 
 
401
  )
402
  ]
403
  ch = mult * model_channels
 
407
  else:
408
  num_heads = ch // num_head_channels
409
  dim_head = num_head_channels
410
+ layers.append(
411
+ SpatialTransformer(ch, num_heads, dim_head,
412
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
413
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
414
+ img_cross_attention=self.use_image_attention
415
+ )
416
+ )
417
+ if self.temporal_attention:
418
+ layers.append(
419
+ TemporalTransformer(ch, num_heads, dim_head,
420
+ depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear,
421
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
422
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
423
+ temporal_length=temporal_length
424
+ )
425
+ )
426
  self.input_blocks.append(TimestepEmbedSequential(*layers))
 
427
  input_block_chans.append(ch)
428
  if level != len(channel_mult) - 1:
429
  out_ch = ch
430
  self.input_blocks.append(
431
  TimestepEmbedSequential(
432
+ ResBlock(ch, time_embed_dim, dropout,
433
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
 
 
 
 
 
434
  use_scale_shift_norm=use_scale_shift_norm,
435
+ down=True
 
 
 
 
436
  )
437
  if resblock_updown
438
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
 
 
439
  )
440
  )
441
  ch = out_ch
442
  input_block_chans.append(ch)
443
  ds *= 2
 
444
 
445
  if num_head_channels == -1:
446
  dim_head = ch // num_heads
447
  else:
448
  num_heads = ch // num_head_channels
449
  dim_head = num_head_channels
450
+ layers = [
451
+ ResBlock(ch, time_embed_dim, dropout,
452
+ dims=dims, use_checkpoint=use_checkpoint,
453
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
454
+ use_temporal_conv=temporal_conv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  ),
456
+ SpatialTransformer(ch, num_heads, dim_head,
457
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
458
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
459
+ img_cross_attention=self.use_image_attention
460
+ )
461
+ ]
462
+ if self.temporal_attention:
463
+ layers.append(
464
+ TemporalTransformer(ch, num_heads, dim_head,
465
+ depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear,
466
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
467
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
468
+ temporal_length=temporal_length
469
+ )
470
+ )
471
+ layers.append(
472
+ ResBlock(ch, time_embed_dim, dropout,
473
+ dims=dims, use_checkpoint=use_checkpoint,
474
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
475
+ use_temporal_conv=temporal_conv
476
+ )
477
  )
478
+ self.middle_block = TimestepEmbedSequential(*layers)
479
 
480
  self.output_blocks = nn.ModuleList([])
481
  for level, mult in list(enumerate(channel_mult))[::-1]:
482
  for i in range(num_res_blocks + 1):
483
  ich = input_block_chans.pop()
484
  layers = [
485
+ ResBlock(ch + ich, time_embed_dim, dropout,
486
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
487
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
488
+ use_temporal_conv=temporal_conv
 
 
 
 
 
 
 
 
489
  )
490
  ]
491
  ch = model_channels * mult
 
495
  else:
496
  num_heads = ch // num_head_channels
497
  dim_head = num_head_channels
 
 
498
  layers.append(
499
+ SpatialTransformer(ch, num_heads, dim_head,
500
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
501
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
502
+ img_cross_attention=self.use_image_attention
 
 
 
 
 
 
 
 
 
503
  )
504
  )
505
+ if self.temporal_attention:
506
+ layers.append(
507
+ TemporalTransformer(ch, num_heads, dim_head,
508
+ depth=temporal_transformer_depth, context_dim=context_dim, use_linear=use_linear,
509
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
510
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
511
+ temporal_length=temporal_length
512
+ )
513
+ )
514
  if level and i == num_res_blocks:
515
  out_ch = ch
516
  layers.append(
517
+ ResBlock(ch, time_embed_dim, dropout,
518
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
 
 
 
 
 
519
  use_scale_shift_norm=use_scale_shift_norm,
520
+ up=True
 
 
 
 
521
  )
522
  if resblock_updown
523
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
524
  )
525
  ds //= 2
526
  self.output_blocks.append(TimestepEmbedSequential(*layers))
 
527
 
528
  self.out = nn.Sequential(
529
  normalization(ch),
530
+ nn.SiLU(),
531
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
532
  )
 
533
 
534
+ def forward(self, x, timesteps, context=None, features_adapter=None, fps=16, **kwargs):
535
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
536
+ emb = self.time_embed(t_emb)
 
 
 
 
537
 
538
+ if self.fps_cond:
539
+ if type(fps) == int:
540
+ fps = torch.full_like(timesteps, fps)
541
+ fps_emb = timestep_embedding(fps,self.model_channels, repeat_only=False)
542
+ emb += self.fps_embedding(fps_emb)
 
 
543
 
544
+ b,_,t,_,_ = x.shape
545
+ ## repeat t times for context [(b t) 77 768] & time embedding
546
+ context = context.repeat_interleave(repeats=t, dim=0)
547
+ emb = emb.repeat_interleave(repeats=t, dim=0)
548
+
549
+ ## always in shape (b t) c h w, except for temporal layer
550
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
  h = x.type(self.dtype)
553
  adapter_idx = 0
554
+ hs = []
555
  for id, module in enumerate(self.input_blocks):
556
+ h = module(h, emb, context=context, batch_size=b)
557
+ if id ==0 and self.addition_attention:
558
+ h = self.init_attn(h, emb, context=context, batch_size=b)
559
  ## plug-in adapter features
560
  if ((id+1)%3 == 0) and features_adapter is not None:
561
  h = h + features_adapter[adapter_idx]
562
  adapter_idx += 1
563
  hs.append(h)
564
  if features_adapter is not None:
565
+ assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
566
 
567
+ h = self.middle_block(h, emb, context=context, batch_size=b)
568
  for module in self.output_blocks:
569
+ h = torch.cat([h, hs.pop()], dim=1)
570
+ h = module(h, emb, context=context, batch_size=b)
571
  h = h.type(x.dtype)
572
+ y = self.out(h)
573
+
574
+ # reshape back to (b c t h w)
575
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
576
+ return y
577
+
lvdm/modules/x_transformer.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2
+ from functools import partial
3
+ from inspect import isfunction
4
+ from collections import namedtuple
5
+ from einops import rearrange, repeat
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ # constants
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple('Intermediates', [
14
+ 'pre_softmax_attn',
15
+ 'post_softmax_attn'
16
+ ])
17
+
18
+ LayerIntermediates = namedtuple('Intermediates', [
19
+ 'hiddens',
20
+ 'attn_intermediates'
21
+ ])
22
+
23
+
24
+ class AbsolutePositionalEmbedding(nn.Module):
25
+ def __init__(self, dim, max_seq_len):
26
+ super().__init__()
27
+ self.emb = nn.Embedding(max_seq_len, dim)
28
+ self.init_()
29
+
30
+ def init_(self):
31
+ nn.init.normal_(self.emb.weight, std=0.02)
32
+
33
+ def forward(self, x):
34
+ n = torch.arange(x.shape[1], device=x.device)
35
+ return self.emb(n)[None, :, :]
36
+
37
+
38
+ class FixedPositionalEmbedding(nn.Module):
39
+ def __init__(self, dim):
40
+ super().__init__()
41
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
42
+ self.register_buffer('inv_freq', inv_freq)
43
+
44
+ def forward(self, x, seq_dim=1, offset=0):
45
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
46
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
47
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
48
+ return emb[None, :, :]
49
+
50
+
51
+ # helpers
52
+
53
+ def exists(val):
54
+ return val is not None
55
+
56
+
57
+ def default(val, d):
58
+ if exists(val):
59
+ return val
60
+ return d() if isfunction(d) else d
61
+
62
+
63
+ def always(val):
64
+ def inner(*args, **kwargs):
65
+ return val
66
+ return inner
67
+
68
+
69
+ def not_equals(val):
70
+ def inner(x):
71
+ return x != val
72
+ return inner
73
+
74
+
75
+ def equals(val):
76
+ def inner(x):
77
+ return x == val
78
+ return inner
79
+
80
+
81
+ def max_neg_value(tensor):
82
+ return -torch.finfo(tensor.dtype).max
83
+
84
+
85
+ # keyword argument helpers
86
+
87
+ def pick_and_pop(keys, d):
88
+ values = list(map(lambda key: d.pop(key), keys))
89
+ return dict(zip(keys, values))
90
+
91
+
92
+ def group_dict_by_key(cond, d):
93
+ return_val = [dict(), dict()]
94
+ for key in d.keys():
95
+ match = bool(cond(key))
96
+ ind = int(not match)
97
+ return_val[ind][key] = d[key]
98
+ return (*return_val,)
99
+
100
+
101
+ def string_begins_with(prefix, str):
102
+ return str.startswith(prefix)
103
+
104
+
105
+ def group_by_key_prefix(prefix, d):
106
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
107
+
108
+
109
+ def groupby_prefix_and_trim(prefix, d):
110
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
111
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
112
+ return kwargs_without_prefix, kwargs
113
+
114
+
115
+ # classes
116
+ class Scale(nn.Module):
117
+ def __init__(self, value, fn):
118
+ super().__init__()
119
+ self.value = value
120
+ self.fn = fn
121
+
122
+ def forward(self, x, **kwargs):
123
+ x, *rest = self.fn(x, **kwargs)
124
+ return (x * self.value, *rest)
125
+
126
+
127
+ class Rezero(nn.Module):
128
+ def __init__(self, fn):
129
+ super().__init__()
130
+ self.fn = fn
131
+ self.g = nn.Parameter(torch.zeros(1))
132
+
133
+ def forward(self, x, **kwargs):
134
+ x, *rest = self.fn(x, **kwargs)
135
+ return (x * self.g, *rest)
136
+
137
+
138
+ class ScaleNorm(nn.Module):
139
+ def __init__(self, dim, eps=1e-5):
140
+ super().__init__()
141
+ self.scale = dim ** -0.5
142
+ self.eps = eps
143
+ self.g = nn.Parameter(torch.ones(1))
144
+
145
+ def forward(self, x):
146
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
147
+ return x / norm.clamp(min=self.eps) * self.g
148
+
149
+
150
+ class RMSNorm(nn.Module):
151
+ def __init__(self, dim, eps=1e-8):
152
+ super().__init__()
153
+ self.scale = dim ** -0.5
154
+ self.eps = eps
155
+ self.g = nn.Parameter(torch.ones(dim))
156
+
157
+ def forward(self, x):
158
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
159
+ return x / norm.clamp(min=self.eps) * self.g
160
+
161
+
162
+ class Residual(nn.Module):
163
+ def forward(self, x, residual):
164
+ return x + residual
165
+
166
+
167
+ class GRUGating(nn.Module):
168
+ def __init__(self, dim):
169
+ super().__init__()
170
+ self.gru = nn.GRUCell(dim, dim)
171
+
172
+ def forward(self, x, residual):
173
+ gated_output = self.gru(
174
+ rearrange(x, 'b n d -> (b n) d'),
175
+ rearrange(residual, 'b n d -> (b n) d')
176
+ )
177
+
178
+ return gated_output.reshape_as(x)
179
+
180
+
181
+ # feedforward
182
+
183
+ class GEGLU(nn.Module):
184
+ def __init__(self, dim_in, dim_out):
185
+ super().__init__()
186
+ self.proj = nn.Linear(dim_in, dim_out * 2)
187
+
188
+ def forward(self, x):
189
+ x, gate = self.proj(x).chunk(2, dim=-1)
190
+ return x * F.gelu(gate)
191
+
192
+
193
+ class FeedForward(nn.Module):
194
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
195
+ super().__init__()
196
+ inner_dim = int(dim * mult)
197
+ dim_out = default(dim_out, dim)
198
+ project_in = nn.Sequential(
199
+ nn.Linear(dim, inner_dim),
200
+ nn.GELU()
201
+ ) if not glu else GEGLU(dim, inner_dim)
202
+
203
+ self.net = nn.Sequential(
204
+ project_in,
205
+ nn.Dropout(dropout),
206
+ nn.Linear(inner_dim, dim_out)
207
+ )
208
+
209
+ def forward(self, x):
210
+ return self.net(x)
211
+
212
+
213
+ # attention.
214
+ class Attention(nn.Module):
215
+ def __init__(
216
+ self,
217
+ dim,
218
+ dim_head=DEFAULT_DIM_HEAD,
219
+ heads=8,
220
+ causal=False,
221
+ mask=None,
222
+ talking_heads=False,
223
+ sparse_topk=None,
224
+ use_entmax15=False,
225
+ num_mem_kv=0,
226
+ dropout=0.,
227
+ on_attn=False
228
+ ):
229
+ super().__init__()
230
+ if use_entmax15:
231
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
232
+ self.scale = dim_head ** -0.5
233
+ self.heads = heads
234
+ self.causal = causal
235
+ self.mask = mask
236
+
237
+ inner_dim = dim_head * heads
238
+
239
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
240
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
241
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
242
+ self.dropout = nn.Dropout(dropout)
243
+
244
+ # talking heads
245
+ self.talking_heads = talking_heads
246
+ if talking_heads:
247
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
248
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249
+
250
+ # explicit topk sparse attention
251
+ self.sparse_topk = sparse_topk
252
+
253
+ # entmax
254
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
255
+ self.attn_fn = F.softmax
256
+
257
+ # add memory key / values
258
+ self.num_mem_kv = num_mem_kv
259
+ if num_mem_kv > 0:
260
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
261
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262
+
263
+ # attention on attention
264
+ self.attn_on_attn = on_attn
265
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
266
+
267
+ def forward(
268
+ self,
269
+ x,
270
+ context=None,
271
+ mask=None,
272
+ context_mask=None,
273
+ rel_pos=None,
274
+ sinusoidal_emb=None,
275
+ prev_attn=None,
276
+ mem=None
277
+ ):
278
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
279
+ kv_input = default(context, x)
280
+
281
+ q_input = x
282
+ k_input = kv_input
283
+ v_input = kv_input
284
+
285
+ if exists(mem):
286
+ k_input = torch.cat((mem, k_input), dim=-2)
287
+ v_input = torch.cat((mem, v_input), dim=-2)
288
+
289
+ if exists(sinusoidal_emb):
290
+ # in shortformer, the query would start at a position offset depending on the past cached memory
291
+ offset = k_input.shape[-2] - q_input.shape[-2]
292
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
293
+ k_input = k_input + sinusoidal_emb(k_input)
294
+
295
+ q = self.to_q(q_input)
296
+ k = self.to_k(k_input)
297
+ v = self.to_v(v_input)
298
+
299
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
300
+
301
+ input_mask = None
302
+ if any(map(exists, (mask, context_mask))):
303
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
304
+ k_mask = q_mask if not exists(context) else context_mask
305
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
306
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
307
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
308
+ input_mask = q_mask * k_mask
309
+
310
+ if self.num_mem_kv > 0:
311
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
312
+ k = torch.cat((mem_k, k), dim=-2)
313
+ v = torch.cat((mem_v, v), dim=-2)
314
+ if exists(input_mask):
315
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
316
+
317
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
318
+ mask_value = max_neg_value(dots)
319
+
320
+ if exists(prev_attn):
321
+ dots = dots + prev_attn
322
+
323
+ pre_softmax_attn = dots
324
+
325
+ if talking_heads:
326
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
327
+
328
+ if exists(rel_pos):
329
+ dots = rel_pos(dots)
330
+
331
+ if exists(input_mask):
332
+ dots.masked_fill_(~input_mask, mask_value)
333
+ del input_mask
334
+
335
+ if self.causal:
336
+ i, j = dots.shape[-2:]
337
+ r = torch.arange(i, device=device)
338
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
339
+ mask = F.pad(mask, (j - i, 0), value=False)
340
+ dots.masked_fill_(mask, mask_value)
341
+ del mask
342
+
343
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
344
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
345
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
346
+ mask = dots < vk
347
+ dots.masked_fill_(mask, mask_value)
348
+ del mask
349
+
350
+ attn = self.attn_fn(dots, dim=-1)
351
+ post_softmax_attn = attn
352
+
353
+ attn = self.dropout(attn)
354
+
355
+ if talking_heads:
356
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
357
+
358
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
359
+ out = rearrange(out, 'b h n d -> b n (h d)')
360
+
361
+ intermediates = Intermediates(
362
+ pre_softmax_attn=pre_softmax_attn,
363
+ post_softmax_attn=post_softmax_attn
364
+ )
365
+
366
+ return self.to_out(out), intermediates
367
+
368
+
369
+ class AttentionLayers(nn.Module):
370
+ def __init__(
371
+ self,
372
+ dim,
373
+ depth,
374
+ heads=8,
375
+ causal=False,
376
+ cross_attend=False,
377
+ only_cross=False,
378
+ use_scalenorm=False,
379
+ use_rmsnorm=False,
380
+ use_rezero=False,
381
+ rel_pos_num_buckets=32,
382
+ rel_pos_max_distance=128,
383
+ position_infused_attn=False,
384
+ custom_layers=None,
385
+ sandwich_coef=None,
386
+ par_ratio=None,
387
+ residual_attn=False,
388
+ cross_residual_attn=False,
389
+ macaron=False,
390
+ pre_norm=True,
391
+ gate_residual=False,
392
+ **kwargs
393
+ ):
394
+ super().__init__()
395
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
396
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
397
+
398
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
399
+
400
+ self.dim = dim
401
+ self.depth = depth
402
+ self.layers = nn.ModuleList([])
403
+
404
+ self.has_pos_emb = position_infused_attn
405
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
406
+ self.rotary_pos_emb = always(None)
407
+
408
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
409
+ self.rel_pos = None
410
+
411
+ self.pre_norm = pre_norm
412
+
413
+ self.residual_attn = residual_attn
414
+ self.cross_residual_attn = cross_residual_attn
415
+
416
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
417
+ norm_class = RMSNorm if use_rmsnorm else norm_class
418
+ norm_fn = partial(norm_class, dim)
419
+
420
+ norm_fn = nn.Identity if use_rezero else norm_fn
421
+ branch_fn = Rezero if use_rezero else None
422
+
423
+ if cross_attend and not only_cross:
424
+ default_block = ('a', 'c', 'f')
425
+ elif cross_attend and only_cross:
426
+ default_block = ('c', 'f')
427
+ else:
428
+ default_block = ('a', 'f')
429
+
430
+ if macaron:
431
+ default_block = ('f',) + default_block
432
+
433
+ if exists(custom_layers):
434
+ layer_types = custom_layers
435
+ elif exists(par_ratio):
436
+ par_depth = depth * len(default_block)
437
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
438
+ default_block = tuple(filter(not_equals('f'), default_block))
439
+ par_attn = par_depth // par_ratio
440
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
441
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
442
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
443
+ par_block = default_block + ('f',) * (par_width - len(default_block))
444
+ par_head = par_block * par_attn
445
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
446
+ elif exists(sandwich_coef):
447
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
448
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
449
+ else:
450
+ layer_types = default_block * depth
451
+
452
+ self.layer_types = layer_types
453
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
454
+
455
+ for layer_type in self.layer_types:
456
+ if layer_type == 'a':
457
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
458
+ elif layer_type == 'c':
459
+ layer = Attention(dim, heads=heads, **attn_kwargs)
460
+ elif layer_type == 'f':
461
+ layer = FeedForward(dim, **ff_kwargs)
462
+ layer = layer if not macaron else Scale(0.5, layer)
463
+ else:
464
+ raise Exception(f'invalid layer type {layer_type}')
465
+
466
+ if isinstance(layer, Attention) and exists(branch_fn):
467
+ layer = branch_fn(layer)
468
+
469
+ if gate_residual:
470
+ residual_fn = GRUGating(dim)
471
+ else:
472
+ residual_fn = Residual()
473
+
474
+ self.layers.append(nn.ModuleList([
475
+ norm_fn(),
476
+ layer,
477
+ residual_fn
478
+ ]))
479
+
480
+ def forward(
481
+ self,
482
+ x,
483
+ context=None,
484
+ mask=None,
485
+ context_mask=None,
486
+ mems=None,
487
+ return_hiddens=False
488
+ ):
489
+ hiddens = []
490
+ intermediates = []
491
+ prev_attn = None
492
+ prev_cross_attn = None
493
+
494
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
495
+
496
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
497
+ is_last = ind == (len(self.layers) - 1)
498
+
499
+ if layer_type == 'a':
500
+ hiddens.append(x)
501
+ layer_mem = mems.pop(0)
502
+
503
+ residual = x
504
+
505
+ if self.pre_norm:
506
+ x = norm(x)
507
+
508
+ if layer_type == 'a':
509
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
510
+ prev_attn=prev_attn, mem=layer_mem)
511
+ elif layer_type == 'c':
512
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
513
+ elif layer_type == 'f':
514
+ out = block(x)
515
+
516
+ x = residual_fn(out, residual)
517
+
518
+ if layer_type in ('a', 'c'):
519
+ intermediates.append(inter)
520
+
521
+ if layer_type == 'a' and self.residual_attn:
522
+ prev_attn = inter.pre_softmax_attn
523
+ elif layer_type == 'c' and self.cross_residual_attn:
524
+ prev_cross_attn = inter.pre_softmax_attn
525
+
526
+ if not self.pre_norm and not is_last:
527
+ x = norm(x)
528
+
529
+ if return_hiddens:
530
+ intermediates = LayerIntermediates(
531
+ hiddens=hiddens,
532
+ attn_intermediates=intermediates
533
+ )
534
+
535
+ return x, intermediates
536
+
537
+ return x
538
+
539
+
540
+ class Encoder(AttentionLayers):
541
+ def __init__(self, **kwargs):
542
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
543
+ super().__init__(causal=False, **kwargs)
544
+
545
+
546
+
547
+ class TransformerWrapper(nn.Module):
548
+ def __init__(
549
+ self,
550
+ *,
551
+ num_tokens,
552
+ max_seq_len,
553
+ attn_layers,
554
+ emb_dim=None,
555
+ max_mem_len=0.,
556
+ emb_dropout=0.,
557
+ num_memory_tokens=None,
558
+ tie_embedding=False,
559
+ use_pos_emb=True
560
+ ):
561
+ super().__init__()
562
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
563
+
564
+ dim = attn_layers.dim
565
+ emb_dim = default(emb_dim, dim)
566
+
567
+ self.max_seq_len = max_seq_len
568
+ self.max_mem_len = max_mem_len
569
+ self.num_tokens = num_tokens
570
+
571
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
572
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
573
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
574
+ self.emb_dropout = nn.Dropout(emb_dropout)
575
+
576
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
577
+ self.attn_layers = attn_layers
578
+ self.norm = nn.LayerNorm(dim)
579
+
580
+ self.init_()
581
+
582
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
583
+
584
+ # memory tokens (like [cls]) from Memory Transformers paper
585
+ num_memory_tokens = default(num_memory_tokens, 0)
586
+ self.num_memory_tokens = num_memory_tokens
587
+ if num_memory_tokens > 0:
588
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
589
+
590
+ # let funnel encoder know number of memory tokens, if specified
591
+ if hasattr(attn_layers, 'num_memory_tokens'):
592
+ attn_layers.num_memory_tokens = num_memory_tokens
593
+
594
+ def init_(self):
595
+ nn.init.normal_(self.token_emb.weight, std=0.02)
596
+
597
+ def forward(
598
+ self,
599
+ x,
600
+ return_embeddings=False,
601
+ mask=None,
602
+ return_mems=False,
603
+ return_attn=False,
604
+ mems=None,
605
+ **kwargs
606
+ ):
607
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
608
+ x = self.token_emb(x)
609
+ x += self.pos_emb(x)
610
+ x = self.emb_dropout(x)
611
+
612
+ x = self.project_emb(x)
613
+
614
+ if num_mem > 0:
615
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
616
+ x = torch.cat((mem, x), dim=1)
617
+
618
+ # auto-handle masking after appending memory tokens
619
+ if exists(mask):
620
+ mask = F.pad(mask, (num_mem, 0), value=True)
621
+
622
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
623
+ x = self.norm(x)
624
+
625
+ mem, x = x[:, :num_mem], x[:, num_mem:]
626
+
627
+ out = self.to_logits(x) if not return_embeddings else x
628
+
629
+ if return_mems:
630
+ hiddens = intermediates.hiddens
631
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
632
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
633
+ return out, new_mems
634
+
635
+ if return_attn:
636
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
637
+ return out, attn_maps
638
+
639
+ return out
640
+
lvdm/utils/common_utils.py DELETED
@@ -1,132 +0,0 @@
1
-
2
- import importlib
3
-
4
- import torch
5
- import numpy as np
6
-
7
- from inspect import isfunction
8
- from PIL import Image, ImageDraw, ImageFont
9
-
10
-
11
- def str2bool(v):
12
- if isinstance(v, bool):
13
- return v
14
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
15
- return True
16
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
17
- return False
18
- else:
19
- raise ValueError('Boolean value expected.')
20
-
21
-
22
- def instantiate_from_config(config):
23
- if not "target" in config:
24
- if config == '__is_first_stage__':
25
- return None
26
- elif config == "__is_unconditional__":
27
- return None
28
- raise KeyError("Expected key `target` to instantiate.")
29
-
30
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
31
-
32
- def get_obj_from_str(string, reload=False):
33
- module, cls = string.rsplit(".", 1)
34
- if reload:
35
- module_imp = importlib.import_module(module)
36
- importlib.reload(module_imp)
37
- return getattr(importlib.import_module(module, package=None), cls)
38
-
39
- def log_txt_as_img(wh, xc, size=10):
40
- # wh a tuple of (width, height)
41
- # xc a list of captions to plot
42
- b = len(xc)
43
- txts = list()
44
- for bi in range(b):
45
- txt = Image.new("RGB", wh, color="white")
46
- draw = ImageDraw.Draw(txt)
47
- font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
48
- nc = int(40 * (wh[0] / 256))
49
- lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
50
-
51
- try:
52
- draw.text((0, 0), lines, fill="black", font=font)
53
- except UnicodeEncodeError:
54
- print("Cant encode string for logging. Skipping.")
55
-
56
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
57
- txts.append(txt)
58
- txts = np.stack(txts)
59
- txts = torch.tensor(txts)
60
- return txts
61
-
62
-
63
- def ismap(x):
64
- if not isinstance(x, torch.Tensor):
65
- return False
66
- return (len(x.shape) == 4) and (x.shape[1] > 3)
67
-
68
-
69
- def isimage(x):
70
- if not isinstance(x,torch.Tensor):
71
- return False
72
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
73
-
74
-
75
- def exists(x):
76
- return x is not None
77
-
78
-
79
- def default(val, d):
80
- if exists(val):
81
- return val
82
- return d() if isfunction(d) else d
83
-
84
-
85
- def mean_flat(tensor):
86
- """
87
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
88
- Take the mean over all non-batch dimensions.
89
- """
90
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
-
92
-
93
- def count_params(model, verbose=False):
94
- total_params = sum(p.numel() for p in model.parameters())
95
- if verbose:
96
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
97
- return total_params
98
-
99
-
100
- def instantiate_from_config(config):
101
- if not "target" in config:
102
- if config == '__is_first_stage__':
103
- return None
104
- elif config == "__is_unconditional__":
105
- return None
106
- raise KeyError("Expected key `target` to instantiate.")
107
-
108
- if "instantiate_with_dict" in config and config["instantiate_with_dict"]:
109
- # input parameter is one dict
110
- return get_obj_from_str(config["target"])(config.get("params", dict()), **kwargs)
111
- else:
112
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
113
-
114
-
115
- def get_obj_from_str(string, reload=False):
116
- module, cls = string.rsplit(".", 1)
117
- if reload:
118
- module_imp = importlib.import_module(module)
119
- importlib.reload(module_imp)
120
- return getattr(importlib.import_module(module, package=None), cls)
121
-
122
-
123
- def check_istarget(name, para_list):
124
- """
125
- name: full name of source para
126
- para_list: partial name of target para
127
- """
128
- istarget=False
129
- for para in para_list:
130
- if para in name:
131
- return True
132
- return istarget
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/utils/dist_utils.py DELETED
@@ -1,19 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
-
4
- def setup_dist(local_rank):
5
- if dist.is_initialized():
6
- return
7
- torch.cuda.set_device(local_rank)
8
- torch.distributed.init_process_group(
9
- 'nccl',
10
- init_method='env://'
11
- )
12
-
13
- def gather_data(data, return_np=True):
14
- ''' gather data from multiple processes to one list '''
15
- data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
16
- dist.all_gather(data_list, data) # gather not supported with NCCL
17
- if return_np:
18
- data_list = [data.cpu().numpy() for data in data_list]
19
- return data_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lvdm/utils/saving_utils.py DELETED
@@ -1,269 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import os
4
- import time
5
- import imageio
6
- from tqdm import tqdm
7
- from PIL import Image
8
- import os
9
- import sys
10
- sys.path.insert(1, os.path.join(sys.path[0], '..'))
11
- import torch
12
- import torchvision
13
- from torchvision.utils import make_grid
14
- from torch import Tensor
15
- from torchvision.transforms.functional import to_tensor
16
-
17
-
18
- def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
19
- """
20
- video: torch.Tensor, b,c,t,h,w, 0-1
21
- if -1~1, enable rescale=True
22
- """
23
- n = video.shape[0]
24
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
25
- nrow = int(np.sqrt(n)) if nrow is None else nrow
26
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
27
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
28
- grid = torch.clamp(grid.float(), -1., 1.)
29
- if rescale:
30
- grid = (grid + 1.0) / 2.0
31
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
32
- #print(f'Save video to {savepath}')
33
- torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
34
-
35
- # ----------------------------------------------------------------------------------------------
36
- def savenp2sheet(imgs, savepath, nrow=None):
37
- """ save multiple imgs (in numpy array type) to a img sheet.
38
- img sheet is one row.
39
-
40
- imgs:
41
- np array of size [N, H, W, 3] or List[array] with array size = [H,W,3]
42
- """
43
- if imgs.ndim == 4:
44
- img_list = [imgs[i] for i in range(imgs.shape[0])]
45
- imgs = img_list
46
-
47
- imgs_new = []
48
- for i, img in enumerate(imgs):
49
- if img.ndim == 3 and img.shape[0] == 3:
50
- img = np.transpose(img,(1,2,0))
51
-
52
- assert(img.ndim == 3 and img.shape[-1] == 3), img.shape # h,w,3
53
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
54
- imgs_new.append(img)
55
- n = len(imgs)
56
- if nrow is not None:
57
- n_cols = nrow
58
- else:
59
- n_cols=int(n**0.5)
60
- n_rows=int(np.ceil(n/n_cols))
61
- print(n_cols)
62
- print(n_rows)
63
-
64
- imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i*n_cols:(i+1)*n_cols]) for i in range(n_rows)])
65
- cv2.imwrite(savepath, imgsheet)
66
- print(f'saved in {savepath}')
67
-
68
- # ----------------------------------------------------------------------------------------------
69
- def save_np_to_img(img, path, norm=True):
70
- if norm:
71
- img = (img + 1) / 2 * 255
72
- img = img.astype(np.uint8)
73
- image = Image.fromarray(img)
74
- image.save(path, q=95)
75
-
76
- # ----------------------------------------------------------------------------------------------
77
- def npz_to_imgsheet_5d(data_path, res_dir, nrow=None,):
78
- if isinstance(data_path, str):
79
- imgs = np.load(data_path)['arr_0'] # NTHWC
80
- elif isinstance(data_path, np.ndarray):
81
- imgs = data_path
82
- else:
83
- raise Exception
84
-
85
- if os.path.isdir(res_dir):
86
- res_path = os.path.join(res_dir, f'samples.jpg')
87
- else:
88
- assert(res_dir.endswith('.jpg'))
89
- res_path = res_dir
90
- imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0)
91
- savenp2sheet(imgs, res_path, nrow=nrow)
92
-
93
- # ----------------------------------------------------------------------------------------------
94
- def npz_to_imgsheet_4d(data_path, res_path, nrow=None,):
95
- if isinstance(data_path, str):
96
- imgs = np.load(data_path)['arr_0'] # NHWC
97
- elif isinstance(data_path, np.ndarray):
98
- imgs = data_path
99
- else:
100
- raise Exception
101
- print(imgs.shape)
102
- savenp2sheet(imgs, res_path, nrow=nrow)
103
-
104
-
105
- # ----------------------------------------------------------------------------------------------
106
- def tensor_to_imgsheet(tensor, save_path):
107
- """
108
- save a batch of videos in one image sheet with shape of [batch_size * num_frames].
109
- data: [b,c,t,h,w]
110
- """
111
- assert(tensor.dim() == 5)
112
- b,c,t,h,w = tensor.shape
113
- imgs = [tensor[bi,:,ti, :, :] for bi in range(b) for ti in range(t)]
114
- torchvision.utils.save_image(imgs, save_path, normalize=True, nrow=t)
115
-
116
-
117
- # ----------------------------------------------------------------------------------------------
118
- def npz_to_frames(data_path, res_dir, norm, num_frames=None, num_samples=None):
119
- start = time.time()
120
- arr = np.load(data_path)
121
- imgs = arr['arr_0'] # [N, T, H, W, 3]
122
- print('original data shape: ', imgs.shape)
123
-
124
- if num_samples is not None:
125
- imgs = imgs[:num_samples, :, :, :, :]
126
- print('after sample selection: ', imgs.shape)
127
-
128
- if num_frames is not None:
129
- imgs = imgs[:, :num_frames, :, :, :]
130
- print('after frame selection: ', imgs.shape)
131
-
132
- for vid in tqdm(range(imgs.shape[0]), desc='Video'):
133
- video_dir = os.path.join(res_dir, f'video{vid:04d}')
134
- os.makedirs(video_dir, exist_ok=True)
135
- for fid in range(imgs.shape[1]):
136
- frame = imgs[vid, fid, :, :, :] #HW3
137
- save_np_to_img(frame, os.path.join(video_dir, f'frame{fid:04d}.jpg'), norm=norm)
138
- print('Finish')
139
- print(f'Total time = {time.time()- start}')
140
-
141
- # ----------------------------------------------------------------------------------------------
142
- def npz_to_gifs(data_path, res_dir, duration=0.2, start_idx=0, num_videos=None, mode='gif'):
143
- os.makedirs(res_dir, exist_ok=True)
144
- if isinstance(data_path, str):
145
- imgs = np.load(data_path)['arr_0'] # NTHWC
146
- elif isinstance(data_path, np.ndarray):
147
- imgs = data_path
148
- else:
149
- raise Exception
150
-
151
- for i in range(imgs.shape[0]):
152
- frames = [imgs[i,j,:,:,:] for j in range(imgs[i].shape[0])] # [(h,w,3)]
153
- if mode == 'gif':
154
- imageio.mimwrite(os.path.join(res_dir, f'samples_{start_idx+i}.gif'), frames, format='GIF', duration=duration)
155
- elif mode == 'mp4':
156
- frames = [torch.from_numpy(frame) for frame in frames]
157
- frames = torch.stack(frames, dim=0).to(torch.uint8) # [T, H, W, C]
158
- torchvision.io.write_video(os.path.join(res_dir, f'samples_{start_idx+i}.mp4'),
159
- frames, fps=0.5, video_codec='h264', options={'crf': '10'})
160
- if i+ 1 == num_videos:
161
- break
162
-
163
- # ----------------------------------------------------------------------------------------------
164
- def fill_with_black_squares(video, desired_len: int) -> Tensor:
165
- if len(video) >= desired_len:
166
- return video
167
-
168
- return torch.cat([
169
- video,
170
- torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1),
171
- ], dim=0)
172
-
173
- # ----------------------------------------------------------------------------------------------
174
- def load_num_videos(data_path, num_videos):
175
- # data_path can be either data_path of np array
176
- if isinstance(data_path, str):
177
- videos = np.load(data_path)['arr_0'] # NTHWC
178
- elif isinstance(data_path, np.ndarray):
179
- videos = data_path
180
- else:
181
- raise Exception
182
-
183
- if num_videos is not None:
184
- videos = videos[:num_videos, :, :, :, :]
185
- return videos
186
-
187
- # ----------------------------------------------------------------------------------------------
188
- def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True):
189
- if isinstance(data_path, str):
190
- videos = load_num_videos(data_path, num_videos)
191
- elif isinstance(data_path, np.ndarray):
192
- videos = data_path
193
- else:
194
- raise Exception
195
- n,t,h,w,c = videos.shape
196
-
197
- videos_th = []
198
- for i in range(n):
199
- video = videos[i, :,:,:,:]
200
- images = [video[j, :,:,:] for j in range(t)]
201
- images = [to_tensor(img) for img in images]
202
- video = torch.stack(images)
203
- videos_th.append(video)
204
-
205
- if num_frames is None:
206
- num_frames = videos.shape[1]
207
- if verbose:
208
- videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW
209
- else:
210
- videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
211
-
212
- frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
213
- if nrow is None:
214
- nrow = int(np.ceil(np.sqrt(n)))
215
- if verbose:
216
- frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')]
217
- else:
218
- frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
219
-
220
- if os.path.dirname(out_path) != "":
221
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
222
- frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
223
- torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
224
-
225
- # ----------------------------------------------------------------------------------------------
226
- def npz_to_gif_grid(data_path, out_path, n_cols=None, num_videos=20):
227
- arr = np.load(data_path)
228
- imgs = arr['arr_0'] # [N, T, H, W, 3]
229
- imgs = imgs[:num_videos]
230
- n, t, h, w, c = imgs.shape
231
- assert(n == num_videos)
232
- n_cols = n_cols if n_cols else imgs.shape[0]
233
- n_rows = np.ceil(imgs.shape[0] / n_cols).astype(np.int8)
234
- H, W = h * n_rows, w * n_cols
235
- grid = np.zeros((t, H, W, c), dtype=np.uint8)
236
-
237
- for i in range(n_rows):
238
- for j in range(n_cols):
239
- if i*n_cols+j < imgs.shape[0]:
240
- grid[:, i*h:(i+1)*h, j*w:(j+1)*w, :] = imgs[i*n_cols+j, :, :, :, :]
241
-
242
- videos = [grid[i] for i in range(grid.shape[0])] # grid: TH'W'C
243
- imageio.mimwrite(out_path, videos, format='GIF', duration=0.5,palettesize=256)
244
-
245
-
246
- # ----------------------------------------------------------------------------------------------
247
- def torch_to_video_grid(videos, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True):
248
- """
249
- videos: -1 ~ 1, torch.Tensor, BCTHW
250
- """
251
- n,t,h,w,c = videos.shape
252
- videos_th = [videos[i, ...] for i in range(n)]
253
- if verbose:
254
- videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW
255
- else:
256
- videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
257
-
258
- frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
259
- if nrow is None:
260
- nrow = int(np.ceil(np.sqrt(n)))
261
- if verbose:
262
- frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')]
263
- else:
264
- frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
265
-
266
- if os.path.dirname(out_path) != "":
267
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
268
- frame_grids = ((torch.stack(frame_grids) + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
269
- torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompts/i2v_prompts/horse.png ADDED
prompts/i2v_prompts/seashore.png ADDED
prompts/i2v_prompts/test_prompts.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ horses are walking on the grassland
2
+ a boy and a girl are talking on the seashore
prompts/test_prompts.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ A tiger walks in the forest, photorealistic, 4k, high definition
2
+ A boat moving on the sea, flowers and grassland on the shore
requirements.txt CHANGED
@@ -16,7 +16,8 @@ transformers==4.25.1
16
  moviepy
17
  av
18
  xformers
19
- gradio==3.24.1
20
- gradio-client==0.1.2
21
  timm
22
- # -e .
 
 
 
16
  moviepy
17
  av
18
  xformers
19
+ gradio
 
20
  timm
21
+ scikit-learn
22
+ open_clip_torch
23
+ kornia
sample_adapter.sh DELETED
@@ -1,22 +0,0 @@
1
- PROMPT="An ostrich walking in the desert, photorealistic, 4k"
2
- VIDEO="input/flamingo.mp4"
3
- OUTDIR="results/"
4
-
5
- NAME="video_adapter"
6
- CONFIG_PATH="models/adapter_t2v_depth/model_config.yaml"
7
- BASE_PATH="models/base_t2v/model.ckpt"
8
- ADAPTER_PATH="models/adapter_t2v_depth/adapter.pth"
9
-
10
- python scripts/sample_text2video_adapter.py \
11
- --seed 123 \
12
- --ckpt_path $BASE_PATH \
13
- --adapter_ckpt $ADAPTER_PATH \
14
- --base $CONFIG_PATH \
15
- --savedir $OUTDIR/$NAME \
16
- --bs 1 --height 256 --width 256 \
17
- --frame_stride -1 \
18
- --unconditional_guidance_scale 15.0 \
19
- --ddim_steps 50 \
20
- --ddim_eta 1.0 \
21
- --prompt "$PROMPT" \
22
- --video $VIDEO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample_text2video.sh DELETED
@@ -1,16 +0,0 @@
1
-
2
- PROMPT="astronaut riding a horse" # OR: PROMPT="input/prompts.txt" for sampling multiple prompts
3
- OUTDIR="results/"
4
-
5
- BASE_PATH="models/base_t2v/model.ckpt"
6
- CONFIG_PATH="models/base_t2v/model_config.yaml"
7
-
8
- python scripts/sample_text2video.py \
9
- --ckpt_path $BASE_PATH \
10
- --config_path $CONFIG_PATH \
11
- --prompt "$PROMPT" \
12
- --save_dir $OUTDIR \
13
- --n_samples 1 \
14
- --batch_size 1 \
15
- --seed 1000 \
16
- --show_denoising_progress