ahmed-masry commited on
Commit
e8ad0b4
·
verified ·
1 Parent(s): 9194fb8

Update processing_colflor.py

Browse files
Files changed (1) hide show
  1. processing_colflor.py +80 -80
processing_colflor.py CHANGED
@@ -1,81 +1,81 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
- from PIL import Image
5
- from transformers import BatchFeature
6
-
7
- from .processing_florence2 import Florence2Processor
8
-
9
- from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
10
-
11
-
12
- class ColFlorProcessor(BaseVisualRetrieverProcessor, Florence2Processor):
13
- """
14
- Processor for ColPali.
15
- """
16
-
17
- def __init__(self, *args, **kwargs):
18
- super().__init__(*args, **kwargs)
19
- self.mock_image = Image.new("RGB", (16, 16), color="black")
20
-
21
- def process_images(
22
- self,
23
- images: List[Image.Image],
24
- ) -> BatchFeature:
25
- """
26
- Process images for ColFlor2.
27
- """
28
- texts_doc = ["<OCR>"] * len(images)
29
- images = [image.convert("RGB") for image in images]
30
-
31
- batch_doc = self(
32
- text=texts_doc,
33
- images=images,
34
- return_tensors="pt",
35
- padding="longest",
36
- )
37
-
38
- new_part = torch.ones((batch_doc['attention_mask'].size()[0], 577)).to(batch_doc['attention_mask'].device)
39
- batch_doc['full_attention_mask'] = torch.cat([new_part, batch_doc['attention_mask']], dim=1)
40
-
41
- return batch_doc
42
-
43
- def process_queries(
44
- self,
45
- queries: List[str],
46
- max_length: int = 50,
47
- suffix: Optional[str] = None,
48
- ) -> BatchFeature:
49
- """
50
- Process queries for ColFlor2.
51
- """
52
- if suffix is None:
53
- suffix = "<pad>" * 10
54
- texts_query: List[str] = []
55
-
56
- for query in queries:
57
- query = f"Question: {query}"
58
- query += suffix # add suffix (pad tokens)
59
- texts_query.append(query)
60
-
61
- batch_query = self.tokenizer(
62
- #images=[self.mock_image] * len(texts_query),
63
- text=texts_query,
64
- return_tensors="pt",
65
- padding="longest",
66
- max_length= max_length + self.image_seq_length,
67
- )
68
-
69
- return batch_query
70
-
71
- def score(
72
- self,
73
- qs: List[torch.Tensor],
74
- ps: List[torch.Tensor],
75
- device: Optional[Union[str, torch.device]] = None,
76
- **kwargs,
77
- ) -> torch.Tensor:
78
- """
79
- Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
80
- """
81
  return self.score_multi_vector(qs, ps, device=device, **kwargs)
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import BatchFeature
6
+
7
+ from processing_florence2 import Florence2Processor
8
+
9
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
10
+
11
+
12
+ class ColFlorProcessor(BaseVisualRetrieverProcessor, Florence2Processor):
13
+ """
14
+ Processor for ColPali.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.mock_image = Image.new("RGB", (16, 16), color="black")
20
+
21
+ def process_images(
22
+ self,
23
+ images: List[Image.Image],
24
+ ) -> BatchFeature:
25
+ """
26
+ Process images for ColFlor2.
27
+ """
28
+ texts_doc = ["<OCR>"] * len(images)
29
+ images = [image.convert("RGB") for image in images]
30
+
31
+ batch_doc = self(
32
+ text=texts_doc,
33
+ images=images,
34
+ return_tensors="pt",
35
+ padding="longest",
36
+ )
37
+
38
+ new_part = torch.ones((batch_doc['attention_mask'].size()[0], 577)).to(batch_doc['attention_mask'].device)
39
+ batch_doc['full_attention_mask'] = torch.cat([new_part, batch_doc['attention_mask']], dim=1)
40
+
41
+ return batch_doc
42
+
43
+ def process_queries(
44
+ self,
45
+ queries: List[str],
46
+ max_length: int = 50,
47
+ suffix: Optional[str] = None,
48
+ ) -> BatchFeature:
49
+ """
50
+ Process queries for ColFlor2.
51
+ """
52
+ if suffix is None:
53
+ suffix = "<pad>" * 10
54
+ texts_query: List[str] = []
55
+
56
+ for query in queries:
57
+ query = f"Question: {query}"
58
+ query += suffix # add suffix (pad tokens)
59
+ texts_query.append(query)
60
+
61
+ batch_query = self.tokenizer(
62
+ #images=[self.mock_image] * len(texts_query),
63
+ text=texts_query,
64
+ return_tensors="pt",
65
+ padding="longest",
66
+ max_length= max_length + self.image_seq_length,
67
+ )
68
+
69
+ return batch_query
70
+
71
+ def score(
72
+ self,
73
+ qs: List[torch.Tensor],
74
+ ps: List[torch.Tensor],
75
+ device: Optional[Union[str, torch.device]] = None,
76
+ **kwargs,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
80
+ """
81
  return self.score_multi_vector(qs, ps, device=device, **kwargs)