Spaces:
Runtime error
Runtime error
from typing import * | |
import torch | |
from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput | |
from diffusers.models.unet_2d_condition import UNet2DConditionOutput | |
from diffusers.models.vae import DecoderOutput | |
from polygraphy import cuda | |
from .utilities import Engine | |
class UNet2DConditionModelEngine: | |
def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): | |
self.engine = Engine(filepath) | |
self.stream = stream | |
self.use_cuda_graph = use_cuda_graph | |
self.engine.load() | |
self.engine.activate() | |
def __call__( | |
self, | |
latent_model_input: torch.Tensor, | |
timestep: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
**kwargs, | |
) -> Any: | |
if timestep.dtype != torch.float32: | |
timestep = timestep.float() | |
self.engine.allocate_buffers( | |
shape_dict={ | |
"sample": latent_model_input.shape, | |
"timestep": timestep.shape, | |
"encoder_hidden_states": encoder_hidden_states.shape, | |
"latent": latent_model_input.shape, | |
}, | |
device=latent_model_input.device, | |
) | |
noise_pred = self.engine.infer( | |
{ | |
"sample": latent_model_input, | |
"timestep": timestep, | |
"encoder_hidden_states": encoder_hidden_states, | |
}, | |
self.stream, | |
use_cuda_graph=self.use_cuda_graph, | |
)["latent"] | |
return UNet2DConditionOutput(sample=noise_pred) | |
def to(self, *args, **kwargs): | |
pass | |
def forward(self, *args, **kwargs): | |
pass | |
class AutoencoderKLEngine: | |
def __init__( | |
self, | |
encoder_path: str, | |
decoder_path: str, | |
stream: cuda.Stream, | |
scaling_factor: int, | |
use_cuda_graph: bool = False, | |
): | |
self.encoder = Engine(encoder_path) | |
self.decoder = Engine(decoder_path) | |
self.stream = stream | |
self.vae_scale_factor = scaling_factor | |
self.use_cuda_graph = use_cuda_graph | |
self.encoder.load() | |
self.decoder.load() | |
self.encoder.activate() | |
self.decoder.activate() | |
def encode(self, images: torch.Tensor, **kwargs): | |
self.encoder.allocate_buffers( | |
shape_dict={ | |
"images": images.shape, | |
"latent": ( | |
images.shape[0], | |
4, | |
images.shape[2] // self.vae_scale_factor, | |
images.shape[3] // self.vae_scale_factor, | |
), | |
}, | |
device=images.device, | |
) | |
latents = self.encoder.infer( | |
{"images": images}, | |
self.stream, | |
use_cuda_graph=self.use_cuda_graph, | |
)["latent"] | |
return AutoencoderTinyOutput(latents=latents) | |
def decode(self, latent: torch.Tensor, **kwargs): | |
self.decoder.allocate_buffers( | |
shape_dict={ | |
"latent": latent.shape, | |
"images": ( | |
latent.shape[0], | |
3, | |
latent.shape[2] * self.vae_scale_factor, | |
latent.shape[3] * self.vae_scale_factor, | |
), | |
}, | |
device=latent.device, | |
) | |
images = self.decoder.infer( | |
{"latent": latent}, | |
self.stream, | |
use_cuda_graph=self.use_cuda_graph, | |
)["images"] | |
return DecoderOutput(sample=images) | |
def to(self, *args, **kwargs): | |
pass | |
def forward(self, *args, **kwargs): | |
pass | |