pulid-flux-adorabook / handler.py
adorabook's picture
Update handler.py
a41b897 verified
import torch
import numpy as np
from typing import Dict, Any
from pulid.pipeline_v1_1 import PuLIDPipeline
from pulid.utils import resize_numpy_image_long
from pulid import attention_processor as attention
# Disable gradients for inference
torch.set_grad_enabled(False)
class EndpointHandler:
def __init__(self, model_dir: str = None):
"""
Initializes the model and necessary components.
Args:
model_dir (str): Directory containing the model weights.
"""
self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde')
self.default_cfg = 7.0
self.default_steps = 25
self.attention = attention
self.pipeline.debug_img_list = []
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles inference requests.
Args:
data (Dict[str, Any]): Input data for inference.
Returns:
Dict[str, Any]: Results containing the generated image and debug information.
"""
# Preprocess inputs
inputs = data.get("inputs", [])
if not inputs or len(inputs) < 14:
raise ValueError("Invalid inputs. Expected 14 elements in the input list.")
id_image = inputs[0]
supp_images = inputs[1:4]
prompt = inputs[4]
neg_prompt = inputs[5]
scale = inputs[6]
seed = int(inputs[7])
steps = int(inputs[8])
H = int(inputs[9])
W = int(inputs[10])
id_scale = inputs[11]
num_zero = int(inputs[12])
ortho = inputs[13]
# Set seed if needed
if seed == -1:
seed = torch.Generator(device="cpu").seed()
# Handle orthogonal settings
if ortho == 'v2':
self.attention.ORTHO = False
self.attention.ORTHO_v2 = True
elif ortho == 'v1':
self.attention.ORTHO = True
self.attention.ORTHO_v2 = False
else:
self.attention.ORTHO = False
self.attention.ORTHO_v2 = False
# Process images
if id_image is not None:
id_image = resize_numpy_image_long(id_image, 1024)
supp_id_image_list = [
resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
]
id_image_list = [id_image] + supp_id_image_list
uncond_id_embedding, id_embedding = self.pipeline.get_id_embedding(id_image_list)
else:
uncond_id_embedding = None
id_embedding = None
# Generate image
img = self.pipeline.inference(
prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
)[0]
# Prepare response
return {
"image": np.array(img).tolist(),
"seed": str(seed),
"debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list],
}