#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py # # Copyright 2022 The HuggingFace Inc. team. # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import gc from collections import OrderedDict from typing import * import numpy as np import onnx import onnx_graphsurgeon as gs import tensorrt as trt import torch from cuda import cudart from PIL import Image from polygraphy import cuda from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import ( CreateConfig, Profile, engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine, ) from polygraphy.backend.trt import util as trt_util from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder TRT_LOGGER = trt.Logger(trt.Logger.ERROR) # Map of numpy dtype -> torch dtype numpy_to_torch_dtype_dict = { np.uint8: torch.uint8, np.int8: torch.int8, np.int16: torch.int16, np.int32: torch.int32, np.int64: torch.int64, np.float16: torch.float16, np.float32: torch.float32, np.float64: torch.float64, np.complex64: torch.complex64, np.complex128: torch.complex128, } if np.version.full_version >= "1.24.0": numpy_to_torch_dtype_dict[np.bool_] = torch.bool else: numpy_to_torch_dtype_dict[np.bool] = torch.bool # Map of torch dtype -> numpy dtype torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} def CUASSERT(cuda_ret): err = cuda_ret[0] if err != cudart.cudaError_t.cudaSuccess: raise RuntimeError( f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" ) if len(cuda_ret) > 1: return cuda_ret[1] return None class Engine: def __init__( self, engine_path, ): self.engine_path = engine_path self.engine = None self.context = None self.buffers = OrderedDict() self.tensors = OrderedDict() self.cuda_graph_instance = None # cuda graph def __del__(self): [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] del self.engine del self.context del self.buffers del self.tensors def refit(self, onnx_path, onnx_refit_path): def convert_int64(arr): # TODO: smarter conversion if len(arr.shape) == 0: return np.int32(arr) return arr def add_to_map(refit_dict, name, values): if name in refit_dict: assert refit_dict[name] is None if values.dtype == np.int64: values = convert_int64(values) refit_dict[name] = values print(f"Refitting TensorRT engine with {onnx_refit_path} weights") refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes # Construct mapping from weight names in refit model -> original model name_map = {} for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): refit_node = refit_nodes[n] assert node.op == refit_node.op # Constant nodes in ONNX do not have inputs but have a constant output if node.op == "Constant": name_map[refit_node.outputs[0].name] = node.outputs[0].name # Handle scale and bias weights elif node.op == "Conv": if node.inputs[1].__class__ == gs.Constant: name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" if node.inputs[2].__class__ == gs.Constant: name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" # For all other nodes: find node inputs that are initializers (gs.Constant) else: for i, inp in enumerate(node.inputs): if inp.__class__ == gs.Constant: name_map[refit_node.inputs[i].name] = inp.name def map_name(name): if name in name_map: return name_map[name] return name # Construct refit dictionary refit_dict = {} refitter = trt.Refitter(self.engine, TRT_LOGGER) all_weights = refitter.get_all() for layer_name, role in zip(all_weights[0], all_weights[1]): # for speciailized roles, use a unique name in the map: if role == trt.WeightsRole.KERNEL: name = layer_name + "_TRTKERNEL" elif role == trt.WeightsRole.BIAS: name = layer_name + "_TRTBIAS" else: name = layer_name assert name not in refit_dict, "Found duplicate layer: " + name refit_dict[name] = None for n in refit_nodes: # Constant nodes in ONNX do not have inputs but have a constant output if n.op == "Constant": name = map_name(n.outputs[0].name) print(f"Add Constant {name}\n") add_to_map(refit_dict, name, n.outputs[0].values) # Handle scale and bias weights elif n.op == "Conv": if n.inputs[1].__class__ == gs.Constant: name = map_name(n.name + "_TRTKERNEL") add_to_map(refit_dict, name, n.inputs[1].values) if n.inputs[2].__class__ == gs.Constant: name = map_name(n.name + "_TRTBIAS") add_to_map(refit_dict, name, n.inputs[2].values) # For all other nodes: find node inputs that are initializers (AKA gs.Constant) else: for inp in n.inputs: name = map_name(inp.name) if inp.__class__ == gs.Constant: add_to_map(refit_dict, name, inp.values) for layer_name, weights_role in zip(all_weights[0], all_weights[1]): if weights_role == trt.WeightsRole.KERNEL: custom_name = layer_name + "_TRTKERNEL" elif weights_role == trt.WeightsRole.BIAS: custom_name = layer_name + "_TRTBIAS" else: custom_name = layer_name # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model if layer_name.startswith("onnx::Trilu"): continue if refit_dict[custom_name] is not None: refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) else: print(f"[W] No refit weights for layer: {layer_name}") if not refitter.refit_cuda_engine(): print("Failed to refit!") exit(0) def build( self, onnx_path, fp16, input_profile=None, enable_refit=False, enable_all_tactics=False, timing_cache=None, workspace_size=0, ): print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") p = Profile() if input_profile: for name, dims in input_profile.items(): assert len(dims) == 3 p.add(name, min=dims[0], opt=dims[1], max=dims[2]) config_kwargs = {} if workspace_size > 0: config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} if not enable_all_tactics: config_kwargs["tactic_sources"] = [] engine = engine_from_network( network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), config=CreateConfig( fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs ), save_timing_cache=timing_cache, ) save_engine(engine, path=self.engine_path) def load(self): print(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) def activate(self, reuse_device_memory=None): if reuse_device_memory: self.context = self.engine.create_execution_context_without_device_memory() self.context.device_memory = reuse_device_memory else: self.context = self.engine.create_execution_context() def allocate_buffers(self, shape_dict=None, device="cuda"): for idx in range(trt_util.get_bindings_per_profile(self.engine)): binding = self.engine[idx] if shape_dict and binding in shape_dict: shape = shape_dict[binding] else: shape = self.engine.get_binding_shape(binding) dtype = trt.nptype(self.engine.get_binding_dtype(binding)) if self.engine.binding_is_input(binding): self.context.set_binding_shape(idx, shape) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) self.tensors[binding] = tensor def infer(self, feed_dict, stream, use_cuda_graph=False): for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) for name, tensor in self.tensors.items(): self.context.set_tensor_address(name, tensor.data_ptr()) if use_cuda_graph: if self.cuda_graph_instance is not None: CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) else: # do inference before CUDA graph capture noerror = self.context.execute_async_v3(stream.ptr) if not noerror: raise ValueError("ERROR: inference failed.") # capture cuda graph CUASSERT( cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) ) self.context.execute_async_v3(stream.ptr) self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) else: noerror = self.context.execute_async_v3(stream.ptr) if not noerror: raise ValueError("ERROR: inference failed.") return self.tensors def decode_images(images: torch.Tensor): images = ( ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() ) return [Image.fromarray(x) for x in images] def preprocess_image(image: Image.Image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h)) init_image = np.array(image).astype(np.float32) / 255.0 init_image = init_image[None].transpose(0, 3, 1, 2) init_image = torch.from_numpy(init_image).contiguous() return 2.0 * init_image - 1.0 def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): if isinstance(image, Image.Image): image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 if isinstance(mask, Image.Image): mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() masked_image = image * (mask < 0.5) return mask, masked_image def create_models( model_id: str, use_auth_token: Optional[str], device: Union[str, torch.device], max_batch_size: int, unet_in_channels: int = 4, embedding_dim: int = 768, ): models = { "clip": CLIP( hf_token=use_auth_token, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, ), "unet": UNet( hf_token=use_auth_token, fp16=True, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, unet_dim=unet_in_channels, ), "vae": VAE( hf_token=use_auth_token, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, ), "vae_encoder": VAEEncoder( hf_token=use_auth_token, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, ), } return models def build_engine( engine_path: str, onnx_opt_path: str, model_data: BaseModel, opt_image_height: int, opt_image_width: int, opt_batch_size: int, build_static_batch: bool = False, build_dynamic_shape: bool = False, build_all_tactics: bool = False, build_enable_refit: bool = False, ): _, free_mem, _ = cudart.cudaMemGetInfo() GiB = 2**30 if free_mem > 6 * GiB: activation_carveout = 4 * GiB max_workspace_size = free_mem - activation_carveout else: max_workspace_size = 0 engine = Engine(engine_path) input_profile = model_data.get_input_profile( opt_batch_size, opt_image_height, opt_image_width, static_batch=build_static_batch, static_shape=not build_dynamic_shape, ) engine.build( onnx_opt_path, fp16=True, input_profile=input_profile, enable_refit=build_enable_refit, enable_all_tactics=build_all_tactics, workspace_size=max_workspace_size, ) return engine def export_onnx( model, onnx_path: str, model_data: BaseModel, opt_image_height: int, opt_image_width: int, opt_batch_size: int, onnx_opset: int, ): with torch.inference_mode(), torch.autocast("cuda"): inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) torch.onnx.export( model, inputs, onnx_path, export_params=True, opset_version=onnx_opset, do_constant_folding=True, input_names=model_data.get_input_names(), output_names=model_data.get_output_names(), dynamic_axes=model_data.get_dynamic_axes(), ) del model gc.collect() torch.cuda.empty_cache() def optimize_onnx( onnx_path: str, onnx_opt_path: str, model_data: BaseModel, ): onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) onnx.save(onnx_opt_graph, onnx_opt_path) del onnx_opt_graph gc.collect() torch.cuda.empty_cache()