import torch import numpy as np from typing import Dict, Any from pulid.pipeline_v1_1 import PuLIDPipeline from pulid.utils import resize_numpy_image_long from pulid import attention_processor as attention # Disable gradients for inference torch.set_grad_enabled(False) class EndpointHandler: def __init__(self, model_dir: str = None): """ Initializes the model and necessary components. Args: model_dir (str): Directory containing the model weights. """ self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde') self.default_cfg = 7.0 self.default_steps = 25 self.attention = attention self.pipeline.debug_img_list = [] def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Handles inference requests. Args: data (Dict[str, Any]): Input data for inference. Returns: Dict[str, Any]: Results containing the generated image and debug information. """ # Preprocess inputs inputs = data.get("inputs", []) if not inputs or len(inputs) < 14: raise ValueError("Invalid inputs. Expected 14 elements in the input list.") id_image = inputs[0] supp_images = inputs[1:4] prompt = inputs[4] neg_prompt = inputs[5] scale = inputs[6] seed = int(inputs[7]) steps = int(inputs[8]) H = int(inputs[9]) W = int(inputs[10]) id_scale = inputs[11] num_zero = int(inputs[12]) ortho = inputs[13] # Set seed if needed if seed == -1: seed = torch.Generator(device="cpu").seed() # Handle orthogonal settings if ortho == 'v2': self.attention.ORTHO = False self.attention.ORTHO_v2 = True elif ortho == 'v1': self.attention.ORTHO = True self.attention.ORTHO_v2 = False else: self.attention.ORTHO = False self.attention.ORTHO_v2 = False # Process images if id_image is not None: id_image = resize_numpy_image_long(id_image, 1024) supp_id_image_list = [ resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None ] id_image_list = [id_image] + supp_id_image_list uncond_id_embedding, id_embedding = self.pipeline.get_id_embedding(id_image_list) else: uncond_id_embedding = None id_embedding = None # Generate image img = self.pipeline.inference( prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed )[0] # Prepare response return { "image": np.array(img).tolist(), "seed": str(seed), "debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list], }