amaye15
commited on
Commit
·
8d41aec
1
Parent(s):
aea7238
Docstring added
Browse files- handler.py +35 -12
handler.py
CHANGED
@@ -66,50 +66,74 @@ from io import BytesIO
|
|
66 |
|
67 |
|
68 |
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
72 |
|
73 |
-
# Load the model and processor
|
74 |
self.model = ColQwen2.from_pretrained(
|
75 |
path,
|
76 |
torch_dtype=torch.bfloat16,
|
77 |
).eval()
|
78 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
79 |
|
80 |
-
# Determine the device
|
81 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
82 |
self.model.to(self.device)
|
83 |
-
|
84 |
-
# Set default batch size
|
85 |
self.default_batch_size = default_batch_size
|
86 |
|
87 |
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
batch_images = self.processor.process_images(images)
|
90 |
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
91 |
|
92 |
-
# Generate embeddings
|
93 |
with torch.no_grad():
|
94 |
image_embeddings = self.model(**batch_images)
|
95 |
|
96 |
-
# Convert embeddings to list format
|
97 |
return image_embeddings.cpu().tolist()
|
98 |
|
99 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
images_data = data.get("inputs", [])
|
102 |
batch_size = data.get("batch_size", self.default_batch_size)
|
103 |
|
104 |
if not images_data:
|
105 |
return {"error": "No images provided in 'inputs'."}
|
106 |
|
107 |
-
# Decode and validate images
|
108 |
images = []
|
109 |
for img_data in images_data:
|
110 |
if isinstance(img_data, str):
|
111 |
try:
|
112 |
-
# Assume base64-encoded image
|
113 |
image_bytes = base64.b64decode(img_data)
|
114 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
115 |
images.append(image)
|
@@ -118,7 +142,6 @@ class EndpointHandler:
|
|
118 |
else:
|
119 |
return {"error": "Images should be base64-encoded strings."}
|
120 |
|
121 |
-
# Process in batches with the specified or default batch size
|
122 |
embeddings = []
|
123 |
for i in range(0, len(images), batch_size):
|
124 |
batch_images = images[i : i + batch_size]
|
|
|
66 |
|
67 |
|
68 |
class EndpointHandler:
|
69 |
+
"""
|
70 |
+
A handler class for processing image data, generating embeddings using a specified model and processor.
|
71 |
+
|
72 |
+
Attributes:
|
73 |
+
model: The pre-trained model used for generating embeddings.
|
74 |
+
processor: The pre-trained processor used to process images before model inference.
|
75 |
+
device: The device (CPU or CUDA) used to run model inference.
|
76 |
+
default_batch_size: The default batch size for processing images in batches.
|
77 |
+
"""
|
78 |
+
|
79 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
80 |
+
"""
|
81 |
+
Initializes the EndpointHandler with a specified model path and default batch size.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
path (str): Path to the pre-trained model and processor.
|
85 |
+
default_batch_size (int): Default batch size for image processing.
|
86 |
+
"""
|
87 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
88 |
|
|
|
89 |
self.model = ColQwen2.from_pretrained(
|
90 |
path,
|
91 |
torch_dtype=torch.bfloat16,
|
92 |
).eval()
|
93 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
94 |
|
|
|
95 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
96 |
self.model.to(self.device)
|
|
|
|
|
97 |
self.default_batch_size = default_batch_size
|
98 |
|
99 |
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
100 |
+
"""
|
101 |
+
Processes a batch of images and generates embeddings.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
images (List[Image.Image]): List of images to process.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
List[List[float]]: List of embeddings for each image.
|
108 |
+
"""
|
109 |
batch_images = self.processor.process_images(images)
|
110 |
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
111 |
|
|
|
112 |
with torch.no_grad():
|
113 |
image_embeddings = self.model(**batch_images)
|
114 |
|
|
|
115 |
return image_embeddings.cpu().tolist()
|
116 |
|
117 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
118 |
+
"""
|
119 |
+
Processes input data containing base64-encoded images, decodes them, and generates embeddings.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
data (Dict[str, Any]): Dictionary containing input images and optional batch size.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Dict[str, Any]: Dictionary containing generated embeddings or error messages.
|
126 |
+
"""
|
127 |
images_data = data.get("inputs", [])
|
128 |
batch_size = data.get("batch_size", self.default_batch_size)
|
129 |
|
130 |
if not images_data:
|
131 |
return {"error": "No images provided in 'inputs'."}
|
132 |
|
|
|
133 |
images = []
|
134 |
for img_data in images_data:
|
135 |
if isinstance(img_data, str):
|
136 |
try:
|
|
|
137 |
image_bytes = base64.b64decode(img_data)
|
138 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
139 |
images.append(image)
|
|
|
142 |
else:
|
143 |
return {"error": "Images should be base64-encoded strings."}
|
144 |
|
|
|
145 |
embeddings = []
|
146 |
for i in range(0, len(images), batch_size):
|
147 |
batch_images = images[i : i + batch_size]
|