File size: 3,034 Bytes
699428f
 
a41b897
699428f
 
 
 
a41b897
699428f
 
5592f9d
a41b897
 
 
 
 
 
 
699428f
 
 
 
 
 
a41b897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5592f9d
 
 
 
 
 
 
 
 
 
 
 
699428f
 
 
 
 
a41b897
699428f
 
 
 
 
 
 
 
 
 
a41b897
699428f
 
 
 
 
 
 
 
 
 
 
a41b897
699428f
 
 
 
a41b897
5592f9d
 
 
 
a41b897
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
from typing import Dict, Any
from pulid.pipeline_v1_1 import PuLIDPipeline
from pulid.utils import resize_numpy_image_long
from pulid import attention_processor as attention

# Disable gradients for inference
torch.set_grad_enabled(False)

class EndpointHandler:
    def __init__(self, model_dir: str = None):
        """
        Initializes the model and necessary components.

        Args:
            model_dir (str): Directory containing the model weights.
        """
        self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde')
        self.default_cfg = 7.0
        self.default_steps = 25
        self.attention = attention
        self.pipeline.debug_img_list = []

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handles inference requests.

        Args:
            data (Dict[str, Any]): Input data for inference.

        Returns:
            Dict[str, Any]: Results containing the generated image and debug information.
        """
        # Preprocess inputs
        inputs = data.get("inputs", [])
        if not inputs or len(inputs) < 14:
            raise ValueError("Invalid inputs. Expected 14 elements in the input list.")

        id_image = inputs[0]
        supp_images = inputs[1:4]
        prompt = inputs[4]
        neg_prompt = inputs[5]
        scale = inputs[6]
        seed = int(inputs[7])
        steps = int(inputs[8])
        H = int(inputs[9])
        W = int(inputs[10])
        id_scale = inputs[11]
        num_zero = int(inputs[12])
        ortho = inputs[13]

        # Set seed if needed
        if seed == -1:
            seed = torch.Generator(device="cpu").seed()

        # Handle orthogonal settings
        if ortho == 'v2':
            self.attention.ORTHO = False
            self.attention.ORTHO_v2 = True
        elif ortho == 'v1':
            self.attention.ORTHO = True
            self.attention.ORTHO_v2 = False
        else:
            self.attention.ORTHO = False
            self.attention.ORTHO_v2 = False

        # Process images
        if id_image is not None:
            id_image = resize_numpy_image_long(id_image, 1024)
            supp_id_image_list = [
                resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
            ]
            id_image_list = [id_image] + supp_id_image_list
            uncond_id_embedding, id_embedding = self.pipeline.get_id_embedding(id_image_list)
        else:
            uncond_id_embedding = None
            id_embedding = None

        # Generate image
        img = self.pipeline.inference(
            prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
        )[0]

        # Prepare response
        return {
            "image": np.array(img).tolist(),
            "seed": str(seed),
            "debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list],
        }