Spaces:
Runtime error
Runtime error
import gc | |
import os | |
import torch | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( | |
retrieve_latents, | |
) | |
from polygraphy import cuda | |
from ...pipeline import StreamV2V | |
from .builder import EngineBuilder, create_onnx_path | |
from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine | |
from .models import VAE, BaseModel, UNet, VAEEncoder | |
class TorchVAEEncoder(torch.nn.Module): | |
def __init__(self, vae: AutoencoderKL): | |
super().__init__() | |
self.vae = vae | |
def forward(self, x: torch.Tensor): | |
return retrieve_latents(self.vae.encode(x)) | |
def compile_vae_encoder( | |
vae: TorchVAEEncoder, | |
model_data: BaseModel, | |
onnx_path: str, | |
onnx_opt_path: str, | |
engine_path: str, | |
opt_batch_size: int = 1, | |
engine_build_options: dict = {}, | |
): | |
builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) | |
builder.build( | |
onnx_path, | |
onnx_opt_path, | |
engine_path, | |
opt_batch_size=opt_batch_size, | |
**engine_build_options, | |
) | |
def compile_vae_decoder( | |
vae: AutoencoderKL, | |
model_data: BaseModel, | |
onnx_path: str, | |
onnx_opt_path: str, | |
engine_path: str, | |
opt_batch_size: int = 1, | |
engine_build_options: dict = {}, | |
): | |
vae = vae.to(torch.device("cuda")) | |
builder = EngineBuilder(model_data, vae, device=torch.device("cuda")) | |
builder.build( | |
onnx_path, | |
onnx_opt_path, | |
engine_path, | |
opt_batch_size=opt_batch_size, | |
**engine_build_options, | |
) | |
def compile_unet( | |
unet: UNet2DConditionModel, | |
model_data: BaseModel, | |
onnx_path: str, | |
onnx_opt_path: str, | |
engine_path: str, | |
opt_batch_size: int = 1, | |
engine_build_options: dict = {}, | |
): | |
unet = unet.to(torch.device("cuda"), dtype=torch.float16) | |
builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) | |
builder.build( | |
onnx_path, | |
onnx_opt_path, | |
engine_path, | |
opt_batch_size=opt_batch_size, | |
**engine_build_options, | |
) | |
def accelerate_with_tensorrt( | |
stream: StreamV2V, | |
engine_dir: str, | |
max_batch_size: int = 2, | |
min_batch_size: int = 1, | |
use_cuda_graph: bool = False, | |
engine_build_options: dict = {}, | |
): | |
if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None: | |
engine_build_options["opt_batch_size"] = max_batch_size | |
text_encoder = stream.text_encoder | |
unet = stream.unet | |
vae = stream.vae | |
del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae | |
vae_config = vae.config | |
vae_dtype = vae.dtype | |
unet.to(torch.device("cpu")) | |
vae.to(torch.device("cpu")) | |
gc.collect() | |
torch.cuda.empty_cache() | |
onnx_dir = os.path.join(engine_dir, "onnx") | |
os.makedirs(onnx_dir, exist_ok=True) | |
unet_engine_path = f"{engine_dir}/unet.engine" | |
vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine" | |
vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine" | |
unet_model = UNet( | |
fp16=True, | |
device=stream.device, | |
max_batch_size=max_batch_size, | |
min_batch_size=min_batch_size, | |
embedding_dim=text_encoder.config.hidden_size, | |
unet_dim=unet.config.in_channels, | |
) | |
vae_decoder_model = VAE( | |
device=stream.device, | |
max_batch_size=max_batch_size, | |
min_batch_size=min_batch_size, | |
) | |
vae_encoder_model = VAEEncoder( | |
device=stream.device, | |
max_batch_size=max_batch_size, | |
min_batch_size=min_batch_size, | |
) | |
if not os.path.exists(unet_engine_path): | |
compile_unet( | |
unet, | |
unet_model, | |
create_onnx_path("unet", onnx_dir, opt=False), | |
create_onnx_path("unet", onnx_dir, opt=True), | |
unet_engine_path, | |
**engine_build_options, | |
) | |
else: | |
del unet | |
if not os.path.exists(vae_decoder_engine_path): | |
vae.forward = vae.decode | |
compile_vae_decoder( | |
vae, | |
vae_decoder_model, | |
create_onnx_path("vae_decoder", onnx_dir, opt=False), | |
create_onnx_path("vae_decoder", onnx_dir, opt=True), | |
vae_decoder_engine_path, | |
**engine_build_options, | |
) | |
if not os.path.exists(vae_encoder_engine_path): | |
vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda")) | |
compile_vae_encoder( | |
vae_encoder, | |
vae_encoder_model, | |
create_onnx_path("vae_encoder", onnx_dir, opt=False), | |
create_onnx_path("vae_encoder", onnx_dir, opt=True), | |
vae_encoder_engine_path, | |
**engine_build_options, | |
) | |
del vae | |
cuda_steram = cuda.Stream() | |
stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph) | |
stream.vae = AutoencoderKLEngine( | |
vae_encoder_engine_path, | |
vae_decoder_engine_path, | |
cuda_steram, | |
stream.pipe.vae_scale_factor, | |
use_cuda_graph=use_cuda_graph, | |
) | |
setattr(stream.vae, "config", vae_config) | |
setattr(stream.vae, "dtype", vae_dtype) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return stream | |