|
import torch |
|
import numpy as np |
|
from typing import Dict, Any |
|
from pulid.pipeline_v1_1 import PuLIDPipeline |
|
from pulid.utils import resize_numpy_image_long |
|
from pulid import attention_processor as attention |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str = None): |
|
""" |
|
Initializes the model and necessary components. |
|
|
|
Args: |
|
model_dir (str): Directory containing the model weights. |
|
""" |
|
self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde') |
|
self.default_cfg = 7.0 |
|
self.default_steps = 25 |
|
self.attention = attention |
|
self.pipeline.debug_img_list = [] |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Handles inference requests. |
|
|
|
Args: |
|
data (Dict[str, Any]): Input data for inference. |
|
|
|
Returns: |
|
Dict[str, Any]: Results containing the generated image and debug information. |
|
""" |
|
|
|
inputs = data.get("inputs", []) |
|
if not inputs or len(inputs) < 14: |
|
raise ValueError("Invalid inputs. Expected 14 elements in the input list.") |
|
|
|
id_image = inputs[0] |
|
supp_images = inputs[1:4] |
|
prompt = inputs[4] |
|
neg_prompt = inputs[5] |
|
scale = inputs[6] |
|
seed = int(inputs[7]) |
|
steps = int(inputs[8]) |
|
H = int(inputs[9]) |
|
W = int(inputs[10]) |
|
id_scale = inputs[11] |
|
num_zero = int(inputs[12]) |
|
ortho = inputs[13] |
|
|
|
|
|
if seed == -1: |
|
seed = torch.Generator(device="cpu").seed() |
|
|
|
|
|
if ortho == 'v2': |
|
self.attention.ORTHO = False |
|
self.attention.ORTHO_v2 = True |
|
elif ortho == 'v1': |
|
self.attention.ORTHO = True |
|
self.attention.ORTHO_v2 = False |
|
else: |
|
self.attention.ORTHO = False |
|
self.attention.ORTHO_v2 = False |
|
|
|
|
|
if id_image is not None: |
|
id_image = resize_numpy_image_long(id_image, 1024) |
|
supp_id_image_list = [ |
|
resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None |
|
] |
|
id_image_list = [id_image] + supp_id_image_list |
|
uncond_id_embedding, id_embedding = self.pipeline.get_id_embedding(id_image_list) |
|
else: |
|
uncond_id_embedding = None |
|
id_embedding = None |
|
|
|
|
|
img = self.pipeline.inference( |
|
prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed |
|
)[0] |
|
|
|
|
|
return { |
|
"image": np.array(img).tolist(), |
|
"seed": str(seed), |
|
"debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list], |
|
} |