amaye15 commited on
Commit
8d41aec
·
1 Parent(s): aea7238

Docstring added

Browse files
Files changed (1) hide show
  1. handler.py +35 -12
handler.py CHANGED
@@ -66,50 +66,74 @@ from io import BytesIO
66
 
67
 
68
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
 
69
  def __init__(self, path: str = "", default_batch_size: int = 4):
70
- # Import your model and processor inside the class
 
 
 
 
 
 
71
  from colpali_engine.models import ColQwen2, ColQwen2Processor
72
 
73
- # Load the model and processor
74
  self.model = ColQwen2.from_pretrained(
75
  path,
76
  torch_dtype=torch.bfloat16,
77
  ).eval()
78
  self.processor = ColQwen2Processor.from_pretrained(path)
79
 
80
- # Determine the device
81
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  self.model.to(self.device)
83
-
84
- # Set default batch size
85
  self.default_batch_size = default_batch_size
86
 
87
  def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
88
- # Prepare inputs for a batch
 
 
 
 
 
 
 
 
89
  batch_images = self.processor.process_images(images)
90
  batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
91
 
92
- # Generate embeddings
93
  with torch.no_grad():
94
  image_embeddings = self.model(**batch_images)
95
 
96
- # Convert embeddings to list format
97
  return image_embeddings.cpu().tolist()
98
 
99
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
100
- # Extract images from the input data
 
 
 
 
 
 
 
 
101
  images_data = data.get("inputs", [])
102
  batch_size = data.get("batch_size", self.default_batch_size)
103
 
104
  if not images_data:
105
  return {"error": "No images provided in 'inputs'."}
106
 
107
- # Decode and validate images
108
  images = []
109
  for img_data in images_data:
110
  if isinstance(img_data, str):
111
  try:
112
- # Assume base64-encoded image
113
  image_bytes = base64.b64decode(img_data)
114
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
115
  images.append(image)
@@ -118,7 +142,6 @@ class EndpointHandler:
118
  else:
119
  return {"error": "Images should be base64-encoded strings."}
120
 
121
- # Process in batches with the specified or default batch size
122
  embeddings = []
123
  for i in range(0, len(images), batch_size):
124
  batch_images = images[i : i + batch_size]
 
66
 
67
 
68
  class EndpointHandler:
69
+ """
70
+ A handler class for processing image data, generating embeddings using a specified model and processor.
71
+
72
+ Attributes:
73
+ model: The pre-trained model used for generating embeddings.
74
+ processor: The pre-trained processor used to process images before model inference.
75
+ device: The device (CPU or CUDA) used to run model inference.
76
+ default_batch_size: The default batch size for processing images in batches.
77
+ """
78
+
79
  def __init__(self, path: str = "", default_batch_size: int = 4):
80
+ """
81
+ Initializes the EndpointHandler with a specified model path and default batch size.
82
+
83
+ Args:
84
+ path (str): Path to the pre-trained model and processor.
85
+ default_batch_size (int): Default batch size for image processing.
86
+ """
87
  from colpali_engine.models import ColQwen2, ColQwen2Processor
88
 
 
89
  self.model = ColQwen2.from_pretrained(
90
  path,
91
  torch_dtype=torch.bfloat16,
92
  ).eval()
93
  self.processor = ColQwen2Processor.from_pretrained(path)
94
 
 
95
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
  self.model.to(self.device)
 
 
97
  self.default_batch_size = default_batch_size
98
 
99
  def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
100
+ """
101
+ Processes a batch of images and generates embeddings.
102
+
103
+ Args:
104
+ images (List[Image.Image]): List of images to process.
105
+
106
+ Returns:
107
+ List[List[float]]: List of embeddings for each image.
108
+ """
109
  batch_images = self.processor.process_images(images)
110
  batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
111
 
 
112
  with torch.no_grad():
113
  image_embeddings = self.model(**batch_images)
114
 
 
115
  return image_embeddings.cpu().tolist()
116
 
117
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
118
+ """
119
+ Processes input data containing base64-encoded images, decodes them, and generates embeddings.
120
+
121
+ Args:
122
+ data (Dict[str, Any]): Dictionary containing input images and optional batch size.
123
+
124
+ Returns:
125
+ Dict[str, Any]: Dictionary containing generated embeddings or error messages.
126
+ """
127
  images_data = data.get("inputs", [])
128
  batch_size = data.get("batch_size", self.default_batch_size)
129
 
130
  if not images_data:
131
  return {"error": "No images provided in 'inputs'."}
132
 
 
133
  images = []
134
  for img_data in images_data:
135
  if isinstance(img_data, str):
136
  try:
 
137
  image_bytes = base64.b64decode(img_data)
138
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
139
  images.append(image)
 
142
  else:
143
  return {"error": "Images should be base64-encoded strings."}
144
 
 
145
  embeddings = []
146
  for i in range(0, len(images), batch_size):
147
  batch_images = images[i : i + batch_size]