jbilcke-hf HF staff commited on
Commit
69f3483
1 Parent(s): e08f02e
streamv2v/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline import StreamV2V
streamv2v/acceleration/__init__.py ADDED
File without changes
streamv2v/acceleration/sfast/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile
4
+
5
+ from ...pipeline import StreamV2V
6
+
7
+
8
+ def accelerate_with_stable_fast(
9
+ stream: StreamV2V,
10
+ config: Optional[CompilationConfig] = None,
11
+ ):
12
+ if config is None:
13
+ config = CompilationConfig.Default()
14
+ # xformers and Triton are suggested for achieving best performance.
15
+ try:
16
+ import xformers
17
+
18
+ config.enable_xformers = True
19
+ except ImportError:
20
+ print("xformers not installed, skip")
21
+ try:
22
+ import triton
23
+
24
+ config.enable_triton = True
25
+ except ImportError:
26
+ print("Triton not installed, skip")
27
+ # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
28
+ config.enable_cuda_graph = True
29
+ stream.pipe = compile(stream.pipe, config)
30
+ stream.unet = stream.pipe.unet
31
+ stream.vae = stream.pipe.vae
32
+ stream.text_encoder = stream.pipe.text_encoder
33
+ return stream
streamv2v/acceleration/tensorrt/__init__.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import torch
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel
6
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
7
+ retrieve_latents,
8
+ )
9
+ from polygraphy import cuda
10
+
11
+ from ...pipeline import StreamV2V
12
+ from .builder import EngineBuilder, create_onnx_path
13
+ from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine
14
+ from .models import VAE, BaseModel, UNet, VAEEncoder
15
+
16
+
17
+ class TorchVAEEncoder(torch.nn.Module):
18
+ def __init__(self, vae: AutoencoderKL):
19
+ super().__init__()
20
+ self.vae = vae
21
+
22
+ def forward(self, x: torch.Tensor):
23
+ return retrieve_latents(self.vae.encode(x))
24
+
25
+
26
+ def compile_vae_encoder(
27
+ vae: TorchVAEEncoder,
28
+ model_data: BaseModel,
29
+ onnx_path: str,
30
+ onnx_opt_path: str,
31
+ engine_path: str,
32
+ opt_batch_size: int = 1,
33
+ engine_build_options: dict = {},
34
+ ):
35
+ builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
36
+ builder.build(
37
+ onnx_path,
38
+ onnx_opt_path,
39
+ engine_path,
40
+ opt_batch_size=opt_batch_size,
41
+ **engine_build_options,
42
+ )
43
+
44
+
45
+ def compile_vae_decoder(
46
+ vae: AutoencoderKL,
47
+ model_data: BaseModel,
48
+ onnx_path: str,
49
+ onnx_opt_path: str,
50
+ engine_path: str,
51
+ opt_batch_size: int = 1,
52
+ engine_build_options: dict = {},
53
+ ):
54
+ vae = vae.to(torch.device("cuda"))
55
+ builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
56
+ builder.build(
57
+ onnx_path,
58
+ onnx_opt_path,
59
+ engine_path,
60
+ opt_batch_size=opt_batch_size,
61
+ **engine_build_options,
62
+ )
63
+
64
+
65
+ def compile_unet(
66
+ unet: UNet2DConditionModel,
67
+ model_data: BaseModel,
68
+ onnx_path: str,
69
+ onnx_opt_path: str,
70
+ engine_path: str,
71
+ opt_batch_size: int = 1,
72
+ engine_build_options: dict = {},
73
+ ):
74
+ unet = unet.to(torch.device("cuda"), dtype=torch.float16)
75
+ builder = EngineBuilder(model_data, unet, device=torch.device("cuda"))
76
+ builder.build(
77
+ onnx_path,
78
+ onnx_opt_path,
79
+ engine_path,
80
+ opt_batch_size=opt_batch_size,
81
+ **engine_build_options,
82
+ )
83
+
84
+
85
+ def accelerate_with_tensorrt(
86
+ stream: StreamV2V,
87
+ engine_dir: str,
88
+ max_batch_size: int = 2,
89
+ min_batch_size: int = 1,
90
+ use_cuda_graph: bool = False,
91
+ engine_build_options: dict = {},
92
+ ):
93
+ if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None:
94
+ engine_build_options["opt_batch_size"] = max_batch_size
95
+ text_encoder = stream.text_encoder
96
+ unet = stream.unet
97
+ vae = stream.vae
98
+
99
+ del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae
100
+
101
+ vae_config = vae.config
102
+ vae_dtype = vae.dtype
103
+
104
+ unet.to(torch.device("cpu"))
105
+ vae.to(torch.device("cpu"))
106
+
107
+ gc.collect()
108
+ torch.cuda.empty_cache()
109
+
110
+ onnx_dir = os.path.join(engine_dir, "onnx")
111
+ os.makedirs(onnx_dir, exist_ok=True)
112
+
113
+ unet_engine_path = f"{engine_dir}/unet.engine"
114
+ vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine"
115
+ vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine"
116
+
117
+ unet_model = UNet(
118
+ fp16=True,
119
+ device=stream.device,
120
+ max_batch_size=max_batch_size,
121
+ min_batch_size=min_batch_size,
122
+ embedding_dim=text_encoder.config.hidden_size,
123
+ unet_dim=unet.config.in_channels,
124
+ )
125
+ vae_decoder_model = VAE(
126
+ device=stream.device,
127
+ max_batch_size=max_batch_size,
128
+ min_batch_size=min_batch_size,
129
+ )
130
+ vae_encoder_model = VAEEncoder(
131
+ device=stream.device,
132
+ max_batch_size=max_batch_size,
133
+ min_batch_size=min_batch_size,
134
+ )
135
+
136
+ if not os.path.exists(unet_engine_path):
137
+ compile_unet(
138
+ unet,
139
+ unet_model,
140
+ create_onnx_path("unet", onnx_dir, opt=False),
141
+ create_onnx_path("unet", onnx_dir, opt=True),
142
+ unet_engine_path,
143
+ **engine_build_options,
144
+ )
145
+ else:
146
+ del unet
147
+
148
+ if not os.path.exists(vae_decoder_engine_path):
149
+ vae.forward = vae.decode
150
+ compile_vae_decoder(
151
+ vae,
152
+ vae_decoder_model,
153
+ create_onnx_path("vae_decoder", onnx_dir, opt=False),
154
+ create_onnx_path("vae_decoder", onnx_dir, opt=True),
155
+ vae_decoder_engine_path,
156
+ **engine_build_options,
157
+ )
158
+
159
+ if not os.path.exists(vae_encoder_engine_path):
160
+ vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda"))
161
+ compile_vae_encoder(
162
+ vae_encoder,
163
+ vae_encoder_model,
164
+ create_onnx_path("vae_encoder", onnx_dir, opt=False),
165
+ create_onnx_path("vae_encoder", onnx_dir, opt=True),
166
+ vae_encoder_engine_path,
167
+ **engine_build_options,
168
+ )
169
+
170
+ del vae
171
+
172
+ cuda_steram = cuda.Stream()
173
+
174
+ stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph)
175
+ stream.vae = AutoencoderKLEngine(
176
+ vae_encoder_engine_path,
177
+ vae_decoder_engine_path,
178
+ cuda_steram,
179
+ stream.pipe.vae_scale_factor,
180
+ use_cuda_graph=use_cuda_graph,
181
+ )
182
+ setattr(stream.vae, "config", vae_config)
183
+ setattr(stream.vae, "dtype", vae_dtype)
184
+
185
+ gc.collect()
186
+ torch.cuda.empty_cache()
187
+
188
+ return stream
streamv2v/acceleration/tensorrt/builder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import *
4
+
5
+ import torch
6
+
7
+ from .models import BaseModel
8
+ from .utilities import (
9
+ build_engine,
10
+ export_onnx,
11
+ optimize_onnx,
12
+ )
13
+
14
+
15
+ def create_onnx_path(name, onnx_dir, opt=True):
16
+ return os.path.join(onnx_dir, name + (".opt" if opt else "") + ".onnx")
17
+
18
+
19
+ class EngineBuilder:
20
+ def __init__(
21
+ self,
22
+ model: BaseModel,
23
+ network: Any,
24
+ device=torch.device("cuda"),
25
+ ):
26
+ self.device = device
27
+
28
+ self.model = model
29
+ self.network = network
30
+
31
+ def build(
32
+ self,
33
+ onnx_path: str,
34
+ onnx_opt_path: str,
35
+ engine_path: str,
36
+ opt_image_height: int = 512,
37
+ opt_image_width: int = 512,
38
+ opt_batch_size: int = 1,
39
+ min_image_resolution: int = 256,
40
+ max_image_resolution: int = 1024,
41
+ build_enable_refit: bool = False,
42
+ build_static_batch: bool = False,
43
+ build_dynamic_shape: bool = False,
44
+ build_all_tactics: bool = False,
45
+ onnx_opset: int = 17,
46
+ force_engine_build: bool = False,
47
+ force_onnx_export: bool = False,
48
+ force_onnx_optimize: bool = False,
49
+ ):
50
+ if not force_onnx_export and os.path.exists(onnx_path):
51
+ print(f"Found cached model: {onnx_path}")
52
+ else:
53
+ print(f"Exporting model: {onnx_path}")
54
+ export_onnx(
55
+ self.network,
56
+ onnx_path=onnx_path,
57
+ model_data=self.model,
58
+ opt_image_height=opt_image_height,
59
+ opt_image_width=opt_image_width,
60
+ opt_batch_size=opt_batch_size,
61
+ onnx_opset=onnx_opset,
62
+ )
63
+ del self.network
64
+ gc.collect()
65
+ torch.cuda.empty_cache()
66
+ if not force_onnx_optimize and os.path.exists(onnx_opt_path):
67
+ print(f"Found cached model: {onnx_opt_path}")
68
+ else:
69
+ print(f"Generating optimizing model: {onnx_opt_path}")
70
+ optimize_onnx(
71
+ onnx_path=onnx_path,
72
+ onnx_opt_path=onnx_opt_path,
73
+ model_data=self.model,
74
+ )
75
+ self.model.min_latent_shape = min_image_resolution // 8
76
+ self.model.max_latent_shape = max_image_resolution // 8
77
+ if not force_engine_build and os.path.exists(engine_path):
78
+ print(f"Found cached engine: {engine_path}")
79
+ else:
80
+ build_engine(
81
+ engine_path=engine_path,
82
+ onnx_opt_path=onnx_opt_path,
83
+ model_data=self.model,
84
+ opt_image_height=opt_image_height,
85
+ opt_image_width=opt_image_width,
86
+ opt_batch_size=opt_batch_size,
87
+ build_static_batch=build_static_batch,
88
+ build_dynamic_shape=build_dynamic_shape,
89
+ build_all_tactics=build_all_tactics,
90
+ build_enable_refit=build_enable_refit,
91
+ )
92
+
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
streamv2v/acceleration/tensorrt/engine.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput
5
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
6
+ from diffusers.models.vae import DecoderOutput
7
+ from polygraphy import cuda
8
+
9
+ from .utilities import Engine
10
+
11
+
12
+ class UNet2DConditionModelEngine:
13
+ def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False):
14
+ self.engine = Engine(filepath)
15
+ self.stream = stream
16
+ self.use_cuda_graph = use_cuda_graph
17
+
18
+ self.engine.load()
19
+ self.engine.activate()
20
+
21
+ def __call__(
22
+ self,
23
+ latent_model_input: torch.Tensor,
24
+ timestep: torch.Tensor,
25
+ encoder_hidden_states: torch.Tensor,
26
+ **kwargs,
27
+ ) -> Any:
28
+ if timestep.dtype != torch.float32:
29
+ timestep = timestep.float()
30
+
31
+ self.engine.allocate_buffers(
32
+ shape_dict={
33
+ "sample": latent_model_input.shape,
34
+ "timestep": timestep.shape,
35
+ "encoder_hidden_states": encoder_hidden_states.shape,
36
+ "latent": latent_model_input.shape,
37
+ },
38
+ device=latent_model_input.device,
39
+ )
40
+
41
+ noise_pred = self.engine.infer(
42
+ {
43
+ "sample": latent_model_input,
44
+ "timestep": timestep,
45
+ "encoder_hidden_states": encoder_hidden_states,
46
+ },
47
+ self.stream,
48
+ use_cuda_graph=self.use_cuda_graph,
49
+ )["latent"]
50
+ return UNet2DConditionOutput(sample=noise_pred)
51
+
52
+ def to(self, *args, **kwargs):
53
+ pass
54
+
55
+ def forward(self, *args, **kwargs):
56
+ pass
57
+
58
+
59
+ class AutoencoderKLEngine:
60
+ def __init__(
61
+ self,
62
+ encoder_path: str,
63
+ decoder_path: str,
64
+ stream: cuda.Stream,
65
+ scaling_factor: int,
66
+ use_cuda_graph: bool = False,
67
+ ):
68
+ self.encoder = Engine(encoder_path)
69
+ self.decoder = Engine(decoder_path)
70
+ self.stream = stream
71
+ self.vae_scale_factor = scaling_factor
72
+ self.use_cuda_graph = use_cuda_graph
73
+
74
+ self.encoder.load()
75
+ self.decoder.load()
76
+ self.encoder.activate()
77
+ self.decoder.activate()
78
+
79
+ def encode(self, images: torch.Tensor, **kwargs):
80
+ self.encoder.allocate_buffers(
81
+ shape_dict={
82
+ "images": images.shape,
83
+ "latent": (
84
+ images.shape[0],
85
+ 4,
86
+ images.shape[2] // self.vae_scale_factor,
87
+ images.shape[3] // self.vae_scale_factor,
88
+ ),
89
+ },
90
+ device=images.device,
91
+ )
92
+ latents = self.encoder.infer(
93
+ {"images": images},
94
+ self.stream,
95
+ use_cuda_graph=self.use_cuda_graph,
96
+ )["latent"]
97
+ return AutoencoderTinyOutput(latents=latents)
98
+
99
+ def decode(self, latent: torch.Tensor, **kwargs):
100
+ self.decoder.allocate_buffers(
101
+ shape_dict={
102
+ "latent": latent.shape,
103
+ "images": (
104
+ latent.shape[0],
105
+ 3,
106
+ latent.shape[2] * self.vae_scale_factor,
107
+ latent.shape[3] * self.vae_scale_factor,
108
+ ),
109
+ },
110
+ device=latent.device,
111
+ )
112
+ images = self.decoder.infer(
113
+ {"latent": latent},
114
+ self.stream,
115
+ use_cuda_graph=self.use_cuda_graph,
116
+ )["images"]
117
+ return DecoderOutput(sample=images)
118
+
119
+ def to(self, *args, **kwargs):
120
+ pass
121
+
122
+ def forward(self, *args, **kwargs):
123
+ pass
streamv2v/acceleration/tensorrt/models.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/models.py
2
+
3
+ #
4
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+ import onnx_graphsurgeon as gs
21
+ import torch
22
+ from onnx import shape_inference
23
+ from polygraphy.backend.onnx.loader import fold_constants
24
+
25
+
26
+ class Optimizer:
27
+ def __init__(self, onnx_graph, verbose=False):
28
+ self.graph = gs.import_onnx(onnx_graph)
29
+ self.verbose = verbose
30
+
31
+ def info(self, prefix):
32
+ if self.verbose:
33
+ print(
34
+ f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs"
35
+ )
36
+
37
+ def cleanup(self, return_onnx=False):
38
+ self.graph.cleanup().toposort()
39
+ if return_onnx:
40
+ return gs.export_onnx(self.graph)
41
+
42
+ def select_outputs(self, keep, names=None):
43
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
44
+ if names:
45
+ for i, name in enumerate(names):
46
+ self.graph.outputs[i].name = name
47
+
48
+ def fold_constants(self, return_onnx=False):
49
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
50
+ self.graph = gs.import_onnx(onnx_graph)
51
+ if return_onnx:
52
+ return onnx_graph
53
+
54
+ def infer_shapes(self, return_onnx=False):
55
+ onnx_graph = gs.export_onnx(self.graph)
56
+ if onnx_graph.ByteSize() > 2147483648:
57
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
58
+ else:
59
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
60
+
61
+ self.graph = gs.import_onnx(onnx_graph)
62
+ if return_onnx:
63
+ return onnx_graph
64
+
65
+
66
+ class BaseModel:
67
+ def __init__(
68
+ self,
69
+ fp16=False,
70
+ device="cuda",
71
+ verbose=True,
72
+ max_batch_size=16,
73
+ min_batch_size=1,
74
+ embedding_dim=768,
75
+ text_maxlen=77,
76
+ ):
77
+ self.name = "SD Model"
78
+ self.fp16 = fp16
79
+ self.device = device
80
+ self.verbose = verbose
81
+
82
+ self.min_batch = min_batch_size
83
+ self.max_batch = max_batch_size
84
+ self.min_image_shape = 256 # min image resolution: 256x256
85
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
86
+ self.min_latent_shape = self.min_image_shape // 8
87
+ self.max_latent_shape = self.max_image_shape // 8
88
+
89
+ self.embedding_dim = embedding_dim
90
+ self.text_maxlen = text_maxlen
91
+
92
+ def get_model(self):
93
+ pass
94
+
95
+ def get_input_names(self):
96
+ pass
97
+
98
+ def get_output_names(self):
99
+ pass
100
+
101
+ def get_dynamic_axes(self):
102
+ return None
103
+
104
+ def get_sample_input(self, batch_size, image_height, image_width):
105
+ pass
106
+
107
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
108
+ return None
109
+
110
+ def get_shape_dict(self, batch_size, image_height, image_width):
111
+ return None
112
+
113
+ def optimize(self, onnx_graph):
114
+ opt = Optimizer(onnx_graph, verbose=self.verbose)
115
+ opt.info(self.name + ": original")
116
+ opt.cleanup()
117
+ opt.info(self.name + ": cleanup")
118
+ opt.fold_constants()
119
+ opt.info(self.name + ": fold constants")
120
+ opt.infer_shapes()
121
+ opt.info(self.name + ": shape inference")
122
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
123
+ opt.info(self.name + ": finished")
124
+ return onnx_opt_graph
125
+
126
+ def check_dims(self, batch_size, image_height, image_width):
127
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
128
+ assert image_height % 8 == 0 or image_width % 8 == 0
129
+ latent_height = image_height // 8
130
+ latent_width = image_width // 8
131
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
132
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
133
+ return (latent_height, latent_width)
134
+
135
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
136
+ min_batch = batch_size if static_batch else self.min_batch
137
+ max_batch = batch_size if static_batch else self.max_batch
138
+ latent_height = image_height // 8
139
+ latent_width = image_width // 8
140
+ min_image_height = image_height if static_shape else self.min_image_shape
141
+ max_image_height = image_height if static_shape else self.max_image_shape
142
+ min_image_width = image_width if static_shape else self.min_image_shape
143
+ max_image_width = image_width if static_shape else self.max_image_shape
144
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
145
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
146
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
147
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
148
+ return (
149
+ min_batch,
150
+ max_batch,
151
+ min_image_height,
152
+ max_image_height,
153
+ min_image_width,
154
+ max_image_width,
155
+ min_latent_height,
156
+ max_latent_height,
157
+ min_latent_width,
158
+ max_latent_width,
159
+ )
160
+
161
+
162
+ class CLIP(BaseModel):
163
+ def __init__(self, device, max_batch_size, embedding_dim, min_batch_size=1):
164
+ super(CLIP, self).__init__(
165
+ device=device,
166
+ max_batch_size=max_batch_size,
167
+ min_batch_size=min_batch_size,
168
+ embedding_dim=embedding_dim,
169
+ )
170
+ self.name = "CLIP"
171
+
172
+ def get_input_names(self):
173
+ return ["input_ids"]
174
+
175
+ def get_output_names(self):
176
+ return ["text_embeddings", "pooler_output"]
177
+
178
+ def get_dynamic_axes(self):
179
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
180
+
181
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
182
+ self.check_dims(batch_size, image_height, image_width)
183
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
184
+ batch_size, image_height, image_width, static_batch, static_shape
185
+ )
186
+ return {
187
+ "input_ids": [
188
+ (min_batch, self.text_maxlen),
189
+ (batch_size, self.text_maxlen),
190
+ (max_batch, self.text_maxlen),
191
+ ]
192
+ }
193
+
194
+ def get_shape_dict(self, batch_size, image_height, image_width):
195
+ self.check_dims(batch_size, image_height, image_width)
196
+ return {
197
+ "input_ids": (batch_size, self.text_maxlen),
198
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
199
+ }
200
+
201
+ def get_sample_input(self, batch_size, image_height, image_width):
202
+ self.check_dims(batch_size, image_height, image_width)
203
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
204
+
205
+ def optimize(self, onnx_graph):
206
+ opt = Optimizer(onnx_graph)
207
+ opt.info(self.name + ": original")
208
+ opt.select_outputs([0]) # delete graph output#1
209
+ opt.cleanup()
210
+ opt.info(self.name + ": remove output[1]")
211
+ opt.fold_constants()
212
+ opt.info(self.name + ": fold constants")
213
+ opt.infer_shapes()
214
+ opt.info(self.name + ": shape inference")
215
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
216
+ opt.info(self.name + ": remove output[0]")
217
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
218
+ opt.info(self.name + ": finished")
219
+ return opt_onnx_graph
220
+
221
+
222
+ class UNet(BaseModel):
223
+ def __init__(
224
+ self,
225
+ fp16=False,
226
+ device="cuda",
227
+ max_batch_size=16,
228
+ min_batch_size=1,
229
+ embedding_dim=768,
230
+ text_maxlen=77,
231
+ unet_dim=4,
232
+ ):
233
+ super(UNet, self).__init__(
234
+ fp16=fp16,
235
+ device=device,
236
+ max_batch_size=max_batch_size,
237
+ min_batch_size=min_batch_size,
238
+ embedding_dim=embedding_dim,
239
+ text_maxlen=text_maxlen,
240
+ )
241
+ self.unet_dim = unet_dim
242
+ self.name = "UNet"
243
+
244
+ def get_input_names(self):
245
+ return ["sample", "timestep", "encoder_hidden_states"]
246
+
247
+ def get_output_names(self):
248
+ return ["latent"]
249
+
250
+ def get_dynamic_axes(self):
251
+ return {
252
+ "sample": {0: "2B", 2: "H", 3: "W"},
253
+ "timestep": {0: "2B"},
254
+ "encoder_hidden_states": {0: "2B"},
255
+ "latent": {0: "2B", 2: "H", 3: "W"},
256
+ }
257
+
258
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
259
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
260
+ (
261
+ min_batch,
262
+ max_batch,
263
+ _,
264
+ _,
265
+ _,
266
+ _,
267
+ min_latent_height,
268
+ max_latent_height,
269
+ min_latent_width,
270
+ max_latent_width,
271
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
272
+ return {
273
+ "sample": [
274
+ (min_batch, self.unet_dim, min_latent_height, min_latent_width),
275
+ (batch_size, self.unet_dim, latent_height, latent_width),
276
+ (max_batch, self.unet_dim, max_latent_height, max_latent_width),
277
+ ],
278
+ "timestep": [(min_batch,), (batch_size,), (max_batch,)],
279
+ "encoder_hidden_states": [
280
+ (min_batch, self.text_maxlen, self.embedding_dim),
281
+ (batch_size, self.text_maxlen, self.embedding_dim),
282
+ (max_batch, self.text_maxlen, self.embedding_dim),
283
+ ],
284
+ }
285
+
286
+ def get_shape_dict(self, batch_size, image_height, image_width):
287
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
288
+ return {
289
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
290
+ "timestep": (2 * batch_size,),
291
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
292
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
293
+ }
294
+
295
+ def get_sample_input(self, batch_size, image_height, image_width):
296
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
297
+ dtype = torch.float16 if self.fp16 else torch.float32
298
+ return (
299
+ torch.randn(
300
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
301
+ ),
302
+ torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device),
303
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
304
+ )
305
+
306
+
307
+ class VAE(BaseModel):
308
+ def __init__(self, device, max_batch_size, min_batch_size=1):
309
+ super(VAE, self).__init__(
310
+ device=device,
311
+ max_batch_size=max_batch_size,
312
+ min_batch_size=min_batch_size,
313
+ embedding_dim=None,
314
+ )
315
+ self.name = "VAE decoder"
316
+
317
+ def get_input_names(self):
318
+ return ["latent"]
319
+
320
+ def get_output_names(self):
321
+ return ["images"]
322
+
323
+ def get_dynamic_axes(self):
324
+ return {
325
+ "latent": {0: "B", 2: "H", 3: "W"},
326
+ "images": {0: "B", 2: "8H", 3: "8W"},
327
+ }
328
+
329
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
330
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
331
+ (
332
+ min_batch,
333
+ max_batch,
334
+ _,
335
+ _,
336
+ _,
337
+ _,
338
+ min_latent_height,
339
+ max_latent_height,
340
+ min_latent_width,
341
+ max_latent_width,
342
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
343
+ return {
344
+ "latent": [
345
+ (min_batch, 4, min_latent_height, min_latent_width),
346
+ (batch_size, 4, latent_height, latent_width),
347
+ (max_batch, 4, max_latent_height, max_latent_width),
348
+ ]
349
+ }
350
+
351
+ def get_shape_dict(self, batch_size, image_height, image_width):
352
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
353
+ return {
354
+ "latent": (batch_size, 4, latent_height, latent_width),
355
+ "images": (batch_size, 3, image_height, image_width),
356
+ }
357
+
358
+ def get_sample_input(self, batch_size, image_height, image_width):
359
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
360
+ return torch.randn(
361
+ batch_size,
362
+ 4,
363
+ latent_height,
364
+ latent_width,
365
+ dtype=torch.float32,
366
+ device=self.device,
367
+ )
368
+
369
+
370
+ class VAEEncoder(BaseModel):
371
+ def __init__(self, device, max_batch_size, min_batch_size=1):
372
+ super(VAEEncoder, self).__init__(
373
+ device=device,
374
+ max_batch_size=max_batch_size,
375
+ min_batch_size=min_batch_size,
376
+ embedding_dim=None,
377
+ )
378
+ self.name = "VAE encoder"
379
+
380
+ def get_input_names(self):
381
+ return ["images"]
382
+
383
+ def get_output_names(self):
384
+ return ["latent"]
385
+
386
+ def get_dynamic_axes(self):
387
+ return {
388
+ "images": {0: "B", 2: "8H", 3: "8W"},
389
+ "latent": {0: "B", 2: "H", 3: "W"},
390
+ }
391
+
392
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
393
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
394
+ min_batch = batch_size if static_batch else self.min_batch
395
+ max_batch = batch_size if static_batch else self.max_batch
396
+ self.check_dims(batch_size, image_height, image_width)
397
+ (
398
+ min_batch,
399
+ max_batch,
400
+ min_image_height,
401
+ max_image_height,
402
+ min_image_width,
403
+ max_image_width,
404
+ _,
405
+ _,
406
+ _,
407
+ _,
408
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
409
+
410
+ return {
411
+ "images": [
412
+ (min_batch, 3, min_image_height, min_image_width),
413
+ (batch_size, 3, image_height, image_width),
414
+ (max_batch, 3, max_image_height, max_image_width),
415
+ ],
416
+ }
417
+
418
+ def get_shape_dict(self, batch_size, image_height, image_width):
419
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
420
+ return {
421
+ "images": (batch_size, 3, image_height, image_width),
422
+ "latent": (batch_size, 4, latent_height, latent_width),
423
+ }
424
+
425
+ def get_sample_input(self, batch_size, image_height, image_width):
426
+ self.check_dims(batch_size, image_height, image_width)
427
+ return torch.randn(
428
+ batch_size,
429
+ 3,
430
+ image_height,
431
+ image_width,
432
+ dtype=torch.float32,
433
+ device=self.device,
434
+ )
streamv2v/acceleration/tensorrt/utilities.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
2
+
3
+ #
4
+ # Copyright 2022 The HuggingFace Inc. team.
5
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6
+ # SPDX-License-Identifier: Apache-2.0
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+
21
+ import gc
22
+ from collections import OrderedDict
23
+ from typing import *
24
+
25
+ import numpy as np
26
+ import onnx
27
+ import onnx_graphsurgeon as gs
28
+ import tensorrt as trt
29
+ import torch
30
+ from cuda import cudart
31
+ from PIL import Image
32
+ from polygraphy import cuda
33
+ from polygraphy.backend.common import bytes_from_path
34
+ from polygraphy.backend.trt import (
35
+ CreateConfig,
36
+ Profile,
37
+ engine_from_bytes,
38
+ engine_from_network,
39
+ network_from_onnx_path,
40
+ save_engine,
41
+ )
42
+ from polygraphy.backend.trt import util as trt_util
43
+
44
+ from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder
45
+
46
+
47
+ TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
48
+
49
+ # Map of numpy dtype -> torch dtype
50
+ numpy_to_torch_dtype_dict = {
51
+ np.uint8: torch.uint8,
52
+ np.int8: torch.int8,
53
+ np.int16: torch.int16,
54
+ np.int32: torch.int32,
55
+ np.int64: torch.int64,
56
+ np.float16: torch.float16,
57
+ np.float32: torch.float32,
58
+ np.float64: torch.float64,
59
+ np.complex64: torch.complex64,
60
+ np.complex128: torch.complex128,
61
+ }
62
+ if np.version.full_version >= "1.24.0":
63
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
64
+ else:
65
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
66
+
67
+ # Map of torch dtype -> numpy dtype
68
+ torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
69
+
70
+
71
+ def CUASSERT(cuda_ret):
72
+ err = cuda_ret[0]
73
+ if err != cudart.cudaError_t.cudaSuccess:
74
+ raise RuntimeError(
75
+ f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
76
+ )
77
+ if len(cuda_ret) > 1:
78
+ return cuda_ret[1]
79
+ return None
80
+
81
+
82
+ class Engine:
83
+ def __init__(
84
+ self,
85
+ engine_path,
86
+ ):
87
+ self.engine_path = engine_path
88
+ self.engine = None
89
+ self.context = None
90
+ self.buffers = OrderedDict()
91
+ self.tensors = OrderedDict()
92
+ self.cuda_graph_instance = None # cuda graph
93
+
94
+ def __del__(self):
95
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
96
+ del self.engine
97
+ del self.context
98
+ del self.buffers
99
+ del self.tensors
100
+
101
+ def refit(self, onnx_path, onnx_refit_path):
102
+ def convert_int64(arr):
103
+ # TODO: smarter conversion
104
+ if len(arr.shape) == 0:
105
+ return np.int32(arr)
106
+ return arr
107
+
108
+ def add_to_map(refit_dict, name, values):
109
+ if name in refit_dict:
110
+ assert refit_dict[name] is None
111
+ if values.dtype == np.int64:
112
+ values = convert_int64(values)
113
+ refit_dict[name] = values
114
+
115
+ print(f"Refitting TensorRT engine with {onnx_refit_path} weights")
116
+ refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes
117
+
118
+ # Construct mapping from weight names in refit model -> original model
119
+ name_map = {}
120
+ for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes):
121
+ refit_node = refit_nodes[n]
122
+ assert node.op == refit_node.op
123
+ # Constant nodes in ONNX do not have inputs but have a constant output
124
+ if node.op == "Constant":
125
+ name_map[refit_node.outputs[0].name] = node.outputs[0].name
126
+ # Handle scale and bias weights
127
+ elif node.op == "Conv":
128
+ if node.inputs[1].__class__ == gs.Constant:
129
+ name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL"
130
+ if node.inputs[2].__class__ == gs.Constant:
131
+ name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS"
132
+ # For all other nodes: find node inputs that are initializers (gs.Constant)
133
+ else:
134
+ for i, inp in enumerate(node.inputs):
135
+ if inp.__class__ == gs.Constant:
136
+ name_map[refit_node.inputs[i].name] = inp.name
137
+
138
+ def map_name(name):
139
+ if name in name_map:
140
+ return name_map[name]
141
+ return name
142
+
143
+ # Construct refit dictionary
144
+ refit_dict = {}
145
+ refitter = trt.Refitter(self.engine, TRT_LOGGER)
146
+ all_weights = refitter.get_all()
147
+ for layer_name, role in zip(all_weights[0], all_weights[1]):
148
+ # for speciailized roles, use a unique name in the map:
149
+ if role == trt.WeightsRole.KERNEL:
150
+ name = layer_name + "_TRTKERNEL"
151
+ elif role == trt.WeightsRole.BIAS:
152
+ name = layer_name + "_TRTBIAS"
153
+ else:
154
+ name = layer_name
155
+
156
+ assert name not in refit_dict, "Found duplicate layer: " + name
157
+ refit_dict[name] = None
158
+
159
+ for n in refit_nodes:
160
+ # Constant nodes in ONNX do not have inputs but have a constant output
161
+ if n.op == "Constant":
162
+ name = map_name(n.outputs[0].name)
163
+ print(f"Add Constant {name}\n")
164
+ add_to_map(refit_dict, name, n.outputs[0].values)
165
+
166
+ # Handle scale and bias weights
167
+ elif n.op == "Conv":
168
+ if n.inputs[1].__class__ == gs.Constant:
169
+ name = map_name(n.name + "_TRTKERNEL")
170
+ add_to_map(refit_dict, name, n.inputs[1].values)
171
+
172
+ if n.inputs[2].__class__ == gs.Constant:
173
+ name = map_name(n.name + "_TRTBIAS")
174
+ add_to_map(refit_dict, name, n.inputs[2].values)
175
+
176
+ # For all other nodes: find node inputs that are initializers (AKA gs.Constant)
177
+ else:
178
+ for inp in n.inputs:
179
+ name = map_name(inp.name)
180
+ if inp.__class__ == gs.Constant:
181
+ add_to_map(refit_dict, name, inp.values)
182
+
183
+ for layer_name, weights_role in zip(all_weights[0], all_weights[1]):
184
+ if weights_role == trt.WeightsRole.KERNEL:
185
+ custom_name = layer_name + "_TRTKERNEL"
186
+ elif weights_role == trt.WeightsRole.BIAS:
187
+ custom_name = layer_name + "_TRTBIAS"
188
+ else:
189
+ custom_name = layer_name
190
+
191
+ # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model
192
+ if layer_name.startswith("onnx::Trilu"):
193
+ continue
194
+
195
+ if refit_dict[custom_name] is not None:
196
+ refitter.set_weights(layer_name, weights_role, refit_dict[custom_name])
197
+ else:
198
+ print(f"[W] No refit weights for layer: {layer_name}")
199
+
200
+ if not refitter.refit_cuda_engine():
201
+ print("Failed to refit!")
202
+ exit(0)
203
+
204
+ def build(
205
+ self,
206
+ onnx_path,
207
+ fp16,
208
+ input_profile=None,
209
+ enable_refit=False,
210
+ enable_all_tactics=False,
211
+ timing_cache=None,
212
+ workspace_size=0,
213
+ ):
214
+ print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
215
+ p = Profile()
216
+ if input_profile:
217
+ for name, dims in input_profile.items():
218
+ assert len(dims) == 3
219
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
220
+
221
+ config_kwargs = {}
222
+
223
+ if workspace_size > 0:
224
+ config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
225
+ if not enable_all_tactics:
226
+ config_kwargs["tactic_sources"] = []
227
+
228
+ engine = engine_from_network(
229
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
230
+ config=CreateConfig(
231
+ fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
232
+ ),
233
+ save_timing_cache=timing_cache,
234
+ )
235
+ save_engine(engine, path=self.engine_path)
236
+
237
+ def load(self):
238
+ print(f"Loading TensorRT engine: {self.engine_path}")
239
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
240
+
241
+ def activate(self, reuse_device_memory=None):
242
+ if reuse_device_memory:
243
+ self.context = self.engine.create_execution_context_without_device_memory()
244
+ self.context.device_memory = reuse_device_memory
245
+ else:
246
+ self.context = self.engine.create_execution_context()
247
+
248
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
249
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
250
+ binding = self.engine[idx]
251
+ if shape_dict and binding in shape_dict:
252
+ shape = shape_dict[binding]
253
+ else:
254
+ shape = self.engine.get_binding_shape(binding)
255
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
256
+ if self.engine.binding_is_input(binding):
257
+ self.context.set_binding_shape(idx, shape)
258
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
259
+ self.tensors[binding] = tensor
260
+
261
+ def infer(self, feed_dict, stream, use_cuda_graph=False):
262
+ for name, buf in feed_dict.items():
263
+ self.tensors[name].copy_(buf)
264
+
265
+ for name, tensor in self.tensors.items():
266
+ self.context.set_tensor_address(name, tensor.data_ptr())
267
+
268
+ if use_cuda_graph:
269
+ if self.cuda_graph_instance is not None:
270
+ CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr))
271
+ CUASSERT(cudart.cudaStreamSynchronize(stream.ptr))
272
+ else:
273
+ # do inference before CUDA graph capture
274
+ noerror = self.context.execute_async_v3(stream.ptr)
275
+ if not noerror:
276
+ raise ValueError("ERROR: inference failed.")
277
+ # capture cuda graph
278
+ CUASSERT(
279
+ cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
280
+ )
281
+ self.context.execute_async_v3(stream.ptr)
282
+ self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
283
+ self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
284
+ else:
285
+ noerror = self.context.execute_async_v3(stream.ptr)
286
+ if not noerror:
287
+ raise ValueError("ERROR: inference failed.")
288
+
289
+ return self.tensors
290
+
291
+
292
+ def decode_images(images: torch.Tensor):
293
+ images = (
294
+ ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
295
+ )
296
+ return [Image.fromarray(x) for x in images]
297
+
298
+
299
+ def preprocess_image(image: Image.Image):
300
+ w, h = image.size
301
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
302
+ image = image.resize((w, h))
303
+ init_image = np.array(image).astype(np.float32) / 255.0
304
+ init_image = init_image[None].transpose(0, 3, 1, 2)
305
+ init_image = torch.from_numpy(init_image).contiguous()
306
+ return 2.0 * init_image - 1.0
307
+
308
+
309
+ def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image):
310
+ if isinstance(image, Image.Image):
311
+ image = np.array(image.convert("RGB"))
312
+ image = image[None].transpose(0, 3, 1, 2)
313
+ image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0
314
+ if isinstance(mask, Image.Image):
315
+ mask = np.array(mask.convert("L"))
316
+ mask = mask.astype(np.float32) / 255.0
317
+ mask = mask[None, None]
318
+ mask[mask < 0.5] = 0
319
+ mask[mask >= 0.5] = 1
320
+ mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous()
321
+
322
+ masked_image = image * (mask < 0.5)
323
+
324
+ return mask, masked_image
325
+
326
+
327
+ def create_models(
328
+ model_id: str,
329
+ use_auth_token: Optional[str],
330
+ device: Union[str, torch.device],
331
+ max_batch_size: int,
332
+ unet_in_channels: int = 4,
333
+ embedding_dim: int = 768,
334
+ ):
335
+ models = {
336
+ "clip": CLIP(
337
+ hf_token=use_auth_token,
338
+ device=device,
339
+ max_batch_size=max_batch_size,
340
+ embedding_dim=embedding_dim,
341
+ ),
342
+ "unet": UNet(
343
+ hf_token=use_auth_token,
344
+ fp16=True,
345
+ device=device,
346
+ max_batch_size=max_batch_size,
347
+ embedding_dim=embedding_dim,
348
+ unet_dim=unet_in_channels,
349
+ ),
350
+ "vae": VAE(
351
+ hf_token=use_auth_token,
352
+ device=device,
353
+ max_batch_size=max_batch_size,
354
+ embedding_dim=embedding_dim,
355
+ ),
356
+ "vae_encoder": VAEEncoder(
357
+ hf_token=use_auth_token,
358
+ device=device,
359
+ max_batch_size=max_batch_size,
360
+ embedding_dim=embedding_dim,
361
+ ),
362
+ }
363
+ return models
364
+
365
+
366
+ def build_engine(
367
+ engine_path: str,
368
+ onnx_opt_path: str,
369
+ model_data: BaseModel,
370
+ opt_image_height: int,
371
+ opt_image_width: int,
372
+ opt_batch_size: int,
373
+ build_static_batch: bool = False,
374
+ build_dynamic_shape: bool = False,
375
+ build_all_tactics: bool = False,
376
+ build_enable_refit: bool = False,
377
+ ):
378
+ _, free_mem, _ = cudart.cudaMemGetInfo()
379
+ GiB = 2**30
380
+ if free_mem > 6 * GiB:
381
+ activation_carveout = 4 * GiB
382
+ max_workspace_size = free_mem - activation_carveout
383
+ else:
384
+ max_workspace_size = 0
385
+ engine = Engine(engine_path)
386
+ input_profile = model_data.get_input_profile(
387
+ opt_batch_size,
388
+ opt_image_height,
389
+ opt_image_width,
390
+ static_batch=build_static_batch,
391
+ static_shape=not build_dynamic_shape,
392
+ )
393
+ engine.build(
394
+ onnx_opt_path,
395
+ fp16=True,
396
+ input_profile=input_profile,
397
+ enable_refit=build_enable_refit,
398
+ enable_all_tactics=build_all_tactics,
399
+ workspace_size=max_workspace_size,
400
+ )
401
+
402
+ return engine
403
+
404
+
405
+ def export_onnx(
406
+ model,
407
+ onnx_path: str,
408
+ model_data: BaseModel,
409
+ opt_image_height: int,
410
+ opt_image_width: int,
411
+ opt_batch_size: int,
412
+ onnx_opset: int,
413
+ ):
414
+ with torch.inference_mode(), torch.autocast("cuda"):
415
+ inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
416
+ torch.onnx.export(
417
+ model,
418
+ inputs,
419
+ onnx_path,
420
+ export_params=True,
421
+ opset_version=onnx_opset,
422
+ do_constant_folding=True,
423
+ input_names=model_data.get_input_names(),
424
+ output_names=model_data.get_output_names(),
425
+ dynamic_axes=model_data.get_dynamic_axes(),
426
+ )
427
+ del model
428
+ gc.collect()
429
+ torch.cuda.empty_cache()
430
+
431
+
432
+ def optimize_onnx(
433
+ onnx_path: str,
434
+ onnx_opt_path: str,
435
+ model_data: BaseModel,
436
+ ):
437
+ onnx_opt_graph = model_data.optimize(onnx.load(onnx_path))
438
+ onnx.save(onnx_opt_graph, onnx_opt_path)
439
+ del onnx_opt_graph
440
+ gc.collect()
441
+ torch.cuda.empty_cache()
streamv2v/image_filter.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import random
3
+
4
+ import torch
5
+
6
+
7
+ class SimilarImageFilter:
8
+ def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
9
+ self.threshold = threshold
10
+ self.prev_tensor = None
11
+ self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
12
+ self.max_skip_frame = max_skip_frame
13
+ self.skip_count = 0
14
+
15
+ def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]:
16
+ if self.prev_tensor is None:
17
+ self.prev_tensor = x.detach().clone()
18
+ return x
19
+ else:
20
+ cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item()
21
+ sample = random.uniform(0, 1)
22
+ if self.threshold >= 1:
23
+ skip_prob = 0
24
+ else:
25
+ skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold))
26
+
27
+ # not skip frame
28
+ if skip_prob < sample:
29
+ self.prev_tensor = x.detach().clone()
30
+ return x
31
+ # skip frame
32
+ else:
33
+ if self.skip_count > self.max_skip_frame:
34
+ self.skip_count = 0
35
+ self.prev_tensor = x.detach().clone()
36
+ return x
37
+ else:
38
+ self.skip_count += 1
39
+ return None
40
+
41
+ def set_threshold(self, threshold: float) -> None:
42
+ self.threshold = threshold
43
+
44
+ def set_max_skip_frame(self, max_skip_frame: float) -> None:
45
+ self.max_skip_frame = max_skip_frame
streamv2v/image_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision
7
+
8
+
9
+ def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
10
+ """
11
+ Denormalize an image array to [0,1].
12
+ """
13
+ return (images / 2 + 0.5).clamp(0, 1)
14
+
15
+
16
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
17
+ """
18
+ Convert a PyTorch tensor to a NumPy image.
19
+ """
20
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
21
+ return images
22
+
23
+
24
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
25
+ """
26
+ Convert a NumPy image or a batch of images to a PIL image.
27
+ """
28
+ if images.ndim == 3:
29
+ images = images[None, ...]
30
+ images = (images * 255).round().astype("uint8")
31
+ if images.shape[-1] == 1:
32
+ # special case for grayscale (single channel) images
33
+ pil_images = [
34
+ PIL.Image.fromarray(image.squeeze(), mode="L") for image in images
35
+ ]
36
+ else:
37
+ pil_images = [PIL.Image.fromarray(image) for image in images]
38
+
39
+ return pil_images
40
+
41
+
42
+ def postprocess_image(
43
+ image: torch.Tensor,
44
+ output_type: str = "pil",
45
+ do_denormalize: Optional[List[bool]] = None,
46
+ ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]:
47
+ if not isinstance(image, torch.Tensor):
48
+ raise ValueError(
49
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
50
+ )
51
+
52
+ if output_type == "latent":
53
+ return image
54
+
55
+ do_normalize_flg = True
56
+ if do_denormalize is None:
57
+ do_denormalize = [do_normalize_flg] * image.shape[0]
58
+
59
+ image = torch.stack(
60
+ [
61
+ denormalize(image[i]) if do_denormalize[i] else image[i]
62
+ for i in range(image.shape[0])
63
+ ]
64
+ )
65
+
66
+ if output_type == "pt":
67
+ return image
68
+
69
+ image = pt_to_numpy(image)
70
+
71
+ if output_type == "np":
72
+ return image
73
+
74
+ if output_type == "pil":
75
+ return numpy_to_pil(image)
76
+
77
+
78
+ def process_image(
79
+ image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1)
80
+ ) -> Tuple[torch.Tensor, PIL.Image.Image]:
81
+ image = torchvision.transforms.ToTensor()(image_pil)
82
+ r_min, r_max = range[0], range[1]
83
+ image = image * (r_max - r_min) + r_min
84
+ return image[None, ...], image_pil
85
+
86
+
87
+ def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor:
88
+ height = image_pil.height
89
+ width = image_pil.width
90
+ imgs = []
91
+ img, _ = process_image(image_pil)
92
+ imgs.append(img)
93
+ imgs = torch.vstack(imgs)
94
+ images = torch.nn.functional.interpolate(
95
+ imgs, size=(height, width), mode="bilinear"
96
+ )
97
+ image_tensors = images.to(torch.float16)
98
+ return image_tensors
99
+
100
+ ### Optical flow utils
101
+
102
+ def coords_grid(b, h, w, homogeneous=False, device=None):
103
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
104
+
105
+ stacks = [x, y]
106
+
107
+ if homogeneous:
108
+ ones = torch.ones_like(x) # [H, W]
109
+ stacks.append(ones)
110
+
111
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
112
+
113
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
114
+
115
+ if device is not None:
116
+ grid = grid.to(device)
117
+
118
+ return grid
119
+
120
+ def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
121
+ b, c, h, w = feature.size()
122
+ assert flow.size(1) == 2
123
+
124
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
125
+
126
+ return bilinear_sample(feature, grid, padding_mode=padding_mode,
127
+ return_mask=mask)
128
+
129
+ def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
130
+ # img: [B, C, H, W]
131
+ # sample_coords: [B, 2, H, W] in image scale
132
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
133
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
134
+
135
+ b, _, h, w = sample_coords.shape
136
+
137
+ # Normalize to [-1, 1]
138
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
139
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
140
+
141
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
142
+
143
+ img = torch.nn.functional.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
144
+
145
+ if return_mask:
146
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
147
+
148
+ return img, mask
149
+
150
+ return img
151
+
152
+ def forward_backward_consistency_check(fwd_flow, bwd_flow,
153
+ alpha=0.1,
154
+ beta=0.5
155
+ ):
156
+ # fwd_flow, bwd_flow: [B, 2, H, W]
157
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
158
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
159
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
160
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
161
+
162
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
163
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
164
+
165
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
166
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
167
+
168
+ threshold = alpha * flow_mag + beta
169
+
170
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
171
+ bwd_occ = (diff_bwd > threshold).float()
172
+
173
+ return fwd_occ, bwd_occ
streamv2v/models/__init__.py ADDED
File without changes
streamv2v/models/attention_processor.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Callable, Optional, Union
3
+ from collections import deque
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
13
+ from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
14
+
15
+ from .utils import get_nn_feats, random_bipartite_soft_matching
16
+
17
+ if is_xformers_available():
18
+ import xformers
19
+ import xformers.ops
20
+ else:
21
+ xformers = None
22
+
23
+ class CachedSTAttnProcessor2_0:
24
+ r"""
25
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
26
+ """
27
+
28
+ def __init__(self, name=None, use_feature_injection=False,
29
+ feature_injection_strength=0.8,
30
+ feature_similarity_threshold=0.98,
31
+ interval=4,
32
+ max_frames=1,
33
+ use_tome_cache=False,
34
+ tome_metric="keys",
35
+ use_grid=False,
36
+ tome_ratio=0.5):
37
+ if not hasattr(F, "scaled_dot_product_attention"):
38
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
39
+ self.name = name
40
+ self.use_feature_injection = use_feature_injection
41
+ self.fi_strength = feature_injection_strength
42
+ self.threshold = feature_similarity_threshold
43
+ self.zero_tensor = torch.tensor(0)
44
+ self.frame_id = torch.tensor(0)
45
+ self.interval = torch.tensor(interval)
46
+ self.max_frames = max_frames
47
+ self.cached_key = None
48
+ self.cached_value = None
49
+ self.cached_output = None
50
+ self.use_tome_cache = use_tome_cache
51
+ self.tome_metric = tome_metric
52
+ self.use_grid = use_grid
53
+ self.tome_ratio = tome_ratio
54
+
55
+ def _tome_step_kvout(self, keys, values, outputs):
56
+ keys = torch.cat([self.cached_key, keys], dim=1)
57
+ values = torch.cat([self.cached_value, values], dim=1)
58
+ outputs = torch.cat([self.cached_output, outputs], dim=1)
59
+ m_kv_out, _, _= random_bipartite_soft_matching(metric=keys, use_grid=self.use_grid, ratio=self.tome_ratio)
60
+ compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
61
+ self.cached_key = compact_keys
62
+ self.cached_value = compact_values
63
+ self.cached_output = compact_outputs
64
+
65
+ def __call__(
66
+ self,
67
+ attn: Attention,
68
+ hidden_states: torch.FloatTensor,
69
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
70
+ attention_mask: Optional[torch.FloatTensor] = None,
71
+ temb: Optional[torch.FloatTensor] = None,
72
+ scale: float = 1.0,
73
+ ) -> torch.FloatTensor:
74
+ residual = hidden_states
75
+ if attn.spatial_norm is not None:
76
+ hidden_states = attn.spatial_norm(hidden_states, temb)
77
+
78
+ input_ndim = hidden_states.ndim
79
+
80
+ if input_ndim == 4:
81
+ batch_size, channel, height, width = hidden_states.shape
82
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
83
+
84
+ batch_size, sequence_length, _ = (
85
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
86
+ )
87
+
88
+ if attention_mask is not None:
89
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
90
+ # scaled_dot_product_attention expects attention_mask shape to be
91
+ # (batch, heads, source_length, target_length)
92
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
93
+
94
+ if attn.group_norm is not None:
95
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
96
+
97
+ args = () if USE_PEFT_BACKEND else (scale,)
98
+ query = attn.to_q(hidden_states, *args)
99
+
100
+ is_selfattn = False
101
+ if encoder_hidden_states is None:
102
+ is_selfattn = True
103
+ encoder_hidden_states = hidden_states
104
+ elif attn.norm_cross:
105
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
106
+
107
+ key = attn.to_k(encoder_hidden_states, *args)
108
+ value = attn.to_v(encoder_hidden_states, *args)
109
+
110
+ if is_selfattn:
111
+ cached_key = key.clone()
112
+ cached_value = value.clone()
113
+
114
+ # Avoid if statement -> replace the dynamic graph to static graph
115
+ if torch.equal(self.frame_id, self.zero_tensor):
116
+ # ONNX
117
+ self.cached_key = cached_key
118
+ self.cached_value = cached_value
119
+
120
+ key = torch.cat([key, self.cached_key], dim=1)
121
+ value = torch.cat([value, self.cached_value], dim=1)
122
+
123
+ inner_dim = key.shape[-1]
124
+ head_dim = inner_dim // attn.heads
125
+
126
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
+
128
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
129
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
130
+
131
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
132
+ # TODO: add support for attn.scale when we move to Torch 2.1
133
+ hidden_states = F.scaled_dot_product_attention(
134
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
135
+ )
136
+
137
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
138
+ hidden_states = hidden_states.to(query.dtype)
139
+
140
+ # linear proj
141
+ hidden_states = attn.to_out[0](hidden_states, *args)
142
+ # dropout
143
+ hidden_states = attn.to_out[1](hidden_states)
144
+
145
+ if input_ndim == 4:
146
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
147
+
148
+ if attn.residual_connection:
149
+ hidden_states = hidden_states + residual
150
+
151
+ hidden_states = hidden_states / attn.rescale_output_factor
152
+
153
+ if is_selfattn:
154
+ cached_output = hidden_states.clone()
155
+
156
+ if torch.equal(self.frame_id, self.zero_tensor):
157
+ self.cached_output = cached_output
158
+
159
+ if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
160
+ nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
161
+ hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
162
+
163
+ mod_result = torch.remainder(self.frame_id, self.interval)
164
+ if torch.equal(mod_result, self.zero_tensor) and is_selfattn:
165
+ self._tome_step_kvout(cached_key, cached_value, cached_output)
166
+
167
+ self.frame_id = self.frame_id + 1
168
+
169
+ return hidden_states
170
+
171
+
172
+
173
+ class CachedSTXFormersAttnProcessor:
174
+ r"""
175
+ Processor for implementing memory efficient attention using xFormers.
176
+
177
+ Args:
178
+ attention_op (`Callable`, *optional*, defaults to `None`):
179
+ The base
180
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
181
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
182
+ operator.
183
+ """
184
+
185
+ def __init__(self, attention_op: Optional[Callable] = None, name=None,
186
+ use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98,
187
+ interval=4, max_frames=4, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5):
188
+ self.attention_op = attention_op
189
+ self.name = name
190
+ self.use_feature_injection = use_feature_injection
191
+ self.fi_strength = feature_injection_strength
192
+ self.threshold = feature_similarity_threshold
193
+ self.frame_id = 0
194
+ self.interval = interval
195
+ self.cached_key = deque(maxlen=max_frames)
196
+ self.cached_value = deque(maxlen=max_frames)
197
+ self.cached_output = deque(maxlen=max_frames)
198
+ self.use_tome_cache = use_tome_cache
199
+ self.tome_metric = tome_metric
200
+ self.use_grid = use_grid
201
+ self.tome_ratio = tome_ratio
202
+
203
+ def _tome_step_kvout(self, keys, values, outputs):
204
+ if len(self.cached_value) == 1:
205
+ keys = torch.cat(list(self.cached_key) + [keys], dim=1)
206
+ values = torch.cat(list(self.cached_value) + [values], dim=1)
207
+ outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
208
+ m_kv_out, _, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
209
+ compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
210
+ self.cached_key.append(compact_keys)
211
+ self.cached_value.append(compact_values)
212
+ self.cached_output.append(compact_outputs)
213
+ else:
214
+ self.cached_key.append(keys)
215
+ self.cached_value.append(values)
216
+ self.cached_output.append(outputs)
217
+
218
+ def _tome_step_kv(self, keys, values):
219
+ if len(self.cached_value) == 1:
220
+ keys = torch.cat(list(self.cached_key) + [keys], dim=1)
221
+ values = torch.cat(list(self.cached_value) + [values], dim=1)
222
+ _, m_kv, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
223
+ compact_keys, compact_values = m_kv(keys, values)
224
+ self.cached_key.append(compact_keys)
225
+ self.cached_value.append(compact_values)
226
+ else:
227
+ self.cached_key.append(keys)
228
+ self.cached_value.append(values)
229
+
230
+ def _tome_step_out(self, outputs):
231
+ if len(self.cached_value) == 1:
232
+ outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
233
+ _, _, m_out= random_bipartite_soft_matching(metric=outputs, use_grid=self.use_grid, ratio=self.tome_ratio)
234
+ compact_outputs = m_out(outputs)
235
+ self.cached_output.append(compact_outputs)
236
+ else:
237
+ self.cached_output.append(outputs)
238
+
239
+ def __call__(
240
+ self,
241
+ attn: Attention,
242
+ hidden_states: torch.FloatTensor,
243
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
244
+ attention_mask: Optional[torch.FloatTensor] = None,
245
+ temb: Optional[torch.FloatTensor] = None,
246
+ scale: float = 1.0,
247
+ ) -> torch.FloatTensor:
248
+ residual = hidden_states
249
+
250
+ args = () if USE_PEFT_BACKEND else (scale,)
251
+
252
+ if attn.spatial_norm is not None:
253
+ hidden_states = attn.spatial_norm(hidden_states, temb)
254
+
255
+ input_ndim = hidden_states.ndim
256
+
257
+ if input_ndim == 4:
258
+ batch_size, channel, height, width = hidden_states.shape
259
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
260
+
261
+ batch_size, key_tokens, _ = (
262
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
263
+ )
264
+
265
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
266
+ if attention_mask is not None:
267
+ # expand our mask's singleton query_tokens dimension:
268
+ # [batch*heads, 1, key_tokens] ->
269
+ # [batch*heads, query_tokens, key_tokens]
270
+ # so that it can be added as a bias onto the attention scores that xformers computes:
271
+ # [batch*heads, query_tokens, key_tokens]
272
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
273
+ _, query_tokens, _ = hidden_states.shape
274
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
275
+
276
+ if attn.group_norm is not None:
277
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
278
+
279
+ query = attn.to_q(hidden_states, *args)
280
+
281
+ is_selfattn = False
282
+ if encoder_hidden_states is None:
283
+ is_selfattn = True
284
+ encoder_hidden_states = hidden_states
285
+ elif attn.norm_cross:
286
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
287
+
288
+ key = attn.to_k(encoder_hidden_states, *args)
289
+ value = attn.to_v(encoder_hidden_states, *args)
290
+
291
+ if is_selfattn:
292
+ cached_key = key.clone()
293
+ cached_value = value.clone()
294
+
295
+ if len(self.cached_key) > 0:
296
+ key = torch.cat([key] + list(self.cached_key), dim=1)
297
+ value = torch.cat([value] + list(self.cached_value), dim=1)
298
+
299
+ ## Code for storing and visualizing features
300
+ # if self.frame_id % self.interval == 0:
301
+ # # if "down_blocks.0" in self.name or "up_blocks.3" in self.name:
302
+ # # feats = {
303
+ # # "hidden_states": hidden_states.clone().cpu(),
304
+ # # "query": query.clone().cpu(),
305
+ # # "key": cached_key.cpu(),
306
+ # # "value": cached_value.cpu(),
307
+ # # }
308
+ # # torch.save(feats, f'./outputs/self_attn_feats_SD/{self.name}.frame{self.frame_id}.pt')
309
+ # if self.use_tome_cache:
310
+ # cached_key, cached_value = self._tome_step(cached_key, cached_value)
311
+
312
+ query = attn.head_to_batch_dim(query).contiguous()
313
+ key = attn.head_to_batch_dim(key).contiguous()
314
+ value = attn.head_to_batch_dim(value).contiguous()
315
+
316
+ hidden_states = xformers.ops.memory_efficient_attention(
317
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
318
+ )
319
+ hidden_states = hidden_states.to(query.dtype)
320
+ hidden_states = attn.batch_to_head_dim(hidden_states)
321
+
322
+ # linear proj
323
+ hidden_states = attn.to_out[0](hidden_states, *args)
324
+ # dropout
325
+ hidden_states = attn.to_out[1](hidden_states)
326
+
327
+ if input_ndim == 4:
328
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
329
+
330
+ if attn.residual_connection:
331
+ hidden_states = hidden_states + residual
332
+
333
+ hidden_states = hidden_states / attn.rescale_output_factor
334
+ if is_selfattn:
335
+ cached_output = hidden_states.clone()
336
+ if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
337
+ if len(self.cached_output) > 0:
338
+ nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
339
+ hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
340
+
341
+ if self.frame_id % self.interval == 0:
342
+ if is_selfattn:
343
+ if self.use_tome_cache:
344
+ self._tome_step_kvout(cached_key, cached_value, cached_output)
345
+ else:
346
+ self.cached_key.append(cached_key)
347
+ self.cached_value.append(cached_value)
348
+ self.cached_output.append(cached_output)
349
+ self.frame_id += 1
350
+
351
+ return hidden_states
352
+
streamv2v/models/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from typing import Tuple, Callable
3
+
4
+ from einops import rearrange
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ def get_nn_feats(x, y, threshold=0.9):
9
+
10
+ if type(x) is deque:
11
+ x = torch.cat(list(x), dim=1)
12
+ if type(y) is deque:
13
+ y = torch.cat(list(y), dim=1)
14
+
15
+ x_norm = F.normalize(x, p=2, dim=-1)
16
+ y_norm = F.normalize(y, p=2, dim=-1)
17
+
18
+ cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
19
+
20
+ max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
21
+ mask = max_cosine_values < threshold
22
+ # print('mask ratio', torch.sum(mask)/x.shape[0]/x.shape[1])
23
+ indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
24
+ nearest_neighbor_tensor = torch.gather(y, 1, indices_expanded)
25
+ selected_tensor = torch.where(mask.unsqueeze(-1), x, nearest_neighbor_tensor)
26
+
27
+ return selected_tensor
28
+
29
+ def get_nn_latent(x, y, threshold=0.9):
30
+
31
+ assert len(x.shape) == 4
32
+ _, c, h, w = x.shape
33
+ x_ = rearrange(x, 'n c h w -> n (h w) c')
34
+ y_ = []
35
+ for i in range(len(y)):
36
+ y_.append(rearrange(y[i], 'n c h w -> n (h w) c'))
37
+ y_ = torch.cat(y_, dim=1)
38
+ x_norm = F.normalize(x_, p=2, dim=-1)
39
+ y_norm = F.normalize(y_, p=2, dim=-1)
40
+
41
+ cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
42
+
43
+ max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
44
+ mask = max_cosine_values < threshold
45
+ indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
46
+ nearest_neighbor_tensor = torch.gather(y_, 1, indices_expanded)
47
+
48
+ # Use values from x where the cosine similarity is below the threshold
49
+ x_expanded = x_.expand_as(nearest_neighbor_tensor)
50
+ selected_tensor = torch.where(mask.unsqueeze(-1), x_expanded, nearest_neighbor_tensor)
51
+
52
+ selected_tensor = rearrange(selected_tensor, 'n (h w) c -> n c h w', h=h, w=w, c=c)
53
+
54
+ return selected_tensor
55
+
56
+
57
+ def random_bipartite_soft_matching(
58
+ metric: torch.Tensor, use_grid: bool = False, ratio: float = 0.5
59
+ ) -> Tuple[Callable, Callable]:
60
+ """
61
+ Applies ToMe with the two sets as (r chosen randomly, the rest).
62
+ Input size is [batch, tokens, channels].
63
+
64
+ This will reduce the number of tokens by a ratio of ratio/2.
65
+ """
66
+
67
+ with torch.no_grad():
68
+ B, N, _ = metric.shape
69
+ if use_grid:
70
+ assert ratio == 0.5
71
+ sample = torch.randint(2, size=(B, N//2, 1), device=metric.device)
72
+ sample_alternate = 1 - sample
73
+ grid = torch.arange(0, N, 2).view(1, N//2, 1).to(device=metric.device)
74
+ grid = grid.repeat(4, 1, 1)
75
+ rand_idx = torch.cat([sample + grid, sample_alternate + grid], dim = 1)
76
+ else:
77
+ rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
78
+ r = int(ratio * N)
79
+ a_idx = rand_idx[:, :r, :]
80
+ b_idx = rand_idx[:, r:, :]
81
+ def split(x):
82
+ C = x.shape[-1]
83
+ a = x.gather(dim=1, index=a_idx.expand(B, r, C))
84
+ b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
85
+ return a, b
86
+
87
+ metric = metric / metric.norm(dim=-1, keepdim=True)
88
+ a, b = split(metric)
89
+ scores = a @ b.transpose(-1, -2)
90
+
91
+ _, dst_idx = scores.max(dim=-1)
92
+ dst_idx = dst_idx[..., None]
93
+
94
+ def merge_kv_out(keys: torch.Tensor, values: torch.Tensor, outputs: torch.Tensor, mode="mean") -> torch.Tensor:
95
+ src_keys, dst_keys = split(keys)
96
+ C_keys = src_keys.shape[-1]
97
+ dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
98
+
99
+ src_values, dst_values = split(values)
100
+ C_values = src_values.shape[-1]
101
+ dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
102
+
103
+ src_outputs, dst_outputs = split(outputs)
104
+ C_outputs = src_outputs.shape[-1]
105
+ dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
106
+
107
+ return dst_keys, dst_values, dst_outputs
108
+
109
+ def merge_kv(keys: torch.Tensor, values: torch.Tensor, mode="mean") -> torch.Tensor:
110
+ src_keys, dst_keys = split(keys)
111
+ C_keys = src_keys.shape[-1]
112
+ dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
113
+
114
+ src_values, dst_values = split(values)
115
+ C_values = src_values.shape[-1]
116
+ dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
117
+
118
+ return dst_keys, dst_values
119
+
120
+ def merge_out(outputs: torch.Tensor, mode="mean") -> torch.Tensor:
121
+ src_outputs, dst_outputs = split(outputs)
122
+ C_outputs = src_outputs.shape[-1]
123
+ dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
124
+
125
+ return dst_outputs
126
+
127
+ return merge_kv_out, merge_kv, merge_out
streamv2v/pip_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import importlib.util
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ from typing import Dict, Optional
7
+
8
+ from packaging.version import Version
9
+
10
+
11
+ python = sys.executable
12
+ index_url = os.environ.get("INDEX_URL", "")
13
+
14
+
15
+ def version(package: str) -> Optional[Version]:
16
+ try:
17
+ return Version(importlib.import_module(package).__version__)
18
+ except ModuleNotFoundError:
19
+ return None
20
+
21
+
22
+ def is_installed(package: str) -> bool:
23
+ try:
24
+ spec = importlib.util.find_spec(package)
25
+ except ModuleNotFoundError:
26
+ return False
27
+
28
+ return spec is not None
29
+
30
+
31
+ def run_python(command: str, env: Dict[str, str] = None) -> str:
32
+ run_kwargs = {
33
+ "args": f"\"{python}\" {command}",
34
+ "shell": True,
35
+ "env": os.environ if env is None else env,
36
+ "encoding": "utf8",
37
+ "errors": "ignore",
38
+ }
39
+
40
+ print(run_kwargs["args"])
41
+
42
+ result = subprocess.run(**run_kwargs)
43
+
44
+ if result.returncode != 0:
45
+ print(f"Error running command: {command}", file=sys.stderr)
46
+ raise RuntimeError(f"Error running command: {command}")
47
+
48
+ return result.stdout or ""
49
+
50
+
51
+ def run_pip(command: str, env: Dict[str, str] = None) -> str:
52
+ return run_python(f"-m pip {command}", env)
streamv2v/pipeline.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import time
4
+ from typing import List, Optional, Union, Any, Dict, Tuple, Literal
5
+ from collections import deque
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torchvision.models.optical_flow import raft_small
12
+
13
+ from diffusers import LCMScheduler, StableDiffusionPipeline
14
+ from diffusers.image_processor import VaeImageProcessor
15
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
16
+ retrieve_latents,
17
+ )
18
+ from .image_utils import postprocess_image, forward_backward_consistency_check
19
+ from .models.utils import get_nn_latent
20
+ from .image_filter import SimilarImageFilter
21
+
22
+
23
+ class StreamV2V:
24
+ def __init__(
25
+ self,
26
+ pipe: StableDiffusionPipeline,
27
+ t_index_list: List[int],
28
+ torch_dtype: torch.dtype = torch.float16,
29
+ width: int = 512,
30
+ height: int = 512,
31
+ do_add_noise: bool = True,
32
+ use_denoising_batch: bool = True,
33
+ frame_buffer_size: int = 1,
34
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
35
+ ) -> None:
36
+ self.device = pipe.device
37
+ self.dtype = torch_dtype
38
+ self.generator = None
39
+
40
+ self.height = height
41
+ self.width = width
42
+
43
+ self.latent_height = int(height // pipe.vae_scale_factor)
44
+ self.latent_width = int(width // pipe.vae_scale_factor)
45
+
46
+ self.frame_bff_size = frame_buffer_size
47
+ self.denoising_steps_num = len(t_index_list)
48
+
49
+ self.cfg_type = cfg_type
50
+
51
+ if use_denoising_batch:
52
+ self.batch_size = self.denoising_steps_num * frame_buffer_size
53
+ if self.cfg_type == "initialize":
54
+ self.trt_unet_batch_size = (
55
+ self.denoising_steps_num + 1
56
+ ) * self.frame_bff_size
57
+ elif self.cfg_type == "full":
58
+ self.trt_unet_batch_size = (
59
+ 2 * self.denoising_steps_num * self.frame_bff_size
60
+ )
61
+ else:
62
+ self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
63
+ else:
64
+ self.trt_unet_batch_size = self.frame_bff_size
65
+ self.batch_size = frame_buffer_size
66
+
67
+ self.t_list = t_index_list
68
+
69
+ self.do_add_noise = do_add_noise
70
+ self.use_denoising_batch = use_denoising_batch
71
+
72
+ self.similar_image_filter = False
73
+ self.similar_filter = SimilarImageFilter()
74
+ self.prev_image_tensor = None
75
+ self.prev_x_t_latent = None
76
+ self.prev_image_result = None
77
+
78
+ self.pipe = pipe
79
+ self.image_processor = VaeImageProcessor(pipe.vae_scale_factor)
80
+
81
+ self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
82
+ self.text_encoder = pipe.text_encoder
83
+ self.unet = pipe.unet
84
+ self.vae = pipe.vae
85
+
86
+ self.flow_model = raft_small(pretrained=True, progress=False).to(device=pipe.device).eval()
87
+
88
+ self.cached_x_t_latent = deque(maxlen=4)
89
+
90
+ self.inference_time_ema = 0
91
+
92
+ def load_lcm_lora(
93
+ self,
94
+ pretrained_model_name_or_path_or_dict: Union[
95
+ str, Dict[str, torch.Tensor]
96
+ ] = "latent-consistency/lcm-lora-sdv1-5",
97
+ adapter_name: Optional[Any] = 'lcm',
98
+ **kwargs,
99
+ ) -> None:
100
+ self.pipe.load_lora_weights(
101
+ pretrained_model_name_or_path_or_dict, adapter_name, **kwargs
102
+ )
103
+
104
+ def load_lora(
105
+ self,
106
+ pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
107
+ adapter_name: Optional[Any] = None,
108
+ **kwargs,
109
+ ) -> None:
110
+ self.pipe.load_lora_weights(
111
+ pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs
112
+ )
113
+
114
+ def fuse_lora(
115
+ self,
116
+ fuse_unet: bool = True,
117
+ fuse_text_encoder: bool = True,
118
+ lora_scale: float = 1.0,
119
+ safe_fusing: bool = False,
120
+ ) -> None:
121
+ self.pipe.fuse_lora(
122
+ fuse_unet=fuse_unet,
123
+ fuse_text_encoder=fuse_text_encoder,
124
+ lora_scale=lora_scale,
125
+ safe_fusing=safe_fusing,
126
+ )
127
+
128
+ def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
129
+ self.similar_image_filter = True
130
+ self.similar_filter.set_threshold(threshold)
131
+ self.similar_filter.set_max_skip_frame(max_skip_frame)
132
+
133
+ def disable_similar_image_filter(self) -> None:
134
+ self.similar_image_filter = False
135
+
136
+ @torch.no_grad()
137
+ def prepare(
138
+ self,
139
+ prompt: str,
140
+ negative_prompt: str = "",
141
+ num_inference_steps: int = 50,
142
+ guidance_scale: float = 1.2,
143
+ delta: float = 1.0,
144
+ generator: Optional[torch.Generator] = torch.Generator(),
145
+ seed: int = 2,
146
+ ) -> None:
147
+ self.generator = generator
148
+ self.generator.manual_seed(seed)
149
+ # initialize x_t_latent (it can be any random tensor)
150
+ if self.denoising_steps_num > 1:
151
+ self.x_t_latent_buffer = torch.zeros(
152
+ (
153
+ (self.denoising_steps_num - 1) * self.frame_bff_size,
154
+ 4,
155
+ self.latent_height,
156
+ self.latent_width,
157
+ ),
158
+ dtype=self.dtype,
159
+ device=self.device,
160
+ )
161
+ else:
162
+ self.x_t_latent_buffer = None
163
+
164
+ if self.cfg_type == "none":
165
+ self.guidance_scale = 1.0
166
+ else:
167
+ self.guidance_scale = guidance_scale
168
+ self.delta = delta
169
+
170
+ do_classifier_free_guidance = False
171
+ if self.guidance_scale > 1.0:
172
+ do_classifier_free_guidance = True
173
+
174
+ encoder_output = self.pipe.encode_prompt(
175
+ prompt=prompt,
176
+ device=self.device,
177
+ num_images_per_prompt=1,
178
+ do_classifier_free_guidance=True,
179
+ negative_prompt=negative_prompt,
180
+ )
181
+
182
+ self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
183
+ self.null_prompt_embeds = encoder_output[1]
184
+
185
+ if self.use_denoising_batch and self.cfg_type == "full":
186
+ uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1)
187
+ elif self.cfg_type == "initialize":
188
+ uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1)
189
+
190
+ if self.guidance_scale > 1.0 and (
191
+ self.cfg_type == "initialize" or self.cfg_type == "full"
192
+ ):
193
+ self.prompt_embeds = torch.cat(
194
+ [uncond_prompt_embeds, self.prompt_embeds], dim=0
195
+ )
196
+
197
+ self.scheduler.set_timesteps(num_inference_steps, self.device)
198
+ self.timesteps = self.scheduler.timesteps.to(self.device)
199
+
200
+ # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
201
+ self.sub_timesteps = []
202
+ for t in self.t_list:
203
+ self.sub_timesteps.append(self.timesteps[t])
204
+
205
+ sub_timesteps_tensor = torch.tensor(
206
+ self.sub_timesteps, dtype=torch.long, device=self.device
207
+ )
208
+ self.sub_timesteps_tensor = torch.repeat_interleave(
209
+ sub_timesteps_tensor,
210
+ repeats=self.frame_bff_size if self.use_denoising_batch else 1,
211
+ dim=0,
212
+ )
213
+
214
+ self.init_noise = torch.randn(
215
+ (self.batch_size, 4, self.latent_height, self.latent_width),
216
+ generator=generator,
217
+ ).to(device=self.device, dtype=self.dtype)
218
+
219
+ self.randn_noise = self.init_noise[:1].clone()
220
+ self.warp_noise = self.init_noise[:1].clone()
221
+
222
+ self.stock_noise = torch.zeros_like(self.init_noise)
223
+
224
+ c_skip_list = []
225
+ c_out_list = []
226
+ for timestep in self.sub_timesteps:
227
+ c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(
228
+ timestep
229
+ )
230
+ c_skip_list.append(c_skip)
231
+ c_out_list.append(c_out)
232
+
233
+ self.c_skip = (
234
+ torch.stack(c_skip_list)
235
+ .view(len(self.t_list), 1, 1, 1)
236
+ .to(dtype=self.dtype, device=self.device)
237
+ )
238
+ self.c_out = (
239
+ torch.stack(c_out_list)
240
+ .view(len(self.t_list), 1, 1, 1)
241
+ .to(dtype=self.dtype, device=self.device)
242
+ )
243
+
244
+ alpha_prod_t_sqrt_list = []
245
+ beta_prod_t_sqrt_list = []
246
+ for timestep in self.sub_timesteps:
247
+ alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
248
+ beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
249
+ alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
250
+ beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
251
+ alpha_prod_t_sqrt = (
252
+ torch.stack(alpha_prod_t_sqrt_list)
253
+ .view(len(self.t_list), 1, 1, 1)
254
+ .to(dtype=self.dtype, device=self.device)
255
+ )
256
+ beta_prod_t_sqrt = (
257
+ torch.stack(beta_prod_t_sqrt_list)
258
+ .view(len(self.t_list), 1, 1, 1)
259
+ .to(dtype=self.dtype, device=self.device)
260
+ )
261
+ self.alpha_prod_t_sqrt = torch.repeat_interleave(
262
+ alpha_prod_t_sqrt,
263
+ repeats=self.frame_bff_size if self.use_denoising_batch else 1,
264
+ dim=0,
265
+ )
266
+ self.beta_prod_t_sqrt = torch.repeat_interleave(
267
+ beta_prod_t_sqrt,
268
+ repeats=self.frame_bff_size if self.use_denoising_batch else 1,
269
+ dim=0,
270
+ )
271
+
272
+ @torch.no_grad()
273
+ def update_prompt(self, prompt: str) -> None:
274
+ encoder_output = self.pipe.encode_prompt(
275
+ prompt=prompt,
276
+ device=self.device,
277
+ num_images_per_prompt=1,
278
+ do_classifier_free_guidance=False,
279
+ )
280
+ self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
281
+
282
+ def add_noise(
283
+ self,
284
+ original_samples: torch.Tensor,
285
+ noise: torch.Tensor,
286
+ t_index: int,
287
+ ) -> torch.Tensor:
288
+ noisy_samples = (
289
+ self.alpha_prod_t_sqrt[t_index] * original_samples
290
+ + self.beta_prod_t_sqrt[t_index] * noise
291
+ )
292
+ return noisy_samples
293
+
294
+ def scheduler_step_batch(
295
+ self,
296
+ model_pred_batch: torch.Tensor,
297
+ x_t_latent_batch: torch.Tensor,
298
+ idx: Optional[int] = None,
299
+ ) -> torch.Tensor:
300
+ # TODO: use t_list to select beta_prod_t_sqrt
301
+ if idx is None:
302
+ F_theta = (
303
+ x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch
304
+ ) / self.alpha_prod_t_sqrt
305
+ denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch
306
+ else:
307
+ F_theta = (
308
+ x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch
309
+ ) / self.alpha_prod_t_sqrt[idx]
310
+ denoised_batch = (
311
+ self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
312
+ )
313
+
314
+ return denoised_batch
315
+
316
+ def unet_step(
317
+ self,
318
+ x_t_latent: torch.Tensor,
319
+ t_list: Union[torch.Tensor, list[int]],
320
+ idx: Optional[int] = None,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
322
+ if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
323
+ x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
324
+ t_list = torch.concat([t_list[0:1], t_list], dim=0)
325
+ elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
326
+ x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
327
+ t_list = torch.concat([t_list, t_list], dim=0)
328
+ else:
329
+ x_t_latent_plus_uc = x_t_latent
330
+
331
+ model_pred = self.unet(
332
+ x_t_latent_plus_uc,
333
+ t_list,
334
+ encoder_hidden_states=self.prompt_embeds,
335
+ return_dict=False,
336
+ )[0]
337
+
338
+ if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
339
+ noise_pred_text = model_pred[1:]
340
+ self.stock_noise = torch.concat(
341
+ [model_pred[0:1], self.stock_noise[1:]], dim=0
342
+ ) # ここコメントアウトでself out cfg
343
+ elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
344
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
345
+ else:
346
+ noise_pred_text = model_pred
347
+ if self.guidance_scale > 1.0 and (
348
+ self.cfg_type == "self" or self.cfg_type == "initialize"
349
+ ):
350
+ noise_pred_uncond = self.stock_noise * self.delta
351
+ if self.guidance_scale > 1.0 and self.cfg_type != "none":
352
+ model_pred = noise_pred_uncond + self.guidance_scale * (
353
+ noise_pred_text - noise_pred_uncond
354
+ )
355
+ else:
356
+ model_pred = noise_pred_text
357
+
358
+ # compute the previous noisy sample x_t -> x_t-1
359
+ if self.use_denoising_batch:
360
+ denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
361
+ if self.cfg_type == "self" or self.cfg_type == "initialize":
362
+ scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
363
+ delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
364
+ alpha_next = torch.concat(
365
+ [
366
+ self.alpha_prod_t_sqrt[1:],
367
+ torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
368
+ ],
369
+ dim=0,
370
+ )
371
+ delta_x = alpha_next * delta_x
372
+ beta_next = torch.concat(
373
+ [
374
+ self.beta_prod_t_sqrt[1:],
375
+ torch.ones_like(self.beta_prod_t_sqrt[0:1]),
376
+ ],
377
+ dim=0,
378
+ )
379
+ delta_x = delta_x / beta_next
380
+ init_noise = torch.concat(
381
+ [self.init_noise[1:], self.init_noise[0:1]], dim=0
382
+ )
383
+ self.stock_noise = init_noise + delta_x
384
+
385
+ else:
386
+ # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised
387
+ denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
388
+
389
+ return denoised_batch, model_pred
390
+
391
+
392
+ def norm_noise(self, noise):
393
+ # Compute mean and std of blended_noise
394
+ mean = noise.mean()
395
+ std = noise.std()
396
+
397
+ # Normalize blended_noise to have mean=0 and std=1
398
+ normalized_noise = (noise - mean) / std
399
+ return normalized_noise
400
+
401
+ def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor:
402
+ image_tensors = image_tensors.to(
403
+ device=self.device,
404
+ dtype=self.vae.dtype,
405
+ )
406
+ img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator)
407
+ img_latent = img_latent * self.vae.config.scaling_factor
408
+ x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0)
409
+ return x_t_latent
410
+
411
+ def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor:
412
+ output_latent = self.vae.decode(
413
+ x_0_pred_out / self.vae.config.scaling_factor, return_dict=False
414
+ )[0]
415
+ return output_latent
416
+
417
+ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:
418
+ prev_latent_batch = self.x_t_latent_buffer
419
+ if self.use_denoising_batch:
420
+ t_list = self.sub_timesteps_tensor
421
+ if self.denoising_steps_num > 1:
422
+ x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0)
423
+ self.stock_noise = torch.cat(
424
+ (self.init_noise[0:1], self.stock_noise[:-1]), dim=0
425
+ )
426
+ x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list)
427
+
428
+ if self.denoising_steps_num > 1:
429
+ x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0)
430
+ if self.do_add_noise:
431
+ self.x_t_latent_buffer = (
432
+ self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
433
+ + self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
434
+ )
435
+ else:
436
+ self.x_t_latent_buffer = (
437
+ self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
438
+ )
439
+ else:
440
+ x_0_pred_out = x_0_pred_batch
441
+ self.x_t_latent_buffer = None
442
+ else:
443
+ self.init_noise = x_t_latent
444
+ for idx, t in enumerate(self.sub_timesteps_tensor):
445
+ t = t.view(
446
+ 1,
447
+ ).repeat(
448
+ self.frame_bff_size,
449
+ )
450
+ x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx)
451
+ if idx < len(self.sub_timesteps_tensor) - 1:
452
+ if self.do_add_noise:
453
+ x_t_latent = self.alpha_prod_t_sqrt[
454
+ idx + 1
455
+ ] * x_0_pred + self.beta_prod_t_sqrt[
456
+ idx + 1
457
+ ] * torch.randn_like(
458
+ x_0_pred, device=self.device, dtype=self.dtype
459
+ )
460
+ else:
461
+ x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
462
+ x_0_pred_out = x_0_pred
463
+ return x_0_pred_out
464
+
465
+ @torch.no_grad()
466
+ def __call__(
467
+ self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
468
+ ) -> torch.Tensor:
469
+ start = torch.cuda.Event(enable_timing=True)
470
+ end = torch.cuda.Event(enable_timing=True)
471
+ start.record()
472
+ if x is not None:
473
+ x = self.image_processor.preprocess(x, self.height, self.width).to(
474
+ device=self.device, dtype=self.dtype
475
+ )
476
+ if self.similar_image_filter:
477
+ x = self.similar_filter(x)
478
+ if x is None:
479
+ time.sleep(self.inference_time_ema)
480
+ return self.prev_image_result
481
+ x_t_latent = self.encode_image(x)
482
+ else:
483
+ # TODO: check the dimension of x_t_latent
484
+ x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to(
485
+ device=self.device, dtype=self.dtype
486
+ )
487
+ x_0_pred_out = self.predict_x0_batch(x_t_latent)
488
+ x_output = self.decode_image(x_0_pred_out).detach().clone()
489
+
490
+ self.prev_image_result = x_output
491
+ end.record()
492
+ torch.cuda.synchronize()
493
+ inference_time = start.elapsed_time(end) / 1000
494
+ self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time
495
+ return x_output
streamv2v/tools/__init__.py ADDED
File without changes
streamv2v/tools/install-tensorrt.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import fire
4
+ from packaging.version import Version
5
+
6
+ from ..pip_utils import is_installed, run_pip, version
7
+ import platform
8
+
9
+
10
+ def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]:
11
+ try:
12
+ import torch
13
+ except ImportError:
14
+ return None
15
+
16
+ return torch.version.cuda.split(".")[0]
17
+
18
+
19
+ def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()):
20
+ if cu is None or cu not in ["11", "12"]:
21
+ print("Could not detect CUDA version. Please specify manually.")
22
+ return
23
+ print("Installing TensorRT requirements...")
24
+
25
+ if is_installed("tensorrt"):
26
+ if version("tensorrt") < Version("9.0.0"):
27
+ run_pip("uninstall -y tensorrt")
28
+
29
+ cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25"
30
+
31
+ if not is_installed("tensorrt"):
32
+ run_pip(f"install {cudnn_name} --no-cache-dir")
33
+ run_pip(
34
+ "install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir"
35
+ )
36
+
37
+ if not is_installed("polygraphy"):
38
+ run_pip(
39
+ "install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com"
40
+ )
41
+ if not is_installed("onnx_graphsurgeon"):
42
+ run_pip(
43
+ "install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com"
44
+ )
45
+ # if platform.system() == 'Windows' and not is_installed("pywin32"):
46
+ # run_pip(
47
+ # "install pywin32"
48
+ # )
49
+
50
+ pass
51
+
52
+
53
+ if __name__ == "__main__":
54
+ fire.Fire(install)