jbilcke-hf's picture
jbilcke-hf HF staff
up
69f3483
raw
history blame
No virus
3.64 kB
from typing import *
import torch
from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.vae import DecoderOutput
from polygraphy import cuda
from .utilities import Engine
class UNet2DConditionModelEngine:
def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False):
self.engine = Engine(filepath)
self.stream = stream
self.use_cuda_graph = use_cuda_graph
self.engine.load()
self.engine.activate()
def __call__(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
**kwargs,
) -> Any:
if timestep.dtype != torch.float32:
timestep = timestep.float()
self.engine.allocate_buffers(
shape_dict={
"sample": latent_model_input.shape,
"timestep": timestep.shape,
"encoder_hidden_states": encoder_hidden_states.shape,
"latent": latent_model_input.shape,
},
device=latent_model_input.device,
)
noise_pred = self.engine.infer(
{
"sample": latent_model_input,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
},
self.stream,
use_cuda_graph=self.use_cuda_graph,
)["latent"]
return UNet2DConditionOutput(sample=noise_pred)
def to(self, *args, **kwargs):
pass
def forward(self, *args, **kwargs):
pass
class AutoencoderKLEngine:
def __init__(
self,
encoder_path: str,
decoder_path: str,
stream: cuda.Stream,
scaling_factor: int,
use_cuda_graph: bool = False,
):
self.encoder = Engine(encoder_path)
self.decoder = Engine(decoder_path)
self.stream = stream
self.vae_scale_factor = scaling_factor
self.use_cuda_graph = use_cuda_graph
self.encoder.load()
self.decoder.load()
self.encoder.activate()
self.decoder.activate()
def encode(self, images: torch.Tensor, **kwargs):
self.encoder.allocate_buffers(
shape_dict={
"images": images.shape,
"latent": (
images.shape[0],
4,
images.shape[2] // self.vae_scale_factor,
images.shape[3] // self.vae_scale_factor,
),
},
device=images.device,
)
latents = self.encoder.infer(
{"images": images},
self.stream,
use_cuda_graph=self.use_cuda_graph,
)["latent"]
return AutoencoderTinyOutput(latents=latents)
def decode(self, latent: torch.Tensor, **kwargs):
self.decoder.allocate_buffers(
shape_dict={
"latent": latent.shape,
"images": (
latent.shape[0],
3,
latent.shape[2] * self.vae_scale_factor,
latent.shape[3] * self.vae_scale_factor,
),
},
device=latent.device,
)
images = self.decoder.infer(
{"latent": latent},
self.stream,
use_cuda_graph=self.use_cuda_graph,
)["images"]
return DecoderOutput(sample=images)
def to(self, *args, **kwargs):
pass
def forward(self, *args, **kwargs):
pass