import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO class EndpointHandler: """ A handler class for processing image data and generating embeddings using a specified model and processor. Attributes: model (:obj:): The pre-trained model used for generating embeddings. processor (:obj:): The pre-trained processor used to process images before model inference. device (:obj:): The device (CPU or CUDA) used to run model inference. default_batch_size (:obj:int:): The default batch size for processing images in batches. """ def __init__(self, path: str = "", default_batch_size: int = 4): """ Initializes the EndpointHandler with a specified model path and default batch size. Args: path (:obj:`str`, optional): Path to the pre-trained model and processor. default_batch_size (:obj:`int`, optional): Default batch size for image processing. Return: None """ from colpali_engine.models import ColQwen2, ColQwen2Processor self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, ).eval() self.processor = ColQwen2Processor.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.default_batch_size = default_batch_size def _process_batch(self, images: List[Image.Image]) -> List[List[float]]: """ Processes a batch of images and generates embeddings. Args: images (:obj:`List[Image.Image]`): List of images to process. Return: A :obj:`List[List[float]]`. A list of embeddings for each image, where each embedding is a list of floats. """ batch_images = self.processor.process_images(images) batch_images = {k: v.to(self.device) for k, v in batch_images.items()} with torch.no_grad(): image_embeddings = self.model(**batch_images) return image_embeddings.cpu().tolist() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images, decodes them, and generates embeddings. Args: data (:obj:`Dict[str, Any]`): Includes the input data and the parameters for the inference, such as "inputs" containing a list of base64-encoded images and an optional "batch_size". Return: A :obj:`dict`. The object returned should be a dict like {"embeddings": [[0.6331314444541931, 0.8802216053009033, ..., -0.7866355180740356]]} containing: - "embeddings": A list of lists, where each inner list is a set of floats corresponding to the embeddings of each image. """ images_data = data.get("inputs", []) batch_size = data.get("batch_size", self.default_batch_size) if not images_data: return {"error": "No images provided in 'inputs'."} images = [] for img_data in images_data: if isinstance(img_data, str): try: image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") images.append(image) except Exception as e: return {"error": f"Invalid image data: {e}"} else: return {"error": "Images should be base64-encoded strings."} embeddings = [] for i in range(0, len(images), batch_size): batch_images = images[i : i + batch_size] batch_embeddings = self._process_batch(batch_images) embeddings.extend(batch_embeddings) return {"embeddings": embeddings}