amaye15 commited on
Commit
c26d617
·
1 Parent(s): 8d41aec

Docstring V2

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. handler.py +25 -73
.gitignore CHANGED
@@ -1 +1,2 @@
1
- *.DS*
 
 
1
+ *.DS*
2
+ *__pycache__*
handler.py CHANGED
@@ -1,63 +1,3 @@
1
- # import torch
2
- # from typing import Dict, Any
3
- # from PIL import Image
4
- # import base64
5
- # from io import BytesIO
6
-
7
-
8
- # class EndpointHandler:
9
- # def __init__(self, path: str = ""):
10
- # # Import your model and processor inside the class
11
- # from colpali_engine.models import ColQwen2, ColQwen2Processor
12
-
13
- # # Load the model and processor
14
- # self.model = ColQwen2.from_pretrained(
15
- # path,
16
- # torch_dtype=torch.bfloat16,
17
- # ).eval()
18
- # self.processor = ColQwen2Processor.from_pretrained(path)
19
-
20
- # # Determine the device
21
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- # self.model.to(self.device)
23
-
24
- # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
- # # Extract images from the input data
26
- # images_data = data.get("inputs", [])
27
-
28
- # if not images_data:
29
- # return {"error": "No images provided in 'inputs'."}
30
-
31
- # # Process images
32
- # images = []
33
- # for img_data in images_data:
34
- # if isinstance(img_data, str):
35
- # try:
36
- # # Assume base64-encoded image
37
- # image_bytes = base64.b64decode(img_data)
38
- # image = Image.open(BytesIO(image_bytes)).convert("RGB")
39
- # images.append(image)
40
- # except Exception as e:
41
- # return {"error": f"Invalid image data: {e}"}
42
- # else:
43
- # return {"error": "Images should be base64-encoded strings."}
44
-
45
- # # Prepare inputs
46
- # batch_images = self.processor.process_images(images)
47
-
48
- # # Move tensors to the device
49
- # batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
50
-
51
- # # Generate embeddings
52
- # with torch.no_grad():
53
- # image_embeddings = self.model(**batch_images)
54
-
55
- # # Convert embeddings to a list
56
- # embeddings_list = image_embeddings.cpu().tolist()
57
-
58
- # return {"embeddings": embeddings_list}
59
-
60
-
61
  import torch
62
  from typing import Dict, Any, List
63
  from PIL import Image
@@ -67,13 +7,17 @@ from io import BytesIO
67
 
68
  class EndpointHandler:
69
  """
70
- A handler class for processing image data, generating embeddings using a specified model and processor.
71
 
72
  Attributes:
73
- model: The pre-trained model used for generating embeddings.
74
- processor: The pre-trained processor used to process images before model inference.
75
- device: The device (CPU or CUDA) used to run model inference.
76
- default_batch_size: The default batch size for processing images in batches.
 
 
 
 
77
  """
78
 
79
  def __init__(self, path: str = "", default_batch_size: int = 4):
@@ -81,8 +25,13 @@ class EndpointHandler:
81
  Initializes the EndpointHandler with a specified model path and default batch size.
82
 
83
  Args:
84
- path (str): Path to the pre-trained model and processor.
85
- default_batch_size (int): Default batch size for image processing.
 
 
 
 
 
86
  """
87
  from colpali_engine.models import ColQwen2, ColQwen2Processor
88
 
@@ -101,10 +50,11 @@ class EndpointHandler:
101
  Processes a batch of images and generates embeddings.
102
 
103
  Args:
104
- images (List[Image.Image]): List of images to process.
 
105
 
106
- Returns:
107
- List[List[float]]: List of embeddings for each image.
108
  """
109
  batch_images = self.processor.process_images(images)
110
  batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
@@ -119,10 +69,12 @@ class EndpointHandler:
119
  Processes input data containing base64-encoded images, decodes them, and generates embeddings.
120
 
121
  Args:
122
- data (Dict[str, Any]): Dictionary containing input images and optional batch size.
 
123
 
124
- Returns:
125
- Dict[str, Any]: Dictionary containing generated embeddings or error messages.
 
126
  """
127
  images_data = data.get("inputs", [])
128
  batch_size = data.get("batch_size", self.default_batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Any, List
3
  from PIL import Image
 
7
 
8
  class EndpointHandler:
9
  """
10
+ A handler class for processing image data and generating embeddings using a specified model and processor.
11
 
12
  Attributes:
13
+ model (:obj:):
14
+ The pre-trained model used for generating embeddings.
15
+ processor (:obj:):
16
+ The pre-trained processor used to process images before model inference.
17
+ device (:obj:):
18
+ The device (CPU or CUDA) used to run model inference.
19
+ default_batch_size (:obj:int:):
20
+ The default batch size for processing images in batches.
21
  """
22
 
23
  def __init__(self, path: str = "", default_batch_size: int = 4):
 
25
  Initializes the EndpointHandler with a specified model path and default batch size.
26
 
27
  Args:
28
+ path (:obj:`str`, optional):
29
+ Path to the pre-trained model and processor.
30
+ default_batch_size (:obj:`int`, optional):
31
+ Default batch size for image processing.
32
+
33
+ Return:
34
+ None
35
  """
36
  from colpali_engine.models import ColQwen2, ColQwen2Processor
37
 
 
50
  Processes a batch of images and generates embeddings.
51
 
52
  Args:
53
+ images (:obj:`List[Image.Image]`):
54
+ List of images to process.
55
 
56
+ Return:
57
+ A :obj:`List[List[float]]`. A list of embeddings for each image, where each embedding is a list of floats.
58
  """
59
  batch_images = self.processor.process_images(images)
60
  batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
 
69
  Processes input data containing base64-encoded images, decodes them, and generates embeddings.
70
 
71
  Args:
72
+ data (:obj:`Dict[str, Any]`):
73
+ Includes the input data and the parameters for the inference, such as "inputs" containing a list of base64-encoded images and an optional "batch_size".
74
 
75
+ Return:
76
+ A :obj:`dict`. The object returned should be a dict like {"embeddings": [[0.6331314444541931, 0.8802216053009033, ..., -0.7866355180740356]]} containing:
77
+ - "embeddings": A list of lists, where each inner list is a set of floats corresponding to the embeddings of each image.
78
  """
79
  images_data = data.get("inputs", [])
80
  batch_size = data.get("batch_size", self.default_batch_size)