|
from __future__ import annotations |
|
|
|
import gc |
|
import pathlib |
|
|
|
import gradio as gr |
|
import PIL.Image |
|
import torch |
|
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler |
|
from huggingface_hub import ModelCard |
|
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid |
|
|
|
|
|
|
|
class InferencePipeline: |
|
def __init__(self, hf_token: str | None = None): |
|
self.hf_token = hf_token |
|
self.pipe = None |
|
self.device = torch.device( |
|
'cuda:0' if torch.cuda.is_available() else 'cpu') |
|
self.model_id = None |
|
self.base_model_id = None |
|
|
|
def clear(self) -> None: |
|
self.model_id = None |
|
self.base_model_id = None |
|
del self.pipe |
|
self.pipe = None |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
@staticmethod |
|
def check_if_model_is_local(model_id: str) -> bool: |
|
return pathlib.Path(model_id).exists() |
|
|
|
@staticmethod |
|
def get_model_card(model_id: str, |
|
hf_token: str | None = None) -> ModelCard: |
|
if InferencePipeline.check_if_model_is_local(model_id): |
|
card_path = (pathlib.Path(model_id) / 'README.md').as_posix() |
|
else: |
|
card_path = model_id |
|
return ModelCard.load(card_path, token=hf_token) |
|
|
|
@staticmethod |
|
def get_base_model_info(model_id: str, |
|
hf_token: str | None = None) -> str: |
|
card = InferencePipeline.get_model_card(model_id, hf_token) |
|
return card.data.base_model |
|
|
|
def load_pipe(self, model_id: str) -> None: |
|
if model_id == self.model_id: |
|
return |
|
|
|
base_model_id = self.get_base_model_info(model_id, self.hf_token) |
|
unet = load_unet_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="unet").to(self.device) |
|
|
|
for module in unet.modules(): |
|
if hasattr(module, "perform_svd"): |
|
module.perform_svd() |
|
unet = unet.to(self.device, dtype=torch.float16) |
|
if base_model_id != self.base_model_id: |
|
if self.device.type == 'cpu': |
|
pipe = DiffusionPipeline.from_pretrained( |
|
base_model_id, |
|
unet=unet, |
|
use_auth_token=self.hf_token |
|
) |
|
else: |
|
pipe = DiffusionPipeline.from_pretrained( |
|
base_model_id, |
|
unet=unet, |
|
torch_dtype=torch.float16, |
|
use_auth_token=self.hf_token |
|
) |
|
pipe = pipe.to(self.device) |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
self.pipe = pipe |
|
|
|
self.model_id = model_id |
|
self.base_model_id = base_model_id |
|
|
|
def run( |
|
self, |
|
model_id: str, |
|
prompt: str, |
|
seed: int, |
|
n_steps: int, |
|
guidance_scale: float, |
|
) -> PIL.Image.Image: |
|
|
|
|
|
|
|
self.load_pipe(model_id) |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
out = self.pipe( |
|
prompt, |
|
num_inference_steps=n_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
) |
|
return out.images[0] |
|
|