ProgramerSalar commited on
Commit
524f2f8
·
1 Parent(s): c171fe2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Path/.gitattributes +35 -0
  2. Path/README.md +3 -0
  3. Path/causal_video_vae/config.json +92 -0
  4. Path/causal_video_vae/diffusion_pytorch_model.safetensors +3 -0
  5. Path/diffusion_transformer_384p/config.json +20 -0
  6. Path/diffusion_transformer_384p/diffusion_pytorch_model.safetensors +3 -0
  7. Path/diffusion_transformer_768p/config.json +20 -0
  8. Path/diffusion_transformer_768p/diffusion_pytorch_model.safetensors +3 -0
  9. Path/text_encoder/config.json +24 -0
  10. Path/text_encoder/model.safetensors +3 -0
  11. Path/text_encoder_2/config.json +24 -0
  12. Path/text_encoder_2/model.safetensors +3 -0
  13. Path/text_encoder_3/config.json +31 -0
  14. Path/text_encoder_3/model-00001-of-00002.safetensors +3 -0
  15. Path/text_encoder_3/model-00002-of-00002.safetensors +3 -0
  16. Path/text_encoder_3/model.safetensors.index.json +226 -0
  17. Path/tokenizer/merges.txt +0 -0
  18. Path/tokenizer/special_tokens_map.json +30 -0
  19. Path/tokenizer/tokenizer_config.json +30 -0
  20. Path/tokenizer/vocab.json +0 -0
  21. Path/tokenizer_2/merges.txt +0 -0
  22. Path/tokenizer_2/special_tokens_map.json +30 -0
  23. Path/tokenizer_2/tokenizer_config.json +38 -0
  24. Path/tokenizer_2/vocab.json +0 -0
  25. Path/tokenizer_3/special_tokens_map.json +125 -0
  26. Path/tokenizer_3/spiece.model +3 -0
  27. Path/tokenizer_3/tokenizer.json +0 -0
  28. Path/tokenizer_3/tokenizer_config.json +940 -0
  29. annotation/image_text.jsonl +20 -0
  30. annotation/video_text.jsonl +17 -0
  31. app.py +356 -0
  32. app_multigpu.py +143 -0
  33. assets/motivation.jpg +0 -0
  34. assets/the_great_wall.jpg +0 -0
  35. assets/user_study.jpg +0 -0
  36. assets/vbench.jpg +0 -0
  37. causal_video_vae_demo.ipynb +221 -0
  38. dataset/__init__.py +12 -0
  39. dataset/bucket_loader.py +148 -0
  40. dataset/dataloaders.py +190 -0
  41. dataset/dataset_cls.py +377 -0
  42. diffusion_schedulers/__init__.py +2 -0
  43. diffusion_schedulers/scheduling_cosine_ddpm.py +137 -0
  44. diffusion_schedulers/scheduling_flow_matching.py +297 -0
  45. docs/DiT.md +54 -0
  46. docs/VAE.md +43 -0
  47. image_generation_demo.ipynb +123 -0
  48. inference_multigpu.py +123 -0
  49. pyramid_dit/__init__.py +3 -0
  50. pyramid_dit/flux_modules/__init__.py +3 -0
Path/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Path/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
Path/causal_video_vae/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CausalVideoVAE",
3
+ "_diffusers_version": "0.29.2",
4
+ "add_post_quant_conv": true,
5
+ "decoder_act_fn": "silu",
6
+ "decoder_block_dropout": [
7
+ 0.0,
8
+ 0.0,
9
+ 0.0,
10
+ 0.0
11
+ ],
12
+ "decoder_block_out_channels": [
13
+ 128,
14
+ 256,
15
+ 512,
16
+ 512
17
+ ],
18
+ "decoder_in_channels": 16,
19
+ "decoder_layers_per_block": [
20
+ 3,
21
+ 3,
22
+ 3,
23
+ 3
24
+ ],
25
+ "decoder_norm_num_groups": 32,
26
+ "decoder_out_channels": 3,
27
+ "decoder_spatial_up_sample": [
28
+ true,
29
+ true,
30
+ true,
31
+ false
32
+ ],
33
+ "decoder_temporal_up_sample": [
34
+ true,
35
+ true,
36
+ true,
37
+ false
38
+ ],
39
+ "decoder_type": "causal_vae_conv",
40
+ "decoder_up_block_types": [
41
+ "UpDecoderBlockCausal3D",
42
+ "UpDecoderBlockCausal3D",
43
+ "UpDecoderBlockCausal3D",
44
+ "UpDecoderBlockCausal3D"
45
+ ],
46
+ "downsample_scale": 8,
47
+ "encoder_act_fn": "silu",
48
+ "encoder_block_dropout": [
49
+ 0.0,
50
+ 0.0,
51
+ 0.0,
52
+ 0.0
53
+ ],
54
+ "encoder_block_out_channels": [
55
+ 128,
56
+ 256,
57
+ 512,
58
+ 512
59
+ ],
60
+ "encoder_double_z": true,
61
+ "encoder_down_block_types": [
62
+ "DownEncoderBlockCausal3D",
63
+ "DownEncoderBlockCausal3D",
64
+ "DownEncoderBlockCausal3D",
65
+ "DownEncoderBlockCausal3D"
66
+ ],
67
+ "encoder_in_channels": 3,
68
+ "encoder_layers_per_block": [
69
+ 2,
70
+ 2,
71
+ 2,
72
+ 2
73
+ ],
74
+ "encoder_norm_num_groups": 32,
75
+ "encoder_out_channels": 16,
76
+ "encoder_spatial_down_sample": [
77
+ true,
78
+ true,
79
+ true,
80
+ false
81
+ ],
82
+ "encoder_temporal_down_sample": [
83
+ true,
84
+ true,
85
+ true,
86
+ false
87
+ ],
88
+ "encoder_type": "causal_vae_conv",
89
+ "interpolate": false,
90
+ "sample_size": 256,
91
+ "scaling_factor": 0.13025
92
+ }
Path/causal_video_vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3e46120712f14a7e8a010264294754c113677085bab2b6cf051f1758cab238e
3
+ size 1341638684
Path/diffusion_transformer_384p/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PyramidDiffusionMMDiT",
3
+ "_diffusers_version": "0.30.0",
4
+ "attention_head_dim": 64,
5
+ "caption_projection_dim": 1536,
6
+ "in_channels": 16,
7
+ "joint_attention_dim": 4096,
8
+ "max_num_frames": 200,
9
+ "num_attention_heads": 24,
10
+ "num_layers": 24,
11
+ "patch_size": 2,
12
+ "pooled_projection_dim": 2048,
13
+ "pos_embed_max_size": 192,
14
+ "pos_embed_type": "sincos",
15
+ "qk_norm": "rms_norm",
16
+ "sample_size": 128,
17
+ "use_flash_attn": false,
18
+ "use_gradient_checkpointing": false,
19
+ "use_temporal_causal": true
20
+ }
Path/diffusion_transformer_384p/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59eeee86ade2c2765a6703e628c81e09b7c5e7191775aa80c5fb552313cc33be
3
+ size 8339919144
Path/diffusion_transformer_768p/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PyramidDiffusionMMDiT",
3
+ "_diffusers_version": "0.30.0",
4
+ "attention_head_dim": 64,
5
+ "caption_projection_dim": 1536,
6
+ "in_channels": 16,
7
+ "joint_attention_dim": 4096,
8
+ "max_num_frames": 200,
9
+ "num_attention_heads": 24,
10
+ "num_layers": 24,
11
+ "patch_size": 2,
12
+ "pooled_projection_dim": 2048,
13
+ "pos_embed_max_size": 192,
14
+ "pos_embed_type": "sincos",
15
+ "qk_norm": "rms_norm",
16
+ "sample_size": 128,
17
+ "use_flash_attn": false,
18
+ "use_gradient_checkpointing": false,
19
+ "use_temporal_causal": true
20
+ }
Path/diffusion_transformer_768p/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7262ac5271e549ea87b8b05610663f16cf490eec19f525d4e2709c929cf52665
3
+ size 8339919144
Path/text_encoder/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModelWithProjection"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "quick_gelu",
10
+ "hidden_size": 768,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.41.2",
23
+ "vocab_size": 49408
24
+ }
Path/text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71e183d11db0c6b6282a4d9e0abb74125edc8692393e89ed8ee5571005f35cb1
3
+ size 247323896
Path/text_encoder_2/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModelWithProjection"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.41.2",
23
+ "vocab_size": 49408
24
+ }
Path/text_encoder_2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec310df2af79c318e24d20511b601a591ca8cd4f1fce1d8dff822a356bcdb1f4
3
+ size 1389382176
Path/text_encoder_3/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5EncoderModel"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 10240,
7
+ "d_kv": 64,
8
+ "d_model": 4096,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 24,
20
+ "num_heads": 64,
21
+ "num_layers": 24,
22
+ "output_past": true,
23
+ "pad_token_id": 0,
24
+ "relative_attention_max_distance": 128,
25
+ "relative_attention_num_buckets": 32,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float16",
28
+ "transformers_version": "4.41.2",
29
+ "use_cache": true,
30
+ "vocab_size": 32128
31
+ }
Path/text_encoder_3/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f2751ceeb2a96edd693e539dc5d6bba0b8d3814f49a9b3798403a0cec4b2e3d
3
+ size 4994582104
Path/text_encoder_3/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63154532130422309532ff56f11945fbea8266c958e3133e8e5aef85c6293c7
3
+ size 4530066248
Path/text_encoder_3/model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9524621312
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
223
+ "encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
224
+ "shared.weight": "model-00001-of-00002.safetensors"
225
+ }
226
+ }
Path/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
Path/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
Path/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
Path/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
Path/tokenizer_2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
Path/tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "!",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
Path/tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
Path/tokenizer_2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
Path/tokenizer_3/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
Path/tokenizer_3/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
Path/tokenizer_3/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
Path/tokenizer_3/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": true,
31
+ "normalized": false,
32
+ "rstrip": true,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": true,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": true,
47
+ "normalized": false,
48
+ "rstrip": true,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": true,
55
+ "normalized": false,
56
+ "rstrip": true,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": true,
63
+ "normalized": false,
64
+ "rstrip": true,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": true,
71
+ "normalized": false,
72
+ "rstrip": true,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": true,
79
+ "normalized": false,
80
+ "rstrip": true,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": true,
87
+ "normalized": false,
88
+ "rstrip": true,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": true,
95
+ "normalized": false,
96
+ "rstrip": true,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": true,
103
+ "normalized": false,
104
+ "rstrip": true,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": true,
111
+ "normalized": false,
112
+ "rstrip": true,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": true,
119
+ "normalized": false,
120
+ "rstrip": true,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": true,
127
+ "normalized": false,
128
+ "rstrip": true,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": true,
135
+ "normalized": false,
136
+ "rstrip": true,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": true,
143
+ "normalized": false,
144
+ "rstrip": true,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": true,
151
+ "normalized": false,
152
+ "rstrip": true,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": true,
159
+ "normalized": false,
160
+ "rstrip": true,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": true,
167
+ "normalized": false,
168
+ "rstrip": true,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": true,
175
+ "normalized": false,
176
+ "rstrip": true,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": true,
183
+ "normalized": false,
184
+ "rstrip": true,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": true,
191
+ "normalized": false,
192
+ "rstrip": true,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": true,
199
+ "normalized": false,
200
+ "rstrip": true,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": true,
207
+ "normalized": false,
208
+ "rstrip": true,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": true,
215
+ "normalized": false,
216
+ "rstrip": true,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": true,
223
+ "normalized": false,
224
+ "rstrip": true,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": true,
231
+ "normalized": false,
232
+ "rstrip": true,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": true,
239
+ "normalized": false,
240
+ "rstrip": true,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": true,
247
+ "normalized": false,
248
+ "rstrip": true,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": true,
255
+ "normalized": false,
256
+ "rstrip": true,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": true,
263
+ "normalized": false,
264
+ "rstrip": true,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": true,
271
+ "normalized": false,
272
+ "rstrip": true,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": true,
279
+ "normalized": false,
280
+ "rstrip": true,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": true,
287
+ "normalized": false,
288
+ "rstrip": true,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": true,
295
+ "normalized": false,
296
+ "rstrip": true,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": true,
303
+ "normalized": false,
304
+ "rstrip": true,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": true,
311
+ "normalized": false,
312
+ "rstrip": true,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": true,
319
+ "normalized": false,
320
+ "rstrip": true,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": true,
327
+ "normalized": false,
328
+ "rstrip": true,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": true,
335
+ "normalized": false,
336
+ "rstrip": true,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": true,
343
+ "normalized": false,
344
+ "rstrip": true,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": true,
351
+ "normalized": false,
352
+ "rstrip": true,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": true,
359
+ "normalized": false,
360
+ "rstrip": true,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": true,
367
+ "normalized": false,
368
+ "rstrip": true,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": true,
375
+ "normalized": false,
376
+ "rstrip": true,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": true,
383
+ "normalized": false,
384
+ "rstrip": true,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": true,
391
+ "normalized": false,
392
+ "rstrip": true,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": true,
399
+ "normalized": false,
400
+ "rstrip": true,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": true,
407
+ "normalized": false,
408
+ "rstrip": true,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": true,
415
+ "normalized": false,
416
+ "rstrip": true,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": true,
423
+ "normalized": false,
424
+ "rstrip": true,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": true,
431
+ "normalized": false,
432
+ "rstrip": true,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": true,
439
+ "normalized": false,
440
+ "rstrip": true,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": true,
447
+ "normalized": false,
448
+ "rstrip": true,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": true,
455
+ "normalized": false,
456
+ "rstrip": true,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": true,
463
+ "normalized": false,
464
+ "rstrip": true,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": true,
471
+ "normalized": false,
472
+ "rstrip": true,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": true,
479
+ "normalized": false,
480
+ "rstrip": true,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": true,
487
+ "normalized": false,
488
+ "rstrip": true,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": true,
495
+ "normalized": false,
496
+ "rstrip": true,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": true,
503
+ "normalized": false,
504
+ "rstrip": true,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": true,
511
+ "normalized": false,
512
+ "rstrip": true,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": true,
519
+ "normalized": false,
520
+ "rstrip": true,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": true,
527
+ "normalized": false,
528
+ "rstrip": true,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": true,
535
+ "normalized": false,
536
+ "rstrip": true,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": true,
543
+ "normalized": false,
544
+ "rstrip": true,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": true,
551
+ "normalized": false,
552
+ "rstrip": true,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": true,
559
+ "normalized": false,
560
+ "rstrip": true,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": true,
567
+ "normalized": false,
568
+ "rstrip": true,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": true,
575
+ "normalized": false,
576
+ "rstrip": true,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": true,
583
+ "normalized": false,
584
+ "rstrip": true,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": true,
591
+ "normalized": false,
592
+ "rstrip": true,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": true,
599
+ "normalized": false,
600
+ "rstrip": true,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": true,
607
+ "normalized": false,
608
+ "rstrip": true,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": true,
615
+ "normalized": false,
616
+ "rstrip": true,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": true,
623
+ "normalized": false,
624
+ "rstrip": true,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": true,
631
+ "normalized": false,
632
+ "rstrip": true,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": true,
639
+ "normalized": false,
640
+ "rstrip": true,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": true,
647
+ "normalized": false,
648
+ "rstrip": true,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": true,
655
+ "normalized": false,
656
+ "rstrip": true,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": true,
663
+ "normalized": false,
664
+ "rstrip": true,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": true,
671
+ "normalized": false,
672
+ "rstrip": true,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": true,
679
+ "normalized": false,
680
+ "rstrip": true,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": true,
687
+ "normalized": false,
688
+ "rstrip": true,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": true,
695
+ "normalized": false,
696
+ "rstrip": true,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": true,
703
+ "normalized": false,
704
+ "rstrip": true,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": true,
711
+ "normalized": false,
712
+ "rstrip": true,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": true,
719
+ "normalized": false,
720
+ "rstrip": true,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": true,
727
+ "normalized": false,
728
+ "rstrip": true,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": true,
735
+ "normalized": false,
736
+ "rstrip": true,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": true,
743
+ "normalized": false,
744
+ "rstrip": true,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": true,
751
+ "normalized": false,
752
+ "rstrip": true,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": true,
759
+ "normalized": false,
760
+ "rstrip": true,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": true,
767
+ "normalized": false,
768
+ "rstrip": true,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": true,
775
+ "normalized": false,
776
+ "rstrip": true,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": true,
783
+ "normalized": false,
784
+ "rstrip": true,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": true,
791
+ "normalized": false,
792
+ "rstrip": true,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": true,
799
+ "normalized": false,
800
+ "rstrip": true,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": true,
807
+ "normalized": false,
808
+ "rstrip": true,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": true,
815
+ "normalized": false,
816
+ "rstrip": true,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": true,
823
+ "normalized": false,
824
+ "rstrip": true,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
annotation/image_text.jsonl ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"image": "SAM_filter/000424/sa_4749867.jpg", "text": "a cityscape with a large body of water, such as a lake or a river, in the foreground"}
2
+ {"image": "SAM_filter/000311/sa_3490721.jpg", "text": "a large, stately building with a white and blue color scheme, which gives it a grand and elegant appearance"}
3
+ {"image": "SAM_filter/000273/sa_3059407.jpg", "text": "a close-up of a green bag containing a package of Japanese soybeans, along with a bottle of sake, a traditional Japanese alcoholic beverage"}
4
+ {"image": "SAM_filter/000745/sa_8344729.jpg", "text": "a large, old-fashioned building with a red and white color scheme"}
5
+ {"image": "SAM_filter/000832/sa_9310794.jpg", "text": "a cityscape with a large tower, likely the Eiffel Tower, as the main focal point"}
6
+ {"image": "SAM_filter/000427/sa_4779422.jpg", "text": "a large cruise ship, specifically a Royal Caribbean cruise ship, docked at a pier in a harbor"}
7
+ {"image": "SAM_filter/000105/sa_1178255.jpg", "text": "a close-up view of a computer screen with a magnifying glass placed over it"}
8
+ {"image": "SAM_filter/000765/sa_8560467.jpg", "text": "a tree with a sign attached to it, which is located in a lush green field"}
9
+ {"image": "SAM_filter/000216/sa_2417372.jpg", "text": "a large airport terminal with a long blue and white rope-style security line"}
10
+ {"image": "SAM_filter/000385/sa_4308806.jpg", "text": "a close-up of a cell phone screen displaying a blue and white logo, which appears to be a bank logo"}
11
+ {"image": "SAM_filter/000931/sa_10425835.jpg", "text": "a large body of water, possibly a lake, with a lush green landscape surrounding it"}
12
+ {"image": "SAM_filter/000364/sa_4079002.jpg", "text": "a large, empty airport terminal with a long row of gray metal chairs arranged in a straight line"}
13
+ {"image": "SAM_filter/000474/sa_5306222.jpg", "text": "a large, modern building with a tall, glass structure, which is likely a museum"}
14
+ {"image": "SAM_filter/000584/sa_6536849.jpg", "text": "a city street scene with a black car parked in a parking lot, a building with a balcony, and a city skyline in the background"}
15
+ {"image": "SAM_filter/000188/sa_2104485.jpg", "text": "a large jet fighter airplane flying through the sky, captured in a high-quality photograph"}
16
+ {"image": "SAM_filter/000219/sa_2458908.jpg", "text": "a stone structure with a tall tower, which is situated in a lush green garden"}
17
+ {"image": "SAM_filter/000440/sa_4929413.jpg", "text": "a large city street with a mix of architectural styles, including a Gothic-style building and a modern building"}
18
+ {"image": "SAM_filter/000739/sa_8279296.jpg", "text": "a vintage blue and white bus parked on the side of a dirt road, with a building in the background"}
19
+ {"image": "SAM_filter/000809/sa_9052304.jpg", "text": "a large, old stone building with a clock tower, which is situated in a small town"}
20
+ {"image": "SAM_filter/000294/sa_3300200.jpg", "text": "a table with various utensils, including a bowl, spoon, and fork, placed on a wooden surface"}
annotation/video_text.jsonl ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"video": "webvid10m/train/010451_010500/23388121.mp4", "text": "the serene beauty of a valley with a river, mountains, and clouds", "latent": "webvid10m/train/010451_010500/23388121-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/010451_010500/23388121-text.pt"}
2
+ {"video": "pexels/8440980-uhd_3840_2160_25fps.mp4", "text": "A group of people, including two men and two women, are seen sitting at a table, smiling and waving at the camera, and appear to be in a good mood", "latent": "pexels/8440980-uhd_3840_2160_25fps-latent-384-2.pt", "text_fea": "text_feature/pexels/8440980-uhd_3840_2160_25fps-text.pt"}
3
+ {"video": "webvid10m/train/176251_176300/1011015221.mp4", "text": "an aerial view of a large wheat field with a road running through it, and a car driving on the road", "latent": "webvid10m/train/176251_176300/1011015221-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/176251_176300/1011015221-text.pt"}
4
+ {"video": "webvid10m/train/005801_005850/22143805.mp4", "text": "a close-up of paint mixing in water, creating swirling patterns", "latent": "webvid10m/train/005801_005850/22143805-latent-384-8.pt", "text_fea": "text_feature/webvid10m/train/005801_005850/22143805-text.pt"}
5
+ {"video": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743.mp4", "text": "A baby girl in a pink shirt and striped pants sits in a high chair, eats a piece of bread, and looks at the camera", "latent": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-latent-384-0.pt", "text_fea": "text_feature/OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-text.pt"}
6
+ {"video": "webvid10m/train/134901_134950/1037990273.mp4", "text": "a field of green wheat waving in the wind", "latent": "webvid10m/train/134901_134950/1037990273-latent-384-6.pt", "text_fea": "text_feature/webvid10m/train/134901_134950/1037990273-text.pt"}
7
+ {"video": "pexels/5263258-uhd_2160_4096_30fps.mp4", "text": "A dog sits patiently in front of its bowl, waiting for it to be filled with food", "latent": "pexels/5263258-uhd_2160_4096_30fps-latent-384-6.pt", "text_fea": "text_feature/pexels/5263258-uhd_2160_4096_30fps-text.pt"}
8
+ {"video": "webvid10m/train/117851_117900/6461432.mp4", "text": "A ladybug crawls along a blade of grass in a serene natural setting", "latent": "webvid10m/train/117851_117900/6461432-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/117851_117900/6461432-text.pt"}
9
+ {"video": "webvid10m/train/053051_053100/1058396656.mp4", "text": "a group of construction workers working on a rooftop, with a supervisor overseeing the work", "latent": "webvid10m/train/053051_053100/1058396656-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/053051_053100/1058396656-text.pt"}
10
+ {"video": "webvid10m/train/073651_073700/1021916425.mp4", "text": "an aerial view of a beautiful coastline with rocky islands, blue water, and a white cloud in the sky", "latent": "webvid10m/train/073651_073700/1021916425-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/073651_073700/1021916425-text.pt"}
11
+ {"video": "webvid10m/train/027051_027100/1032549941.mp4", "text": "a young woman waking up in bed, smiling at the camera, and then lying back down on the bed", "latent": "webvid10m/train/027051_027100/1032549941-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/027051_027100/1032549941-text.pt"}
12
+ {"video": "pexels/5564564-uhd_3840_2160_24fps.mp4", "text": "a person rolling out dough on a table using a rolling pin", "latent": "pexels/5564564-uhd_3840_2160_24fps-latent-384-8.pt", "text_fea": "text_feature/pexels/5564564-uhd_3840_2160_24fps-text.pt"}
13
+ {"video": "webvid10m/train/073701_073750/24008116.mp4", "text": "a cityscape with a moon in the sky, and the camera pans across the city", "latent": "webvid10m/train/073701_073750/24008116-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/073701_073750/24008116-text.pt"}
14
+ {"video": "webvid10m/train/118351_118400/23370991.mp4", "text": "a group of dolphins swimming in the ocean, with a person on a boat nearby", "latent": "webvid10m/train/118351_118400/23370991-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/118351_118400/23370991-text.pt"}
15
+ {"video": "webvid10m/train/022001_022050/1023013066.mp4", "text": "a bird's eye view of a beachfront city, highlighting the hotels, pools, and proximity to the ocean", "latent": "webvid10m/train/022001_022050/1023013066-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/022001_022050/1023013066-text.pt"}
16
+ {"video": "webvid10m/train/004601_004650/1015979020.mp4", "text": "a bridge over a body of water, with a boat passing under it", "latent": "webvid10m/train/004601_004650/1015979020-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/004601_004650/1015979020-text.pt"}
17
+ {"video": "webvid10m/train/149701_149750/1034525579.mp4", "text": "a group of owls and a moon, with the moon appearing to grow larger as the video progresses", "latent": "webvid10m/train/149701_149750/1034525579-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/149701_149750/1034525579-text.pt"}
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import gradio as gr
4
+ import torch
5
+ import PIL
6
+ from PIL import Image
7
+ from pyramid_dit import PyramidDiTForVideoGeneration
8
+ from diffusers.utils import export_to_video
9
+ from huggingface_hub import snapshot_download
10
+ import threading
11
+ import random
12
+
13
+ # Global model cache
14
+ model_cache = {}
15
+
16
+ # Lock to ensure thread-safe access to the model cache
17
+ model_cache_lock = threading.Lock()
18
+
19
+ # Configuration
20
+ model_name = "pyramid_flux" # or pyramid_mmdit
21
+ model_repo = "rain1011/pyramid-flow-sd3" if model_name == "pyramid_mmdit" else "rain1011/pyramid-flow-miniflux"
22
+
23
+ model_dtype = "bf16" # Support bf16 and fp32
24
+ variants = {
25
+ 'high': 'diffusion_transformer_768p', # For high-resolution version
26
+ 'low': 'diffusion_transformer_384p' # For low-resolution version
27
+ }
28
+ required_file = 'config.json' # Ensure config.json is present
29
+ width_high = 1280
30
+ height_high = 768
31
+ width_low = 640
32
+ height_low = 384
33
+ cpu_offloading = True # enable cpu_offloading by default
34
+
35
+ # Get the current working directory and create a folder to store the model
36
+ current_directory = os.getcwd()
37
+ model_path = os.path.join(current_directory, "pyramid_flow_model") # Directory to store the model
38
+
39
+ # Download the model if not already present
40
+ def download_model_from_hf(model_repo, model_dir, variants, required_file):
41
+ need_download = False
42
+ if not os.path.exists(model_dir):
43
+ print(f"[INFO] Model directory '{model_dir}' does not exist. Initiating download...")
44
+ need_download = True
45
+ else:
46
+ # Check if all required files exist for each variant
47
+ for variant_key, variant_dir in variants.items():
48
+ variant_path = os.path.join(model_dir, variant_dir)
49
+ file_path = os.path.join(variant_path, required_file)
50
+ if not os.path.exists(file_path):
51
+ print(f"[WARNING] Required file '{required_file}' missing in '{variant_path}'.")
52
+ need_download = True
53
+ break
54
+
55
+ if need_download:
56
+ print(f"[INFO] Downloading model from '{model_repo}' to '{model_dir}'...")
57
+ try:
58
+ snapshot_download(
59
+ repo_id=model_repo,
60
+ local_dir=model_dir,
61
+ local_dir_use_symlinks=False,
62
+ repo_type='model'
63
+ )
64
+ print("[INFO] Model download complete.")
65
+ except Exception as e:
66
+ print(f"[ERROR] Failed to download the model: {e}")
67
+ raise
68
+ else:
69
+ print(f"[INFO] All required model files are present in '{model_dir}'. Skipping download.")
70
+
71
+ # Download model from Hugging Face if not present
72
+ download_model_from_hf(model_repo, model_path, variants, required_file)
73
+
74
+ # Function to initialize the model based on user options
75
+ def initialize_model(variant):
76
+ print(f"[INFO] Initializing model with variant='{variant}', using bf16 precision...")
77
+
78
+ # Determine the correct variant directory
79
+ variant_dir = variants['high'] if variant == '768p' else variants['low']
80
+ base_path = model_path # Pass the base model path
81
+
82
+ print(f"[DEBUG] Model base path: {base_path}")
83
+
84
+ # Verify that config.json exists in the variant directory
85
+ config_path = os.path.join(model_path, variant_dir, 'config.json')
86
+ if not os.path.exists(config_path):
87
+ print(f"[ERROR] config.json not found in '{os.path.join(model_path, variant_dir)}'.")
88
+ raise FileNotFoundError(f"config.json not found in '{os.path.join(model_path, variant_dir)}'.")
89
+
90
+ if model_dtype == "bf16":
91
+ torch_dtype_selected = torch.bfloat16
92
+ else:
93
+ torch_dtype_selected = torch.float32
94
+
95
+ # Initialize the model
96
+ try:
97
+
98
+ model = PyramidDiTForVideoGeneration(
99
+ base_path, # Pass the base model path
100
+ model_name=model_name, # set to pyramid_flux or pyramid_mmdit
101
+ model_dtype=model_dtype, # Use bf16
102
+ model_variant=variant_dir, # Pass the variant directory name
103
+ cpu_offloading=cpu_offloading, # Pass the CPU offloading flag
104
+ )
105
+
106
+ # Always enable tiling for the VAE
107
+ model.vae.enable_tiling()
108
+
109
+ # Remove manual device placement when using CPU offloading
110
+ # The components will be moved to the appropriate devices automatically
111
+ if torch.cuda.is_available():
112
+ torch.cuda.set_device(0)
113
+ # Manual device replacement when not using CPU offloading
114
+ if not cpu_offloading:
115
+ model.vae.to("cuda")
116
+ model.dit.to("cuda")
117
+ model.text_encoder.to("cuda")
118
+ else:
119
+ print("[WARNING] CUDA is not available. Proceeding without GPU.")
120
+
121
+ print("[INFO] Model initialized successfully.")
122
+ return model, torch_dtype_selected
123
+ except Exception as e:
124
+ print(f"[ERROR] Error initializing model: {e}")
125
+ raise
126
+
127
+ # Function to get the model from cache or initialize it
128
+ def initialize_model_cached(variant, seed):
129
+ key = variant
130
+
131
+ if seed == 0:
132
+ seed = random.randint(0, 2**8 - 1)
133
+ torch.manual_seed(seed)
134
+ if torch.cuda.is_available():
135
+ torch.cuda.manual_seed(seed)
136
+ torch.cuda.manual_seed_all(seed)
137
+
138
+ # Check if the model is already in the cache
139
+ if key not in model_cache:
140
+ with model_cache_lock:
141
+ # Double-checked locking to prevent race conditions
142
+ if key not in model_cache:
143
+ model, dtype = initialize_model(variant)
144
+ model_cache[key] = (model, dtype)
145
+
146
+ return model_cache[key]
147
+
148
+ def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):
149
+ ori_width, ori_height = img.width, img.height
150
+ scale = max(tgt_width / ori_width, tgt_height / ori_height)
151
+ resized_width = round(ori_width * scale)
152
+ resized_height = round(ori_height * scale)
153
+ img = img.resize((resized_width, resized_height), resample=PIL.Image.LANCZOS)
154
+
155
+ left = (resized_width - tgt_width) / 2
156
+ top = (resized_height - tgt_height) / 2
157
+ right = (resized_width + tgt_width) / 2
158
+ bottom = (resized_height + tgt_height) / 2
159
+
160
+ # Crop the center of the image
161
+ img = img.crop((left, top, right, bottom))
162
+
163
+ return img
164
+
165
+ # Function to generate text-to-video
166
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, seed, progress=gr.Progress()):
167
+ progress(0, desc="Loading model")
168
+ print("[DEBUG] generate_text_to_video called.")
169
+ variant = '768p' if resolution == "768p" else '384p'
170
+ height = height_high if resolution == "768p" else height_low
171
+ width = width_high if resolution == "768p" else width_low
172
+
173
+ def progress_callback(i, m):
174
+ progress(i/m)
175
+
176
+ # Initialize model based on user options using cached function
177
+ try:
178
+ model, torch_dtype_selected = initialize_model_cached(variant, seed)
179
+ except Exception as e:
180
+ print(f"[ERROR] Model initialization failed: {e}")
181
+ return f"Model initialization failed: {e}"
182
+
183
+ try:
184
+ print("[INFO] Starting text-to-video generation...")
185
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
186
+ frames = model.generate(
187
+ prompt=prompt,
188
+ num_inference_steps=[20, 20, 20],
189
+ video_num_inference_steps=[10, 10, 10],
190
+ height=height,
191
+ width=width,
192
+ temp=temp,
193
+ guidance_scale=guidance_scale,
194
+ video_guidance_scale=video_guidance_scale,
195
+ output_type="pil",
196
+ cpu_offloading=cpu_offloading,
197
+ save_memory=True,
198
+ callback=progress_callback,
199
+ )
200
+ print("[INFO] Text-to-video generation completed.")
201
+ except Exception as e:
202
+ print(f"[ERROR] Error during text-to-video generation: {e}")
203
+ return f"Error during video generation: {e}"
204
+
205
+ video_path = f"{str(uuid.uuid4())}_text_to_video_sample.mp4"
206
+ try:
207
+ export_to_video(frames, video_path, fps=24)
208
+ print(f"[INFO] Video exported to {video_path}.")
209
+ except Exception as e:
210
+ print(f"[ERROR] Error exporting video: {e}")
211
+ return f"Error exporting video: {e}"
212
+ return video_path
213
+
214
+ # Function to generate image-to-video
215
+ def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution, seed, progress=gr.Progress()):
216
+ progress(0, desc="Loading model")
217
+ print("[DEBUG] generate_image_to_video called.")
218
+ variant = '768p' if resolution == "768p" else '384p'
219
+ height = height_high if resolution == "768p" else height_low
220
+ width = width_high if resolution == "768p" else width_low
221
+
222
+ try:
223
+ image = resize_crop_image(image, width, height)
224
+ print("[INFO] Image resized and cropped successfully.")
225
+ except Exception as e:
226
+ print(f"[ERROR] Error processing image: {e}")
227
+ return f"Error processing image: {e}"
228
+
229
+ def progress_callback(i, m):
230
+ progress(i/m)
231
+
232
+ # Initialize model based on user options using cached function
233
+ try:
234
+ model, torch_dtype_selected = initialize_model_cached(variant, seed)
235
+ except Exception as e:
236
+ print(f"[ERROR] Model initialization failed: {e}")
237
+ return f"Model initialization failed: {e}"
238
+
239
+ try:
240
+ print("[INFO] Starting image-to-video generation...")
241
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
242
+ frames = model.generate_i2v(
243
+ prompt=prompt,
244
+ input_image=image,
245
+ num_inference_steps=[10, 10, 10],
246
+ temp=temp,
247
+ video_guidance_scale=video_guidance_scale,
248
+ output_type="pil",
249
+ cpu_offloading=cpu_offloading,
250
+ save_memory=True,
251
+ callback=progress_callback,
252
+ )
253
+ print("[INFO] Image-to-video generation completed.")
254
+ except Exception as e:
255
+ print(f"[ERROR] Error during image-to-video generation: {e}")
256
+ return f"Error during video generation: {e}"
257
+
258
+ video_path = f"{str(uuid.uuid4())}_image_to_video_sample.mp4"
259
+ try:
260
+ export_to_video(frames, video_path, fps=24)
261
+ print(f"[INFO] Video exported to {video_path}.")
262
+ except Exception as e:
263
+ print(f"[ERROR] Error exporting video: {e}")
264
+ return f"Error exporting video: {e}"
265
+ return video_path
266
+
267
+ def update_slider(resolution):
268
+ if resolution == "768p":
269
+ return [gr.update(maximum=31), gr.update(maximum=31)]
270
+ else:
271
+ return [gr.update(maximum=16), gr.update(maximum=16)]
272
+
273
+ # Gradio interface
274
+ with gr.Blocks() as demo:
275
+ gr.Markdown(
276
+ """
277
+ # Pyramid Flow Video Generation Demo
278
+
279
+ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
280
+
281
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
282
+ """
283
+ )
284
+
285
+ # Shared settings
286
+ with gr.Row():
287
+ resolution_dropdown = gr.Dropdown(
288
+ choices=["768p", "384p"],
289
+ value="384p",
290
+ label="Model Resolution"
291
+ )
292
+
293
+ with gr.Tab("Text-to-Video"):
294
+ with gr.Row():
295
+ with gr.Column():
296
+ text_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
297
+ temp_slider = gr.Slider(1, 16, value=16, step=1, label="Duration")
298
+ guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
299
+ video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
300
+ text_seed = gr.Number(label="Inference Seed (Enter a positive number, 0 for random)", value=0)
301
+ txt_generate = gr.Button("Generate Video")
302
+ with gr.Column():
303
+ txt_output = gr.Video(label="Generated Video")
304
+ gr.Examples(
305
+ examples=[
306
+ ["A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors", 16, 7.0, 5.0, "384p"],
307
+ ["Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes", 16, 7.0, 5.0, "384p"],
308
+ # ["Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours", 31, 9.0, 5.0, "768p"],
309
+ ],
310
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, text_seed],
311
+ outputs=[txt_output],
312
+ fn=generate_text_to_video,
313
+ cache_examples='lazy',
314
+ )
315
+
316
+ with gr.Tab("Image-to-Video"):
317
+ with gr.Row():
318
+ with gr.Column():
319
+ image_input = gr.Image(type="pil", label="Input Image")
320
+ image_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
321
+ image_temp_slider = gr.Slider(2, 16, value=16, step=1, label="Duration")
322
+ image_video_guidance_scale_slider = gr.Slider(1.0, 7.0, value=4.0, step=0.1, label="Video Guidance Scale")
323
+ image_seed = gr.Number(label="Inference Seed (Enter a positive number, 0 for random)", value=0)
324
+ img_generate = gr.Button("Generate Video")
325
+ with gr.Column():
326
+ img_output = gr.Video(label="Generated Video")
327
+ gr.Examples(
328
+ examples=[
329
+ ['assets/the_great_wall.jpg', 'FPV flying over the Great Wall', 16, 4.0, "384p"]
330
+ ],
331
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown, image_seed],
332
+ outputs=[img_output],
333
+ fn=generate_image_to_video,
334
+ cache_examples='lazy',
335
+ )
336
+
337
+ # Update generate functions to include resolution options
338
+ txt_generate.click(
339
+ generate_text_to_video,
340
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, text_seed],
341
+ outputs=txt_output
342
+ )
343
+
344
+ img_generate.click(
345
+ generate_image_to_video,
346
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown, image_seed],
347
+ outputs=img_output
348
+ )
349
+ resolution_dropdown.change(
350
+ fn=update_slider,
351
+ inputs=resolution_dropdown,
352
+ outputs=[temp_slider, image_temp_slider]
353
+ )
354
+
355
+ # Launch Gradio app
356
+ demo.launch(share=False)
app_multigpu.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import gradio as gr
4
+ import subprocess
5
+ import tempfile
6
+ import shutil
7
+
8
+ def run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt):
9
+ """
10
+ Runs the external multi-GPU inference script and returns the path to the generated video.
11
+ """
12
+ # Create a temporary directory to store inputs and outputs
13
+ with tempfile.TemporaryDirectory() as tmpdir:
14
+ output_video = os.path.join(tmpdir, f"{uuid.uuid4()}_output.mp4")
15
+
16
+ # Path to the external shell script
17
+ script_path = "./scripts/app_multigpu_engine.sh" # Updated script path
18
+
19
+ # Prepare the command
20
+ cmd = [
21
+ script_path,
22
+ str(gpus),
23
+ variant,
24
+ model_path,
25
+ 't2v', # Task is always 't2v' since 'i2v' is removed
26
+ str(temp),
27
+ str(guidance_scale),
28
+ str(video_guidance_scale),
29
+ resolution,
30
+ output_video,
31
+ prompt # Pass the prompt directly as an argument
32
+ ]
33
+
34
+ try:
35
+ # Run the external script
36
+ subprocess.run(cmd, check=True)
37
+ except subprocess.CalledProcessError as e:
38
+ raise RuntimeError(f"Error during video generation: {e}")
39
+
40
+ # After generation, move the video to a permanent location
41
+ final_output = os.path.join("generated_videos", f"{uuid.uuid4()}_output.mp4")
42
+ os.makedirs("generated_videos", exist_ok=True)
43
+ shutil.move(output_video, final_output)
44
+
45
+ return final_output
46
+
47
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, gpus):
48
+ model_path = "./pyramid_flow_model" # Use the model path as specified
49
+ # Determine variant based on resolution
50
+ if resolution == "768p":
51
+ variant = "diffusion_transformer_768p"
52
+ else:
53
+ variant = "diffusion_transformer_384p"
54
+ return run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt)
55
+
56
+ # Gradio interface
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown(
59
+ """
60
+ # Pyramid Flow Video Generation Demo
61
+
62
+ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
63
+
64
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
65
+ """
66
+ )
67
+
68
+ # Shared settings
69
+ with gr.Row():
70
+ gpus_dropdown = gr.Dropdown(
71
+ choices=[2, 4],
72
+ value=4,
73
+ label="Number of GPUs"
74
+ )
75
+ resolution_dropdown = gr.Dropdown(
76
+ choices=["768p", "384p"],
77
+ value="768p",
78
+ label="Model Resolution"
79
+ )
80
+
81
+ with gr.Tab("Text-to-Video"):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ text_prompt = gr.Textbox(
85
+ label="Prompt (Less than 128 words)",
86
+ placeholder="Enter a text prompt for the video",
87
+ lines=2
88
+ )
89
+ temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
90
+ guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
91
+ video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
92
+ txt_generate = gr.Button("Generate Video")
93
+ with gr.Column():
94
+ txt_output = gr.Video(label="Generated Video")
95
+ gr.Examples(
96
+ examples=[
97
+ [
98
+ "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
99
+ 16,
100
+ 9.0,
101
+ 5.0,
102
+ "768p",
103
+ 4
104
+ ],
105
+ [
106
+ "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes",
107
+ 16,
108
+ 9.0,
109
+ 5.0,
110
+ "768p",
111
+ 4
112
+ ],
113
+ [
114
+ "Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours",
115
+ 31,
116
+ 9.0,
117
+ 5.0,
118
+ "768p",
119
+ 4
120
+ ],
121
+ ],
122
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, gpus_dropdown],
123
+ outputs=[txt_output],
124
+ fn=generate_text_to_video,
125
+ cache_examples='lazy',
126
+ )
127
+
128
+ # Update generate function for Text-to-Video
129
+ txt_generate.click(
130
+ generate_text_to_video,
131
+ inputs=[
132
+ text_prompt,
133
+ temp_slider,
134
+ guidance_scale_slider,
135
+ video_guidance_scale_slider,
136
+ resolution_dropdown,
137
+ gpus_dropdown
138
+ ],
139
+ outputs=txt_output
140
+ )
141
+
142
+ # Launch Gradio app
143
+ demo.launch(share=False)
assets/motivation.jpg ADDED
assets/the_great_wall.jpg ADDED
assets/user_study.jpg ADDED
assets/vbench.jpg ADDED
causal_video_vae_demo.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import cv2\n",
12
+ "import torch\n",
13
+ "import numpy as np\n",
14
+ "import PIL\n",
15
+ "from PIL import Image\n",
16
+ "from einops import rearrange\n",
17
+ "from video_vae import CausalVideoVAELossWrapper\n",
18
+ "from torchvision import transforms as pth_transforms\n",
19
+ "from torchvision.transforms.functional import InterpolationMode\n",
20
+ "from IPython.display import Image as ipython_image\n",
21
+ "from diffusers.utils import load_image, export_to_video, export_to_gif\n",
22
+ "from IPython.display import HTML"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "model_path = \"pyramid-flow-miniflux/causal_video_vae\" # The video-vae checkpoint dir\n",
32
+ "model_dtype = 'bf16'\n",
33
+ "\n",
34
+ "device_id = 3\n",
35
+ "torch.cuda.set_device(device_id)\n",
36
+ "\n",
37
+ "model = CausalVideoVAELossWrapper(\n",
38
+ " model_path,\n",
39
+ " model_dtype,\n",
40
+ " interpolate=False, \n",
41
+ " add_discriminator=False,\n",
42
+ ")\n",
43
+ "model = model.to(\"cuda\")\n",
44
+ "\n",
45
+ "if model_dtype == \"bf16\":\n",
46
+ " torch_dtype = torch.bfloat16 \n",
47
+ "elif model_dtype == \"fp16\":\n",
48
+ " torch_dtype = torch.float16\n",
49
+ "else:\n",
50
+ " torch_dtype = torch.float32\n",
51
+ "\n",
52
+ "def image_transform(images, resize_width, resize_height):\n",
53
+ " transform_list = pth_transforms.Compose([\n",
54
+ " pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n",
55
+ " pth_transforms.ToTensor(),\n",
56
+ " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
57
+ " ])\n",
58
+ " return torch.stack([transform_list(image) for image in images])\n",
59
+ "\n",
60
+ "\n",
61
+ "def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n",
62
+ " transform_list = []\n",
63
+ "\n",
64
+ " if resize:\n",
65
+ " if new_width is None:\n",
66
+ " new_width = width // 8 * 8\n",
67
+ " if new_height is None:\n",
68
+ " new_height = height // 8 * 8\n",
69
+ " transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n",
70
+ " \n",
71
+ " transform_list.extend([\n",
72
+ " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
73
+ " ])\n",
74
+ " transform_list = pth_transforms.Compose(transform_list)\n",
75
+ "\n",
76
+ " return transform_list\n",
77
+ "\n",
78
+ "\n",
79
+ "def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n",
80
+ " try:\n",
81
+ " video_capture = cv2.VideoCapture(video_path)\n",
82
+ " fps = video_capture.get(cv2.CAP_PROP_FPS)\n",
83
+ " frames = []\n",
84
+ " pil_frames = []\n",
85
+ " while True:\n",
86
+ " flag, frame = video_capture.read()\n",
87
+ " if not flag:\n",
88
+ " break\n",
89
+ " \n",
90
+ " pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n",
91
+ " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
92
+ " frame = torch.from_numpy(frame)\n",
93
+ " frame = frame.permute(2, 0, 1)\n",
94
+ " frames.append(frame)\n",
95
+ " if len(frames) >= max_frames:\n",
96
+ " break\n",
97
+ "\n",
98
+ " video_capture.release()\n",
99
+ " interval = max(int(fps / sample_fps), 1)\n",
100
+ " pil_frames = pil_frames[::interval][:frame_number]\n",
101
+ " frames = frames[::interval][:frame_number]\n",
102
+ " frames = torch.stack(frames).float() / 255\n",
103
+ " width = frames.shape[-1]\n",
104
+ " height = frames.shape[-2]\n",
105
+ " video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n",
106
+ " frames = video_transform(frames)\n",
107
+ " pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n",
108
+ "\n",
109
+ " if resize:\n",
110
+ " if new_width is None:\n",
111
+ " new_width = width // 32 * 32\n",
112
+ " if new_height is None:\n",
113
+ " new_height = height // 32 * 32\n",
114
+ " pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n",
115
+ " return frames, pil_frames\n",
116
+ " except Exception:\n",
117
+ " return None\n",
118
+ "\n",
119
+ "\n",
120
+ "def show_video(ori_path, rec_path, width=\"100%\"):\n",
121
+ " html = ''\n",
122
+ " if ori_path is not None:\n",
123
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
124
+ " <source src=\"{ori_path}\" type=\"video/mp4\">\n",
125
+ " </video>\n",
126
+ " \"\"\"\n",
127
+ " \n",
128
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
129
+ " <source src=\"{rec_path}\" type=\"video/mp4\">\n",
130
+ " </video>\n",
131
+ " \"\"\"\n",
132
+ " return HTML(html)"
133
+ ]
134
+ },
135
+ {
136
+ "attachments": {},
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "### Image Reconstruction"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "image_path = 'image_path'\n",
150
+ "\n",
151
+ "image = Image.open(image_path).convert(\"RGB\")\n",
152
+ "resize_width = image.width // 8 * 8\n",
153
+ "resize_height = image.height // 8 * 8\n",
154
+ "input_image_tensor = image_transform([image], resize_width, resize_height)\n",
155
+ "input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
156
+ "\n",
157
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
158
+ " latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n",
159
+ " rec_images = model.decode_latent(latent)\n",
160
+ "\n",
161
+ "display(image)\n",
162
+ "display(rec_images[0])"
163
+ ]
164
+ },
165
+ {
166
+ "attachments": {},
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "### Video Reconstruction"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "video_path = 'video_path'\n",
180
+ "\n",
181
+ "frame_number = 57 # x*8 + 1\n",
182
+ "width = 640\n",
183
+ "height = 384\n",
184
+ "\n",
185
+ "video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n",
186
+ "video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
187
+ "print(video_frames_tensor.shape)\n",
188
+ "\n",
189
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
190
+ " latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n",
191
+ " rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n",
192
+ "\n",
193
+ "export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n",
194
+ "export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n",
195
+ "show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")"
196
+ ]
197
+ }
198
+ ],
199
+ "metadata": {
200
+ "kernelspec": {
201
+ "display_name": "Python 3",
202
+ "language": "python",
203
+ "name": "python3"
204
+ },
205
+ "language_info": {
206
+ "codemirror_mode": {
207
+ "name": "ipython",
208
+ "version": 3
209
+ },
210
+ "file_extension": ".py",
211
+ "mimetype": "text/x-python",
212
+ "name": "python",
213
+ "nbconvert_exporter": "python",
214
+ "pygments_lexer": "ipython3",
215
+ "version": "3.8.10"
216
+ },
217
+ "orig_nbformat": 4
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 2
221
+ }
dataset/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset_cls import (
2
+ ImageTextDataset,
3
+ LengthGroupedVideoTextDataset,
4
+ ImageDataset,
5
+ VideoDataset,
6
+ )
7
+
8
+ from .dataloaders import (
9
+ create_image_text_dataloaders,
10
+ create_length_grouped_video_text_dataloader,
11
+ create_mixed_dataloaders,
12
+ )
dataset/bucket_loader.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ import math
5
+ import random
6
+ import time
7
+
8
+
9
+ class Bucketeer:
10
+ def __init__(
11
+ self, dataloader,
12
+ sizes=[(256, 256), (192, 384), (192, 320), (384, 192), (320, 192)],
13
+ is_infinite=True, epoch=0,
14
+ ):
15
+ # Ratios and Sizes : (w h)
16
+ self.sizes = sizes
17
+ self.batch_size = dataloader.batch_size
18
+ self._dataloader = dataloader
19
+ self.iterator = iter(dataloader)
20
+ self.sampler = dataloader.sampler
21
+ self.buckets = {s: [] for s in self.sizes}
22
+ self.is_infinite = is_infinite
23
+ self._epoch = epoch
24
+
25
+ def get_available_batch(self):
26
+ available_size = []
27
+ for b in self.buckets:
28
+ if len(self.buckets[b]) >= self.batch_size:
29
+ available_size.append(b)
30
+
31
+ if len(available_size) == 0:
32
+ return None
33
+ else:
34
+ b = random.choice(available_size)
35
+ batch = self.buckets[b][:self.batch_size]
36
+ self.buckets[b] = self.buckets[b][self.batch_size:]
37
+ return batch
38
+
39
+ def __next__(self):
40
+ batch = self.get_available_batch()
41
+ while batch is None:
42
+ try:
43
+ elements = next(self.iterator)
44
+ except StopIteration:
45
+ # To make it infinity
46
+ if self.is_infinite:
47
+ self._epoch += 1
48
+ if hasattr(self._dataloader.sampler, "set_epoch"):
49
+ self._dataloader.sampler.set_epoch(self._epoch)
50
+ time.sleep(2) # Prevent possible deadlock during epoch transition
51
+ self.iterator = iter(self._dataloader)
52
+ elements = next(self.iterator)
53
+ else:
54
+ raise StopIteration
55
+
56
+ for dct in elements:
57
+ try:
58
+ img = dct['video']
59
+ size = (img.shape[-1], img.shape[-2])
60
+ self.buckets[size].append({**{'video': img}, **{k:dct[k] for k in dct if k != 'video'}})
61
+ except Exception as e:
62
+ continue
63
+
64
+ batch = self.get_available_batch()
65
+
66
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
67
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
68
+
69
+ def __iter__(self):
70
+ return self
71
+
72
+ def __len__(self):
73
+ return len(self.iterator)
74
+
75
+
76
+ class TemporalLengthBucketeer:
77
+ def __init__(
78
+ self, dataloader, max_frames=16, epoch=0,
79
+ ):
80
+ self.batch_size = dataloader.batch_size
81
+ self._dataloader = dataloader
82
+ self.iterator = iter(dataloader)
83
+ self.buckets = {temp: [] for temp in range(1, max_frames + 1)}
84
+ self._epoch = epoch
85
+
86
+ def get_available_batch(self):
87
+ available_size = []
88
+ for b in self.buckets:
89
+ if len(self.buckets[b]) >= self.batch_size:
90
+ available_size.append(b)
91
+
92
+ if len(available_size) == 0:
93
+ return None
94
+ else:
95
+ b = random.choice(available_size)
96
+ batch = self.buckets[b][:self.batch_size]
97
+ self.buckets[b] = self.buckets[b][self.batch_size:]
98
+ return batch
99
+
100
+ def __next__(self):
101
+ batch = self.get_available_batch()
102
+ while batch is None:
103
+ try:
104
+ elements = next(self.iterator)
105
+ except StopIteration:
106
+ # To make it infinity
107
+ self._epoch += 1
108
+ if hasattr(self._dataloader.sampler, "set_epoch"):
109
+ self._dataloader.sampler.set_epoch(self._epoch)
110
+ time.sleep(2) # Prevent possible deadlock during epoch transition
111
+ self.iterator = iter(self._dataloader)
112
+ elements = next(self.iterator)
113
+
114
+ for dct in elements:
115
+ try:
116
+ video_latent = dct['video']
117
+ temp = video_latent.shape[2]
118
+ self.buckets[temp].append({**{'video': video_latent}, **{k:dct[k] for k in dct if k != 'video'}})
119
+ except Exception as e:
120
+ continue
121
+
122
+ batch = self.get_available_batch()
123
+
124
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
125
+ out = {k: torch.cat(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
126
+
127
+ if 'prompt_embed' in out:
128
+ # Loading the pre-extrcted textual features
129
+ prompt_embeds = out['prompt_embed'].clone()
130
+ del out['prompt_embed']
131
+ prompt_attention_mask = out['prompt_attention_mask'].clone()
132
+ del out['prompt_attention_mask']
133
+ pooled_prompt_embeds = out['pooled_prompt_embed'].clone()
134
+ del out['pooled_prompt_embed']
135
+
136
+ out['text'] = {
137
+ 'prompt_embeds' : prompt_embeds,
138
+ 'prompt_attention_mask': prompt_attention_mask,
139
+ 'pooled_prompt_embeds': pooled_prompt_embeds,
140
+ }
141
+
142
+ return out
143
+
144
+ def __iter__(self):
145
+ return self
146
+
147
+ def __len__(self):
148
+ return len(self.iterator)
dataset/dataloaders.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import time
5
+ import random
6
+ from typing import Iterable
7
+
8
+ from collections import OrderedDict
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset, IterableDataset, DistributedSampler, RandomSampler
11
+ from torch.utils.data.dataloader import default_collate
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ from torchvision.transforms import functional as F
15
+ from .bucket_loader import Bucketeer, TemporalLengthBucketeer
16
+
17
+
18
+ class IterLoader:
19
+ """
20
+ A wrapper to convert DataLoader as an infinite iterator.
21
+
22
+ Modified from:
23
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
24
+ """
25
+
26
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False, epoch: int = 0):
27
+ self._dataloader = dataloader
28
+ self.iter_loader = iter(self._dataloader)
29
+ self._use_distributed = use_distributed
30
+ self._epoch = epoch
31
+
32
+ @property
33
+ def epoch(self) -> int:
34
+ return self._epoch
35
+
36
+ def __next__(self):
37
+ try:
38
+ data = next(self.iter_loader)
39
+ except StopIteration:
40
+ self._epoch += 1
41
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
42
+ self._dataloader.sampler.set_epoch(self._epoch)
43
+ time.sleep(2) # Prevent possible deadlock during epoch transition
44
+ self.iter_loader = iter(self._dataloader)
45
+ data = next(self.iter_loader)
46
+
47
+ return data
48
+
49
+ def __iter__(self):
50
+ return self
51
+
52
+ def __len__(self):
53
+ return len(self._dataloader)
54
+
55
+
56
+ def identity(x):
57
+ return x
58
+
59
+
60
+ def create_image_text_dataloaders(dataset, batch_size, num_workers,
61
+ multi_aspect_ratio=True, epoch=0, sizes=[(512, 512), (384, 640), (640, 384)],
62
+ use_distributed=True, world_size=None, rank=None,
63
+ ):
64
+ """
65
+ The dataset has already been splited by different rank
66
+ """
67
+ if use_distributed:
68
+ assert world_size is not None
69
+ assert rank is not None
70
+ sampler = DistributedSampler(
71
+ dataset,
72
+ shuffle=True,
73
+ num_replicas=world_size,
74
+ rank=rank,
75
+ seed=epoch,
76
+ )
77
+ else:
78
+ sampler = RandomSampler(dataset)
79
+
80
+ dataloader = DataLoader(
81
+ dataset,
82
+ batch_size=batch_size,
83
+ num_workers=num_workers,
84
+ pin_memory=True,
85
+ sampler=sampler,
86
+ collate_fn=identity if multi_aspect_ratio else default_collate,
87
+ drop_last=True,
88
+ )
89
+
90
+ if multi_aspect_ratio:
91
+ dataloader_iterator = Bucketeer(
92
+ dataloader,
93
+ sizes=sizes,
94
+ is_infinite=True, epoch=epoch,
95
+ )
96
+ else:
97
+ dataloader_iterator = iter(dataloader)
98
+
99
+ # To make it infinite
100
+ loader = IterLoader(dataloader_iterator, use_distributed=False, epoch=epoch)
101
+
102
+ return loader
103
+
104
+
105
+ def create_length_grouped_video_text_dataloader(dataset, batch_size, num_workers, max_frames,
106
+ world_size=None, rank=None, epoch=0, use_distributed=False):
107
+ if use_distributed:
108
+ assert world_size is not None
109
+ assert rank is not None
110
+ sampler = DistributedSampler(
111
+ dataset,
112
+ shuffle=True,
113
+ num_replicas=world_size,
114
+ rank=rank,
115
+ seed=epoch,
116
+ )
117
+ else:
118
+ sampler = RandomSampler(dataset)
119
+
120
+ dataloader = DataLoader(
121
+ dataset,
122
+ batch_size=batch_size,
123
+ num_workers=num_workers,
124
+ pin_memory=True,
125
+ sampler=sampler,
126
+ collate_fn=identity,
127
+ drop_last=True,
128
+ )
129
+
130
+ # make it infinite
131
+ dataloader_iterator = TemporalLengthBucketeer(
132
+ dataloader,
133
+ max_frames=max_frames,
134
+ epoch=epoch,
135
+ )
136
+
137
+ return dataloader_iterator
138
+
139
+
140
+ def create_mixed_dataloaders(
141
+ dataset, batch_size, num_workers, world_size=None, rank=None, epoch=0,
142
+ image_mix_ratio=0.1, use_image_video_mixed_training=True,
143
+ ):
144
+ """
145
+ The video & image mixed training dataloader builder
146
+ """
147
+
148
+ assert world_size is not None
149
+ assert rank is not None
150
+
151
+ image_gpus = max(1, int(world_size * image_mix_ratio))
152
+ if use_image_video_mixed_training:
153
+ video_gpus = world_size - image_gpus
154
+ else:
155
+ # only use video data
156
+ video_gpus = world_size
157
+ image_gpus = 0
158
+
159
+ print(f"{image_gpus} gpus for image, {video_gpus} gpus for video")
160
+
161
+ if rank < video_gpus:
162
+ sampler = DistributedSampler(
163
+ dataset,
164
+ shuffle=True,
165
+ num_replicas=video_gpus,
166
+ rank=rank,
167
+ seed=epoch,
168
+ )
169
+ else:
170
+ sampler = DistributedSampler(
171
+ dataset,
172
+ shuffle=True,
173
+ num_replicas=image_gpus,
174
+ rank=rank - video_gpus,
175
+ seed=epoch,
176
+ )
177
+
178
+ loader = DataLoader(
179
+ dataset,
180
+ batch_size=batch_size,
181
+ num_workers=num_workers,
182
+ pin_memory=True,
183
+ sampler=sampler,
184
+ collate_fn=default_collate,
185
+ drop_last=True,
186
+ )
187
+
188
+ # To make it infinite
189
+ loader = IterLoader(loader, use_distributed=True, epoch=epoch)
190
+ return loader
dataset/dataset_cls.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import jsonlines
4
+ import torch
5
+ import math
6
+ import random
7
+ import cv2
8
+
9
+ from tqdm import tqdm
10
+ from collections import OrderedDict
11
+
12
+ from PIL import Image
13
+ from PIL import ImageFile
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+
16
+ import numpy as np
17
+ import subprocess
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms
20
+ from torchvision.transforms.functional import InterpolationMode
21
+ from torchvision.transforms import functional as F
22
+
23
+
24
+ class ImageTextDataset(Dataset):
25
+ """
26
+ Usage:
27
+ The dataset class for image-text pairs, used for image generation training
28
+ It supports multi-aspect ratio training
29
+ params:
30
+ anno_file: The annotation file list
31
+ add_normalize: whether to normalize the input image pixel to [-1, 1], default: True
32
+ ratios: The aspect ratios during training, format: width / height
33
+ sizes: The resoultion of training images, format: (width, height)
34
+ """
35
+ def __init__(
36
+ self, anno_file, add_normalize=True,
37
+ ratios=[1/1, 3/5, 5/3],
38
+ sizes=[(1024, 1024), (768, 1280), (1280, 768)],
39
+ crop_mode='random', p_random_ratio=0.0,
40
+ ):
41
+ # Ratios and Sizes : (w h)
42
+ super().__init__()
43
+
44
+ self.image_annos = []
45
+ if not isinstance(anno_file, list):
46
+ anno_file = [anno_file]
47
+
48
+ for anno_file_ in anno_file:
49
+ print(f"Load image annotation files from {anno_file_}")
50
+ with jsonlines.open(anno_file_, 'r') as reader:
51
+ for item in reader:
52
+ self.image_annos.append(item)
53
+
54
+ print(f"Totally Remained {len(self.image_annos)} images")
55
+
56
+ transform_list = [
57
+ transforms.ToTensor(),
58
+ ]
59
+
60
+ if add_normalize:
61
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
62
+
63
+ self.transform = transforms.Compose(transform_list)
64
+
65
+ print(f"Transform List is {transform_list}")
66
+
67
+ assert crop_mode in ['center', 'random']
68
+ self.crop_mode = crop_mode
69
+ self.ratios = ratios
70
+ self.sizes = sizes
71
+ self.p_random_ratio = p_random_ratio
72
+
73
+ def get_closest_size(self, x):
74
+ if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
75
+ best_size_idx = np.random.randint(len(self.ratios))
76
+ else:
77
+ w, h = x.width, x.height
78
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
79
+ return self.sizes[best_size_idx]
80
+
81
+ def get_resize_size(self, orig_size, tgt_size):
82
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
83
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
84
+ resize_size = max(alt_min, min(tgt_size))
85
+ else:
86
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
87
+ resize_size = max(alt_max, max(tgt_size))
88
+ return resize_size
89
+
90
+ def __len__(self):
91
+ return len(self.image_annos)
92
+
93
+ def __getitem__(self, index):
94
+ image_anno = self.image_annos[index]
95
+
96
+ try:
97
+ img = Image.open(image_anno['image']).convert("RGB")
98
+ text = image_anno['text']
99
+
100
+ assert isinstance(text, str), "Text should be str"
101
+
102
+ size = self.get_closest_size(img)
103
+ resize_size = self.get_resize_size((img.width, img.height), size)
104
+
105
+ img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
106
+
107
+ if self.crop_mode == 'center':
108
+ img = transforms.functional.center_crop(img, (size[1], size[0]))
109
+ elif self.crop_mode == 'random':
110
+ img = transforms.RandomCrop((size[1], size[0]))(img)
111
+ else:
112
+ img = transforms.functional.center_crop(img, (size[1], size[0]))
113
+
114
+ image_tensor = self.transform(img)
115
+
116
+ return {
117
+ "video": image_tensor, # using keyname `video`, to be compatible with video
118
+ "text" : text,
119
+ "identifier": 'image',
120
+ }
121
+
122
+ except Exception as e:
123
+ print(f'Load Image Error with {e}')
124
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
125
+
126
+
127
+ class LengthGroupedVideoTextDataset(Dataset):
128
+ """
129
+ Usage:
130
+ The dataset class for video-text pairs, used for video generation training
131
+ It groups the video with the same frames together
132
+ Now only supporting fixed resolution during training
133
+ params:
134
+ anno_file: The annotation file list
135
+ max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames
136
+ load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance
137
+ to reduce the time cost per batch
138
+ load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features
139
+ in advance, since the T5 encoder will cost many GPU memories
140
+ """
141
+
142
+ def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True):
143
+ super().__init__()
144
+
145
+ self.video_annos = []
146
+ self.max_frames = max_frames
147
+ self.load_vae_latent = load_vae_latent
148
+ self.load_text_fea = load_text_fea
149
+ self.resolution = resolution
150
+
151
+ assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future"
152
+
153
+ if not isinstance(anno_file, list):
154
+ anno_file = [anno_file]
155
+
156
+ for anno_file_ in anno_file:
157
+ with jsonlines.open(anno_file_, 'r') as reader:
158
+ for item in tqdm(reader):
159
+ self.video_annos.append(item)
160
+
161
+ print(f"Totally Remained {len(self.video_annos)} videos")
162
+
163
+ def __len__(self):
164
+ return len(self.video_annos)
165
+
166
+ def __getitem__(self, index):
167
+ try:
168
+ video_anno = self.video_annos[index]
169
+ text = video_anno['text']
170
+ latent_path = video_anno['latent']
171
+ latent = torch.load(latent_path, map_location='cpu') # loading the pre-extracted video latents
172
+
173
+ # TODO: remove the hard code latent shape checking
174
+ if self.resolution == '384p':
175
+ assert latent.shape[-1] == 640 // 8
176
+ assert latent.shape[-2] == 384 // 8
177
+ else:
178
+ assert self.resolution == '768p'
179
+ assert latent.shape[-1] == 1280 // 8
180
+ assert latent.shape[-2] == 768 // 8
181
+
182
+ cur_temp = latent.shape[2]
183
+ cur_temp = min(cur_temp, self.max_frames)
184
+
185
+ video_latent = latent[:,:,:cur_temp].float()
186
+ assert video_latent.shape[1] == 16
187
+
188
+ if self.load_text_fea:
189
+ text_fea_path = video_anno['text_fea']
190
+ text_fea = torch.load(text_fea_path, map_location='cpu')
191
+ return {
192
+ 'video': video_latent,
193
+ 'prompt_embed': text_fea['prompt_embed'],
194
+ 'prompt_attention_mask': text_fea['prompt_attention_mask'],
195
+ 'pooled_prompt_embed': text_fea['pooled_prompt_embed'],
196
+ "identifier": 'video',
197
+ }
198
+
199
+ else:
200
+ return {
201
+ 'video': video_latent,
202
+ 'text': text,
203
+ "identifier": 'video',
204
+ }
205
+
206
+ except Exception as e:
207
+ print(f'Load Video Error with {e}')
208
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
209
+
210
+
211
+ class VideoFrameProcessor:
212
+ # load a video and transform
213
+ def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24):
214
+
215
+ image_size = resolution
216
+
217
+ transform_list = [
218
+ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
219
+ transforms.CenterCrop(image_size),
220
+ ]
221
+
222
+ if add_normalize:
223
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
224
+
225
+ print(f"Transform List is {transform_list}")
226
+ self.num_frames = num_frames
227
+ self.transform = transforms.Compose(transform_list)
228
+ self.sample_fps = sample_fps
229
+
230
+ def __call__(self, video_path):
231
+ try:
232
+ video_capture = cv2.VideoCapture(video_path)
233
+ fps = video_capture.get(cv2.CAP_PROP_FPS)
234
+ frames = []
235
+
236
+ while True:
237
+ flag, frame = video_capture.read()
238
+ if not flag:
239
+ break
240
+
241
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
242
+ frame = torch.from_numpy(frame)
243
+ frame = frame.permute(2, 0, 1)
244
+ frames.append(frame)
245
+
246
+ video_capture.release()
247
+ sample_fps = self.sample_fps
248
+ interval = max(int(fps / sample_fps), 1)
249
+ frames = frames[::interval]
250
+
251
+ if len(frames) < self.num_frames:
252
+ num_frame_to_pack = self.num_frames - len(frames)
253
+ recurrent_num = num_frame_to_pack // len(frames)
254
+ frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))]
255
+ assert len(frames) >= self.num_frames, f'{len(frames)}'
256
+
257
+ start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1)))
258
+ start_index = random.choice(start_indexs)
259
+
260
+ filtered_frames = frames[start_index : start_index+self.num_frames]
261
+ assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}"
262
+
263
+ filtered_frames = torch.stack(filtered_frames).float() / 255
264
+ filtered_frames = self.transform(filtered_frames)
265
+ filtered_frames = filtered_frames.permute(1, 0, 2, 3)
266
+
267
+ return filtered_frames, None
268
+
269
+ except Exception as e:
270
+ print(f"Load video: {video_path} Error, Exception {e}")
271
+ return None, None
272
+
273
+
274
+ class VideoDataset(Dataset):
275
+ def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True):
276
+ super().__init__()
277
+
278
+ self.video_annos = []
279
+ self.max_frames = max_frames
280
+
281
+ if not isinstance(anno_file, list):
282
+ anno_file = [anno_file]
283
+
284
+ print(f"The training video clip frame number is {max_frames} ")
285
+
286
+ for anno_file_ in anno_file:
287
+ print(f"Load annotation file from {anno_file_}")
288
+
289
+ with jsonlines.open(anno_file_, 'r') as reader:
290
+ for item in tqdm(reader):
291
+ self.video_annos.append(item)
292
+
293
+ print(f"Totally Remained {len(self.video_annos)} videos")
294
+
295
+ self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize)
296
+
297
+ def __len__(self):
298
+ return len(self.video_annos)
299
+
300
+ def __getitem__(self, index):
301
+ video_anno = self.video_annos[index]
302
+ video_path = video_anno['video']
303
+
304
+ try:
305
+ video_tensors, video_frames = self.video_processor(video_path)
306
+
307
+ assert video_tensors.shape[1] == self.max_frames
308
+
309
+ return {
310
+ "video": video_tensors,
311
+ "identifier": 'video',
312
+ }
313
+
314
+ except Exception as e:
315
+ print('Loading Video Error with {e}')
316
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
317
+
318
+
319
+ class ImageDataset(Dataset):
320
+ def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True):
321
+ super().__init__()
322
+
323
+ self.image_annos = []
324
+ self.max_frames = max_frames
325
+ image_paths = []
326
+
327
+ if not isinstance(anno_file, list):
328
+ anno_file = [anno_file]
329
+
330
+ for anno_file_ in anno_file:
331
+ print(f"Load annotation file from {anno_file_}")
332
+ with jsonlines.open(anno_file_, 'r') as reader:
333
+ for item in tqdm(reader):
334
+ image_paths.append(item['image'])
335
+
336
+ print(f"Totally Remained {len(image_paths)} images")
337
+
338
+ # pack multiple frames
339
+ for idx in range(0, len(image_paths), self.max_frames):
340
+ image_path_shard = image_paths[idx : idx + self.max_frames]
341
+ if len(image_path_shard) < self.max_frames:
342
+ image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)]
343
+ assert len(image_path_shard) == self.max_frames
344
+ self.image_annos.append(image_path_shard)
345
+
346
+ image_size = resolution
347
+ transform_list = [
348
+ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
349
+ transforms.CenterCrop(image_size),
350
+ transforms.ToTensor(),
351
+ ]
352
+ if add_normalize:
353
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
354
+
355
+ print(f"Transform List is {transform_list}")
356
+ self.transform = transforms.Compose(transform_list)
357
+
358
+ def __len__(self):
359
+ return len(self.image_annos)
360
+
361
+ def __getitem__(self, index):
362
+ image_paths = self.image_annos[index]
363
+
364
+ try:
365
+ packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths]
366
+ filtered_frames = [self.transform(frame) for frame in packed_pil_frames]
367
+ filtered_frames = torch.stack(filtered_frames) # [t, c, h, w]
368
+ filtered_frames = filtered_frames.permute(1, 0, 2, 3) # [c, t, h, w]
369
+
370
+ return {
371
+ "video": filtered_frames,
372
+ "identifier": 'image',
373
+ }
374
+
375
+ except Exception as e:
376
+ print(f'Load Images Error with {e}')
377
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
diffusion_schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_cosine_ddpm import DDPMCosineScheduler
2
+ from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
diffusion_schedulers/scheduling_cosine_ddpm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class DDPMSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's step function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.Tensor
25
+
26
+
27
+ class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
28
+
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ scaler: float = 1.0,
33
+ s: float = 0.008,
34
+ ):
35
+ self.scaler = scaler
36
+ self.s = torch.tensor([s])
37
+ self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
38
+
39
+ # standard deviation of the initial noise distribution
40
+ self.init_noise_sigma = 1.0
41
+
42
+ def _alpha_cumprod(self, t, device):
43
+ if self.scaler > 1:
44
+ t = 1 - (1 - t) ** self.scaler
45
+ elif self.scaler < 1:
46
+ t = t**self.scaler
47
+ alpha_cumprod = torch.cos(
48
+ (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
49
+ ) ** 2 / self._init_alpha_cumprod.to(device)
50
+ return alpha_cumprod.clamp(0.0001, 0.9999)
51
+
52
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
53
+ """
54
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
55
+ current timestep.
56
+
57
+ Args:
58
+ sample (`torch.Tensor`): input sample
59
+ timestep (`int`, optional): current timestep
60
+
61
+ Returns:
62
+ `torch.Tensor`: scaled input sample
63
+ """
64
+ return sample
65
+
66
+ def set_timesteps(
67
+ self,
68
+ num_inference_steps: int = None,
69
+ timesteps: Optional[List[int]] = None,
70
+ device: Union[str, torch.device] = None,
71
+ ):
72
+ """
73
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
74
+
75
+ Args:
76
+ num_inference_steps (`Dict[float, int]`):
77
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
78
+ `timesteps` must be `None`.
79
+ device (`str` or `torch.device`, optional):
80
+ the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
81
+ """
82
+ if timesteps is None:
83
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
84
+ if not isinstance(timesteps, torch.Tensor):
85
+ timesteps = torch.Tensor(timesteps).to(device)
86
+ self.timesteps = timesteps
87
+
88
+ def step(
89
+ self,
90
+ model_output: torch.Tensor,
91
+ timestep: int,
92
+ sample: torch.Tensor,
93
+ generator=None,
94
+ return_dict: bool = True,
95
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
96
+ dtype = model_output.dtype
97
+ device = model_output.device
98
+ t = timestep
99
+
100
+ prev_t = self.previous_timestep(t)
101
+
102
+ alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
103
+ alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
104
+ alpha = alpha_cumprod / alpha_cumprod_prev
105
+
106
+ mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
107
+
108
+ std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
109
+ std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
110
+ pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
111
+
112
+ if not return_dict:
113
+ return (pred.to(dtype),)
114
+
115
+ return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
116
+
117
+ def add_noise(
118
+ self,
119
+ original_samples: torch.Tensor,
120
+ noise: torch.Tensor,
121
+ timesteps: torch.Tensor,
122
+ ) -> torch.Tensor:
123
+ device = original_samples.device
124
+ dtype = original_samples.dtype
125
+ alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
126
+ timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
127
+ )
128
+ noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
129
+ return noisy_samples.to(dtype=dtype)
130
+
131
+ def __len__(self):
132
+ return self.config.num_train_timesteps
133
+
134
+ def previous_timestep(self, timestep):
135
+ index = (self.timesteps - timestep[0]).abs().argmin().item()
136
+ prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
137
+ return prev_t
diffusion_schedulers/scheduling_flow_matching.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's `step` function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.FloatTensor
25
+
26
+
27
+ class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
28
+ """
29
+ Euler scheduler.
30
+
31
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
32
+ methods the library implements for all schedulers such as loading and saving.
33
+
34
+ Args:
35
+ num_train_timesteps (`int`, defaults to 1000):
36
+ The number of diffusion steps to train the model.
37
+ timestep_spacing (`str`, defaults to `"linspace"`):
38
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
39
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
40
+ shift (`float`, defaults to 1.0):
41
+ The shift value for the timestep schedule.
42
+ """
43
+
44
+ _compatibles = []
45
+ order = 1
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ num_train_timesteps: int = 1000,
51
+ shift: float = 1.0, # Following Stable diffusion 3,
52
+ stages: int = 3,
53
+ stage_range: List = [0, 1/3, 2/3, 1],
54
+ gamma: float = 1/3,
55
+ ):
56
+
57
+ self.timestep_ratios = {} # The timestep ratio for each stage
58
+ self.timesteps_per_stage = {} # The detailed timesteps per stage
59
+ self.sigmas_per_stage = {}
60
+ self.start_sigmas = {}
61
+ self.end_sigmas = {}
62
+ self.ori_start_sigmas = {}
63
+
64
+ # self.init_sigmas()
65
+ self.init_sigmas_for_each_stage()
66
+ self.sigma_min = self.sigmas[-1].item()
67
+ self.sigma_max = self.sigmas[0].item()
68
+ self.gamma = gamma
69
+
70
+ def init_sigmas(self):
71
+ """
72
+ initialize the global timesteps and sigmas
73
+ """
74
+ num_train_timesteps = self.config.num_train_timesteps
75
+ shift = self.config.shift
76
+
77
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
78
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
79
+
80
+ sigmas = timesteps / num_train_timesteps
81
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
82
+
83
+ self.timesteps = sigmas * num_train_timesteps
84
+
85
+ self._step_index = None
86
+ self._begin_index = None
87
+
88
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
+
90
+ def init_sigmas_for_each_stage(self):
91
+ """
92
+ Init the timesteps for each stage
93
+ """
94
+ self.init_sigmas()
95
+
96
+ stage_distance = []
97
+ stages = self.config.stages
98
+ training_steps = self.config.num_train_timesteps
99
+ stage_range = self.config.stage_range
100
+
101
+ # Init the start and end point of each stage
102
+ for i_s in range(stages):
103
+ # To decide the start and ends point
104
+ start_indice = int(stage_range[i_s] * training_steps)
105
+ start_indice = max(start_indice, 0)
106
+ end_indice = int(stage_range[i_s+1] * training_steps)
107
+ end_indice = min(end_indice, training_steps)
108
+ start_sigma = self.sigmas[start_indice].item()
109
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
110
+ self.ori_start_sigmas[i_s] = start_sigma
111
+
112
+ if i_s != 0:
113
+ ori_sigma = 1 - start_sigma
114
+ gamma = self.config.gamma
115
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
116
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
117
+ start_sigma = 1 - corrected_sigma
118
+
119
+ stage_distance.append(start_sigma - end_sigma)
120
+ self.start_sigmas[i_s] = start_sigma
121
+ self.end_sigmas[i_s] = end_sigma
122
+
123
+ # Determine the ratio of each stage according to flow length
124
+ tot_distance = sum(stage_distance)
125
+ for i_s in range(stages):
126
+ if i_s == 0:
127
+ start_ratio = 0.0
128
+ else:
129
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
130
+ if i_s == stages - 1:
131
+ end_ratio = 1.0
132
+ else:
133
+ end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
134
+
135
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
136
+
137
+ # Determine the timesteps and sigmas for each stage
138
+ for i_s in range(stages):
139
+ timestep_ratio = self.timestep_ratios[i_s]
140
+ timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
141
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
142
+ timesteps = np.linspace(
143
+ timestep_max, timestep_min, training_steps + 1,
144
+ )
145
+ self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
146
+ stage_sigmas = np.linspace(
147
+ 1, 0, training_steps + 1,
148
+ )
149
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
150
+
151
+ @property
152
+ def step_index(self):
153
+ """
154
+ The index counter for current timestep. It will increase 1 after each scheduler step.
155
+ """
156
+ return self._step_index
157
+
158
+ @property
159
+ def begin_index(self):
160
+ """
161
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
162
+ """
163
+ return self._begin_index
164
+
165
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
166
+ def set_begin_index(self, begin_index: int = 0):
167
+ """
168
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
169
+
170
+ Args:
171
+ begin_index (`int`):
172
+ The begin index for the scheduler.
173
+ """
174
+ self._begin_index = begin_index
175
+
176
+ def _sigma_to_t(self, sigma):
177
+ return sigma * self.config.num_train_timesteps
178
+
179
+ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
180
+ """
181
+ Setting the timesteps and sigmas for each stage
182
+ """
183
+ self.num_inference_steps = num_inference_steps
184
+ training_steps = self.config.num_train_timesteps
185
+ self.init_sigmas()
186
+
187
+ stage_timesteps = self.timesteps_per_stage[stage_index]
188
+ timestep_max = stage_timesteps[0].item()
189
+ timestep_min = stage_timesteps[-1].item()
190
+
191
+ timesteps = np.linspace(
192
+ timestep_max, timestep_min, num_inference_steps,
193
+ )
194
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
195
+
196
+ stage_sigmas = self.sigmas_per_stage[stage_index]
197
+ sigma_max = stage_sigmas[0].item()
198
+ sigma_min = stage_sigmas[-1].item()
199
+
200
+ ratios = np.linspace(
201
+ sigma_max, sigma_min, num_inference_steps
202
+ )
203
+ sigmas = torch.from_numpy(ratios).to(device=device)
204
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
205
+
206
+ self._step_index = None
207
+
208
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
209
+ if schedule_timesteps is None:
210
+ schedule_timesteps = self.timesteps
211
+
212
+ indices = (schedule_timesteps == timestep).nonzero()
213
+
214
+ # The sigma index that is taken for the **very** first `step`
215
+ # is always the second index (or the last index if there is only 1)
216
+ # This way we can ensure we don't accidentally skip a sigma in
217
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
218
+ pos = 1 if len(indices) > 1 else 0
219
+
220
+ return indices[pos].item()
221
+
222
+ def _init_step_index(self, timestep):
223
+ if self.begin_index is None:
224
+ if isinstance(timestep, torch.Tensor):
225
+ timestep = timestep.to(self.timesteps.device)
226
+ self._step_index = self.index_for_timestep(timestep)
227
+ else:
228
+ self._step_index = self._begin_index
229
+
230
+ def step(
231
+ self,
232
+ model_output: torch.FloatTensor,
233
+ timestep: Union[float, torch.FloatTensor],
234
+ sample: torch.FloatTensor,
235
+ generator: Optional[torch.Generator] = None,
236
+ return_dict: bool = True,
237
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
238
+ """
239
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
240
+ process from the learned model outputs (most often the predicted noise).
241
+
242
+ Args:
243
+ model_output (`torch.FloatTensor`):
244
+ The direct output from learned diffusion model.
245
+ timestep (`float`):
246
+ The current discrete timestep in the diffusion chain.
247
+ sample (`torch.FloatTensor`):
248
+ A current instance of a sample created by the diffusion process.
249
+ generator (`torch.Generator`, *optional*):
250
+ A random number generator.
251
+ return_dict (`bool`):
252
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
253
+ tuple.
254
+
255
+ Returns:
256
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
257
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
258
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
259
+ """
260
+
261
+ if (
262
+ isinstance(timestep, int)
263
+ or isinstance(timestep, torch.IntTensor)
264
+ or isinstance(timestep, torch.LongTensor)
265
+ ):
266
+ raise ValueError(
267
+ (
268
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
269
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
270
+ " one of the `scheduler.timesteps` as a timestep."
271
+ ),
272
+ )
273
+
274
+ if self.step_index is None:
275
+ self._step_index = 0
276
+
277
+ # Upcast to avoid precision issues when computing prev_sample
278
+ sample = sample.to(torch.float32)
279
+
280
+ sigma = self.sigmas[self.step_index]
281
+ sigma_next = self.sigmas[self.step_index + 1]
282
+
283
+ prev_sample = sample + (sigma_next - sigma) * model_output
284
+
285
+ # Cast sample back to model compatible dtype
286
+ prev_sample = prev_sample.to(model_output.dtype)
287
+
288
+ # upon completion increase step index by one
289
+ self._step_index += 1
290
+
291
+ if not return_dict:
292
+ return (prev_sample,)
293
+
294
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
295
+
296
+ def __len__(self):
297
+ return self.config.num_train_timesteps
docs/DiT.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Flow's DiT Finetuning Guide
2
+
3
+ This is the finetuning guide for the DiT in Pyramid Flow. We provide instructions for both autoregressive and non-autoregressive versions. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid). Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE) for VAE finetuning.
4
+
5
+ ## Hardware Requirements
6
+
7
+ + DiT finetuning: At least 8 A100 GPUs.
8
+
9
+
10
+ ## Prepare the Dataset
11
+
12
+ The training dataset should be arranged into a json file, with `video`, `text` fields. Since the video vae latent extraction is very slow, we strongly recommend you to pre-extract the video vae latents to save the training time. We provide a video vae latent extraction script in folder `tools`. You can run it with the following command:
13
+
14
+ ```bash
15
+ sh scripts/extract_vae_latent.sh
16
+ ```
17
+
18
+ (optional) Since the T5 text encoder will cost a lot of GPU memory, pre-extract the text features will save the training memory. We also provide a text feature extraction script in folder `tools`. You can run it with the following command:
19
+
20
+ ```bash
21
+ sh scripts/extract_text_feature.sh
22
+ ```
23
+
24
+ The final training annotation json file should look like the following format:
25
+
26
+ ```
27
+ {"video": video_path, "text": text prompt, "latent": extracted video vae latent, "text_fea": extracted text feature}
28
+ ```
29
+
30
+ We provide the example json annotation files for [video](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/video_text.jsonl) and [image](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/image_text.jsonl)) training in the `annotation` folder. You can refer them to prepare your training dataset.
31
+
32
+
33
+ ## Run Training
34
+ We provide two types of training scripts: (1) autoregressive video generation training with temporal pyramid. (2) Full-sequence diffusion training with pyramid-flow for both text-to-image and text-to-video training. This corresponds to the following two script files. Running these training scripts using at least 8 GPUs:
35
+
36
+ + `scripts/train_pyramid_flow.sh`: The autoregressive video generation training with temporal pyramid.
37
+
38
+ ```bash
39
+ sh scripts/train_pyramid_flow.sh
40
+ ```
41
+
42
+ + `scripts/train_pyramid_flow_without_ar.sh`: Using pyramid-flow for full-sequence diffusion training.
43
+
44
+ ```bash
45
+ sh scripts/train_pyramid_flow_without_ar.sh
46
+ ```
47
+
48
+
49
+ ## Tips
50
+
51
+ + For the 768p version, make sure to add the args: `--gradient_checkpointing`
52
+ + Param `NUM_FRAMES` should be set to a multiple of 8
53
+ + For the param `video_sync_group`, it indicates the number of process that accepts the same input video, used for temporal pyramid AR training. We recommend to set this value to 4, 8 or 16. (16 is better if you have more GPUs)
54
+ + Make sure to set `NUM_FRAMES % VIDEO_SYNC_GROUP == 0`, `GPUS % VIDEO_SYNC_GROUP == 0`, and `BATCH_SIZE % 4 == 0`
docs/VAE.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Flow's VAE Training Guide
2
+
3
+ This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning.
4
+
5
+ ## Hardware Requirements
6
+
7
+ + VAE training: At least 8 A100 GPUs.
8
+
9
+
10
+ ## Prepare the Dataset
11
+
12
+ The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format:
13
+
14
+ ```
15
+ # For Video
16
+ {"video": video_path}
17
+
18
+ # For Image
19
+ {"image": image_path}
20
+ ```
21
+
22
+ ## Run Training
23
+
24
+ The causal video vae undergoes a two-stage training.
25
+ + Stage-1: image and video mixed training
26
+ + Stage-2: pure video training, using context parallel to load video with more video frames
27
+
28
+ The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows:
29
+
30
+ ```bash
31
+ sh scripts/train_causal_video_vae.sh
32
+ ```
33
+
34
+ We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction.
35
+ > The original vgg lpips download URL is not available, I have shared the one we used in this [URL](https://drive.google.com/file/d/1YeFlX5BKKw-HGkjNd1r7DSwas1iJJwqC/view). You can download it and replace the LPIPS_CKPT with the correct path.
36
+
37
+
38
+ ## Tips
39
+
40
+ + For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default.
41
+ + Set the `resolution` to 256 is enough for VAE training.
42
+ + For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training.
43
+ + For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1`
image_generation_demo.ipynb ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "import PIL\n",
14
+ "from PIL import Image\n",
15
+ "from IPython.display import HTML\n",
16
+ "from pyramid_dit import PyramidDiTForVideoGeneration\n",
17
+ "from IPython.display import Image as ipython_image\n",
18
+ "from diffusers.utils import load_image, export_to_video, export_to_gif"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "variant='diffusion_transformer_image' # For low resolution\n",
28
+ "model_name = \"pyramid_flux\"\n",
29
+ "\n",
30
+ "model_path = \"/home/jinyang06/models/pyramid-flow-miniflux\" # The downloaded checkpoint dir\n",
31
+ "model_dtype = 'bf16'\n",
32
+ "\n",
33
+ "device_id = 0\n",
34
+ "torch.cuda.set_device(device_id)\n",
35
+ "\n",
36
+ "model = PyramidDiTForVideoGeneration(\n",
37
+ " model_path,\n",
38
+ " model_dtype,\n",
39
+ " model_name=model_name,\n",
40
+ " model_variant=variant,\n",
41
+ ")\n",
42
+ "\n",
43
+ "model.vae.to(\"cuda\")\n",
44
+ "model.dit.to(\"cuda\")\n",
45
+ "model.text_encoder.to(\"cuda\")\n",
46
+ "\n",
47
+ "model.vae.enable_tiling()\n",
48
+ "\n",
49
+ "if model_dtype == \"bf16\":\n",
50
+ " torch_dtype = torch.bfloat16 \n",
51
+ "elif model_dtype == \"fp16\":\n",
52
+ " torch_dtype = torch.float16\n",
53
+ "else:\n",
54
+ " torch_dtype = torch.float32"
55
+ ]
56
+ },
57
+ {
58
+ "attachments": {},
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "### Text-to-Image"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "prompt = \"shoulder and full head portrait of a beautiful 19 year old girl, brunette, smiling, stunning, highly detailed, glamour lighting, HDR, photorealistic, hyperrealism, octane render, unreal engine\"\n",
72
+ "\n",
73
+ "# now support 3 aspect ratios\n",
74
+ "resolution_dict = {\n",
75
+ " '1:1' : (1024, 1024),\n",
76
+ " '5:3' : (1280, 768),\n",
77
+ " '3:5' : (768, 1280),\n",
78
+ "}\n",
79
+ "\n",
80
+ "ratio = '1:1' # 1:1, 5:3, 3:5\n",
81
+ "\n",
82
+ "width, height = resolution_dict[ratio]\n",
83
+ "\n",
84
+ "\n",
85
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
86
+ " images = model.generate(\n",
87
+ " prompt=prompt,\n",
88
+ " num_inference_steps=[20, 20, 20],\n",
89
+ " height=height,\n",
90
+ " width=width,\n",
91
+ " temp=1,\n",
92
+ " guidance_scale=9.0, \n",
93
+ " output_type=\"pil\",\n",
94
+ " save_memory=False, \n",
95
+ " )\n",
96
+ "\n",
97
+ "display(images[0])"
98
+ ]
99
+ }
100
+ ],
101
+ "metadata": {
102
+ "kernelspec": {
103
+ "display_name": "Python 3",
104
+ "language": "python",
105
+ "name": "python3"
106
+ },
107
+ "language_info": {
108
+ "codemirror_mode": {
109
+ "name": "ipython",
110
+ "version": 3
111
+ },
112
+ "file_extension": ".py",
113
+ "mimetype": "text/x-python",
114
+ "name": "python",
115
+ "nbconvert_exporter": "python",
116
+ "pygments_lexer": "ipython3",
117
+ "version": "3.8.10"
118
+ },
119
+ "orig_nbformat": 4
120
+ },
121
+ "nbformat": 4,
122
+ "nbformat_minor": 2
123
+ }
inference_multigpu.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import sys
4
+ import argparse
5
+ import random
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from diffusers.utils import export_to_video
10
+ from pyramid_dit import PyramidDiTForVideoGeneration
11
+ from trainer_misc import init_distributed_mode, init_sequence_parallel_group
12
+ import PIL
13
+ from PIL import Image
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
18
+ parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
19
+ parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
20
+ parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
21
+ parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
22
+ parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
23
+ parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
24
+ parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
25
+ parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
26
+
27
+ return parser.parse_args()
28
+
29
+
30
+ def main():
31
+ args = get_args()
32
+
33
+ # setup DDP
34
+ init_distributed_mode(args)
35
+
36
+ assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
37
+
38
+ # Enable sequence parallel
39
+ init_sequence_parallel_group(args)
40
+
41
+ device = torch.device('cuda')
42
+ rank = args.rank
43
+ model_dtype = args.model_dtype
44
+
45
+ model = PyramidDiTForVideoGeneration(
46
+ args.model_path,
47
+ model_dtype,
48
+ model_name=args.model_name,
49
+ model_variant=args.variant,
50
+ )
51
+
52
+ model.vae.to(device)
53
+ model.dit.to(device)
54
+ model.text_encoder.to(device)
55
+ model.vae.enable_tiling()
56
+
57
+ if model_dtype == "bf16":
58
+ torch_dtype = torch.bfloat16
59
+ elif model_dtype == "fp16":
60
+ torch_dtype = torch.float16
61
+ else:
62
+ torch_dtype = torch.float32
63
+
64
+ # The video generation config
65
+ if args.variant == 'diffusion_transformer_768p':
66
+ width = 1280
67
+ height = 768
68
+ else:
69
+ assert args.variant == 'diffusion_transformer_384p'
70
+ width = 640
71
+ height = 384
72
+
73
+ if args.task == 't2v':
74
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
75
+
76
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
77
+ frames = model.generate(
78
+ prompt=prompt,
79
+ num_inference_steps=[20, 20, 20],
80
+ video_num_inference_steps=[10, 10, 10],
81
+ height=height,
82
+ width=width,
83
+ temp=args.temp,
84
+ guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
85
+ video_guidance_scale=5.0, # The guidance for the other video latent
86
+ output_type="pil",
87
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
88
+ cpu_offloading=False, # If OOM, set it to True to reduce memory usage
89
+ inference_multigpu=True,
90
+ )
91
+ if rank == 0:
92
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
93
+
94
+ else:
95
+ assert args.task == 'i2v'
96
+
97
+ image_path = 'assets/the_great_wall.jpg'
98
+ image = Image.open(image_path).convert("RGB")
99
+ image = image.resize((width, height))
100
+
101
+ prompt = "FPV flying over the Great Wall"
102
+
103
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
104
+ frames = model.generate_i2v(
105
+ prompt=prompt,
106
+ input_image=image,
107
+ num_inference_steps=[10, 10, 10],
108
+ temp=args.temp,
109
+ video_guidance_scale=4.0,
110
+ output_type="pil",
111
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
112
+ cpu_offloading=False, # If OOM, set it to True to reduce memory usage
113
+ inference_multigpu=True,
114
+ )
115
+
116
+ if rank == 0:
117
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
118
+
119
+ torch.distributed.barrier()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
pyramid_dit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
2
+ from .flux_modules import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTextEncoderWithMask
3
+ from .mmdit_modules import JointTransformerBlock, SD3TextEncoderWithMask
pyramid_dit/flux_modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_pyramid_flux import PyramidFluxTransformer
2
+ from .modeling_text_encoder import FluxTextEncoderWithMask
3
+ from .modeling_flux_block import FluxSingleTransformerBlock, FluxTransformerBlock