jbilcke-hf's picture
jbilcke-hf HF staff
up
69f3483
raw
history blame
No virus
15.5 kB
#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
#
# Copyright 2022 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
from collections import OrderedDict
from typing import *
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import tensorrt as trt
import torch
from cuda import cudart
from PIL import Image
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import (
CreateConfig,
Profile,
engine_from_bytes,
engine_from_network,
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
# Map of numpy dtype -> torch dtype
numpy_to_torch_dtype_dict = {
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
if np.version.full_version >= "1.24.0":
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
else:
numpy_to_torch_dtype_dict[np.bool] = torch.bool
# Map of torch dtype -> numpy dtype
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def CUASSERT(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
)
if len(cuda_ret) > 1:
return cuda_ret[1]
return None
class Engine:
def __init__(
self,
engine_path,
):
self.engine_path = engine_path
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.cuda_graph_instance = None # cuda graph
def __del__(self):
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
del self.engine
del self.context
del self.buffers
del self.tensors
def refit(self, onnx_path, onnx_refit_path):
def convert_int64(arr):
# TODO: smarter conversion
if len(arr.shape) == 0:
return np.int32(arr)
return arr
def add_to_map(refit_dict, name, values):
if name in refit_dict:
assert refit_dict[name] is None
if values.dtype == np.int64:
values = convert_int64(values)
refit_dict[name] = values
print(f"Refitting TensorRT engine with {onnx_refit_path} weights")
refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes
# Construct mapping from weight names in refit model -> original model
name_map = {}
for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes):
refit_node = refit_nodes[n]
assert node.op == refit_node.op
# Constant nodes in ONNX do not have inputs but have a constant output
if node.op == "Constant":
name_map[refit_node.outputs[0].name] = node.outputs[0].name
# Handle scale and bias weights
elif node.op == "Conv":
if node.inputs[1].__class__ == gs.Constant:
name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL"
if node.inputs[2].__class__ == gs.Constant:
name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS"
# For all other nodes: find node inputs that are initializers (gs.Constant)
else:
for i, inp in enumerate(node.inputs):
if inp.__class__ == gs.Constant:
name_map[refit_node.inputs[i].name] = inp.name
def map_name(name):
if name in name_map:
return name_map[name]
return name
# Construct refit dictionary
refit_dict = {}
refitter = trt.Refitter(self.engine, TRT_LOGGER)
all_weights = refitter.get_all()
for layer_name, role in zip(all_weights[0], all_weights[1]):
# for speciailized roles, use a unique name in the map:
if role == trt.WeightsRole.KERNEL:
name = layer_name + "_TRTKERNEL"
elif role == trt.WeightsRole.BIAS:
name = layer_name + "_TRTBIAS"
else:
name = layer_name
assert name not in refit_dict, "Found duplicate layer: " + name
refit_dict[name] = None
for n in refit_nodes:
# Constant nodes in ONNX do not have inputs but have a constant output
if n.op == "Constant":
name = map_name(n.outputs[0].name)
print(f"Add Constant {name}\n")
add_to_map(refit_dict, name, n.outputs[0].values)
# Handle scale and bias weights
elif n.op == "Conv":
if n.inputs[1].__class__ == gs.Constant:
name = map_name(n.name + "_TRTKERNEL")
add_to_map(refit_dict, name, n.inputs[1].values)
if n.inputs[2].__class__ == gs.Constant:
name = map_name(n.name + "_TRTBIAS")
add_to_map(refit_dict, name, n.inputs[2].values)
# For all other nodes: find node inputs that are initializers (AKA gs.Constant)
else:
for inp in n.inputs:
name = map_name(inp.name)
if inp.__class__ == gs.Constant:
add_to_map(refit_dict, name, inp.values)
for layer_name, weights_role in zip(all_weights[0], all_weights[1]):
if weights_role == trt.WeightsRole.KERNEL:
custom_name = layer_name + "_TRTKERNEL"
elif weights_role == trt.WeightsRole.BIAS:
custom_name = layer_name + "_TRTBIAS"
else:
custom_name = layer_name
# Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model
if layer_name.startswith("onnx::Trilu"):
continue
if refit_dict[custom_name] is not None:
refitter.set_weights(layer_name, weights_role, refit_dict[custom_name])
else:
print(f"[W] No refit weights for layer: {layer_name}")
if not refitter.refit_cuda_engine():
print("Failed to refit!")
exit(0)
def build(
self,
onnx_path,
fp16,
input_profile=None,
enable_refit=False,
enable_all_tactics=False,
timing_cache=None,
workspace_size=0,
):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(
fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
def load(self):
print(f"Loading TensorRT engine: {self.engine_path}")
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
def activate(self, reuse_device_memory=None):
if reuse_device_memory:
self.context = self.engine.create_execution_context_without_device_memory()
self.context.device_memory = reuse_device_memory
else:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
def infer(self, feed_dict, stream, use_cuda_graph=False):
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
if use_cuda_graph:
if self.cuda_graph_instance is not None:
CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr))
CUASSERT(cudart.cudaStreamSynchronize(stream.ptr))
else:
# do inference before CUDA graph capture
noerror = self.context.execute_async_v3(stream.ptr)
if not noerror:
raise ValueError("ERROR: inference failed.")
# capture cuda graph
CUASSERT(
cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
)
self.context.execute_async_v3(stream.ptr)
self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
else:
noerror = self.context.execute_async_v3(stream.ptr)
if not noerror:
raise ValueError("ERROR: inference failed.")
return self.tensors
def decode_images(images: torch.Tensor):
images = (
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
)
return [Image.fromarray(x) for x in images]
def preprocess_image(image: Image.Image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h))
init_image = np.array(image).astype(np.float32) / 255.0
init_image = init_image[None].transpose(0, 3, 1, 2)
init_image = torch.from_numpy(init_image).contiguous()
return 2.0 * init_image - 1.0
def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image):
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0
if isinstance(mask, Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous()
masked_image = image * (mask < 0.5)
return mask, masked_image
def create_models(
model_id: str,
use_auth_token: Optional[str],
device: Union[str, torch.device],
max_batch_size: int,
unet_in_channels: int = 4,
embedding_dim: int = 768,
):
models = {
"clip": CLIP(
hf_token=use_auth_token,
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
),
"unet": UNet(
hf_token=use_auth_token,
fp16=True,
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
unet_dim=unet_in_channels,
),
"vae": VAE(
hf_token=use_auth_token,
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
),
"vae_encoder": VAEEncoder(
hf_token=use_auth_token,
device=device,
max_batch_size=max_batch_size,
embedding_dim=embedding_dim,
),
}
return models
def build_engine(
engine_path: str,
onnx_opt_path: str,
model_data: BaseModel,
opt_image_height: int,
opt_image_width: int,
opt_batch_size: int,
build_static_batch: bool = False,
build_dynamic_shape: bool = False,
build_all_tactics: bool = False,
build_enable_refit: bool = False,
):
_, free_mem, _ = cudart.cudaMemGetInfo()
GiB = 2**30
if free_mem > 6 * GiB:
activation_carveout = 4 * GiB
max_workspace_size = free_mem - activation_carveout
else:
max_workspace_size = 0
engine = Engine(engine_path)
input_profile = model_data.get_input_profile(
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch=build_static_batch,
static_shape=not build_dynamic_shape,
)
engine.build(
onnx_opt_path,
fp16=True,
input_profile=input_profile,
enable_refit=build_enable_refit,
enable_all_tactics=build_all_tactics,
workspace_size=max_workspace_size,
)
return engine
def export_onnx(
model,
onnx_path: str,
model_data: BaseModel,
opt_image_height: int,
opt_image_width: int,
opt_batch_size: int,
onnx_opset: int,
):
with torch.inference_mode(), torch.autocast("cuda"):
inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_data.get_input_names(),
output_names=model_data.get_output_names(),
dynamic_axes=model_data.get_dynamic_axes(),
)
del model
gc.collect()
torch.cuda.empty_cache()
def optimize_onnx(
onnx_path: str,
onnx_opt_path: str,
model_data: BaseModel,
):
onnx_opt_graph = model_data.optimize(onnx.load(onnx_path))
onnx.save(onnx_opt_graph, onnx_opt_path)
del onnx_opt_graph
gc.collect()
torch.cuda.empty_cache()