from typing import * import torch from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.models.vae import DecoderOutput from polygraphy import cuda from .utilities import Engine class UNet2DConditionModelEngine: def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): self.engine = Engine(filepath) self.stream = stream self.use_cuda_graph = use_cuda_graph self.engine.load() self.engine.activate() def __call__( self, latent_model_input: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, **kwargs, ) -> Any: if timestep.dtype != torch.float32: timestep = timestep.float() self.engine.allocate_buffers( shape_dict={ "sample": latent_model_input.shape, "timestep": timestep.shape, "encoder_hidden_states": encoder_hidden_states.shape, "latent": latent_model_input.shape, }, device=latent_model_input.device, ) noise_pred = self.engine.infer( { "sample": latent_model_input, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, }, self.stream, use_cuda_graph=self.use_cuda_graph, )["latent"] return UNet2DConditionOutput(sample=noise_pred) def to(self, *args, **kwargs): pass def forward(self, *args, **kwargs): pass class AutoencoderKLEngine: def __init__( self, encoder_path: str, decoder_path: str, stream: cuda.Stream, scaling_factor: int, use_cuda_graph: bool = False, ): self.encoder = Engine(encoder_path) self.decoder = Engine(decoder_path) self.stream = stream self.vae_scale_factor = scaling_factor self.use_cuda_graph = use_cuda_graph self.encoder.load() self.decoder.load() self.encoder.activate() self.decoder.activate() def encode(self, images: torch.Tensor, **kwargs): self.encoder.allocate_buffers( shape_dict={ "images": images.shape, "latent": ( images.shape[0], 4, images.shape[2] // self.vae_scale_factor, images.shape[3] // self.vae_scale_factor, ), }, device=images.device, ) latents = self.encoder.infer( {"images": images}, self.stream, use_cuda_graph=self.use_cuda_graph, )["latent"] return AutoencoderTinyOutput(latents=latents) def decode(self, latent: torch.Tensor, **kwargs): self.decoder.allocate_buffers( shape_dict={ "latent": latent.shape, "images": ( latent.shape[0], 3, latent.shape[2] * self.vae_scale_factor, latent.shape[3] * self.vae_scale_factor, ), }, device=latent.device, ) images = self.decoder.infer( {"latent": latent}, self.stream, use_cuda_graph=self.use_cuda_graph, )["images"] return DecoderOutput(sample=images) def to(self, *args, **kwargs): pass def forward(self, *args, **kwargs): pass