adorabook commited on
Commit
699428f
·
verified ·
1 Parent(s): f6a08b5

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +82 -0
handler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import AutoTokenizer
5
+ from pulid.pipeline_v1_1 import PuLIDPipeline
6
+ from pulid.utils import resize_numpy_image_long
7
+ from pulid import attention_processor as attention
8
+
9
+ torch.set_grad_enabled(False)
10
+
11
+ # Initialize the model and tokenizer
12
+ class ModelHandler:
13
+ def __init__(self):
14
+ # Set default model parameters
15
+ self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde')
16
+ self.default_cfg = 7.0
17
+ self.default_steps = 25
18
+ self.attention = attention
19
+ self.pipeline.debug_img_list = []
20
+
21
+ def preprocess(self, input_data):
22
+ # Extracts image and parameters from the input data
23
+ id_image = input_data[0]
24
+ supp_images = input_data[1:4]
25
+ prompt = input_data[4]
26
+ neg_prompt = input_data[5]
27
+ scale = input_data[6]
28
+ seed = int(input_data[7])
29
+ steps = int(input_data[8])
30
+ H = int(input_data[9])
31
+ W = int(input_data[10])
32
+ id_scale = input_data[11]
33
+ num_zero = int(input_data[12])
34
+ ortho = input_data[13]
35
+
36
+ # Set seed if needed
37
+ if seed == -1:
38
+ seed = torch.Generator(device="cpu").seed()
39
+
40
+ # Handle the ortho settings
41
+ if ortho == 'v2':
42
+ self.attention.ORTHO = False
43
+ self.attention.ORTHO_v2 = True
44
+ elif ortho == 'v1':
45
+ self.attention.ORTHO = True
46
+ self.attention.ORTHO_v2 = False
47
+ else:
48
+ self.attention.ORTHO = False
49
+ self.attention.ORTHO_v2 = False
50
+
51
+ # Process the images
52
+ if id_image is not None:
53
+ id_image = resize_numpy_image_long(id_image, 1024)
54
+ supp_id_image_list = [
55
+ resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
56
+ ]
57
+ id_image_list = [id_image] + supp_id_image_list
58
+ uncond_id_embedding, id_embedding = self.pipeline.get_id_embedding(id_image_list)
59
+ else:
60
+ uncond_id_embedding = None
61
+ id_embedding = None
62
+
63
+ return (prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, uncond_id_embedding, id_embedding)
64
+
65
+ def predict(self, input_data):
66
+ # Preprocess the input data
67
+ (prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, uncond_id_embedding, id_embedding) = self.preprocess(input_data)
68
+
69
+ # Run the inference pipeline
70
+ img = self.pipeline.inference(
71
+ prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
72
+ )[0]
73
+
74
+ return np.array(img), str(seed), self.pipeline.debug_img_list
75
+
76
+
77
+ # Instantiate the model handler
78
+ handler = ModelHandler()
79
+
80
+ def handler_function(input_data):
81
+ # Predict using the handler
82
+ return handler.predict(input_data)