jbilcke-hf's picture
jbilcke-hf HF staff
up
69f3483
raw
history blame
5.27 kB
import gc
import os
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
retrieve_latents,
)
from polygraphy import cuda
from ...pipeline import StreamV2V
from .builder import EngineBuilder, create_onnx_path
from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine
from .models import VAE, BaseModel, UNet, VAEEncoder
class TorchVAEEncoder(torch.nn.Module):
def __init__(self, vae: AutoencoderKL):
super().__init__()
self.vae = vae
def forward(self, x: torch.Tensor):
return retrieve_latents(self.vae.encode(x))
def compile_vae_encoder(
vae: TorchVAEEncoder,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def compile_vae_decoder(
vae: AutoencoderKL,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
vae = vae.to(torch.device("cuda"))
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def compile_unet(
unet: UNet2DConditionModel,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
unet = unet.to(torch.device("cuda"), dtype=torch.float16)
builder = EngineBuilder(model_data, unet, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def accelerate_with_tensorrt(
stream: StreamV2V,
engine_dir: str,
max_batch_size: int = 2,
min_batch_size: int = 1,
use_cuda_graph: bool = False,
engine_build_options: dict = {},
):
if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None:
engine_build_options["opt_batch_size"] = max_batch_size
text_encoder = stream.text_encoder
unet = stream.unet
vae = stream.vae
del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae
vae_config = vae.config
vae_dtype = vae.dtype
unet.to(torch.device("cpu"))
vae.to(torch.device("cpu"))
gc.collect()
torch.cuda.empty_cache()
onnx_dir = os.path.join(engine_dir, "onnx")
os.makedirs(onnx_dir, exist_ok=True)
unet_engine_path = f"{engine_dir}/unet.engine"
vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine"
vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine"
unet_model = UNet(
fp16=True,
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
embedding_dim=text_encoder.config.hidden_size,
unet_dim=unet.config.in_channels,
)
vae_decoder_model = VAE(
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
)
vae_encoder_model = VAEEncoder(
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
)
if not os.path.exists(unet_engine_path):
compile_unet(
unet,
unet_model,
create_onnx_path("unet", onnx_dir, opt=False),
create_onnx_path("unet", onnx_dir, opt=True),
unet_engine_path,
**engine_build_options,
)
else:
del unet
if not os.path.exists(vae_decoder_engine_path):
vae.forward = vae.decode
compile_vae_decoder(
vae,
vae_decoder_model,
create_onnx_path("vae_decoder", onnx_dir, opt=False),
create_onnx_path("vae_decoder", onnx_dir, opt=True),
vae_decoder_engine_path,
**engine_build_options,
)
if not os.path.exists(vae_encoder_engine_path):
vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda"))
compile_vae_encoder(
vae_encoder,
vae_encoder_model,
create_onnx_path("vae_encoder", onnx_dir, opt=False),
create_onnx_path("vae_encoder", onnx_dir, opt=True),
vae_encoder_engine_path,
**engine_build_options,
)
del vae
cuda_steram = cuda.Stream()
stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph)
stream.vae = AutoencoderKLEngine(
vae_encoder_engine_path,
vae_decoder_engine_path,
cuda_steram,
stream.pipe.vae_scale_factor,
use_cuda_graph=use_cuda_graph,
)
setattr(stream.vae, "config", vae_config)
setattr(stream.vae, "dtype", vae_dtype)
gc.collect()
torch.cuda.empty_cache()
return stream