ahmed-masry commited on
Commit
ceb87f6
·
verified ·
1 Parent(s): 3679015

Update modeling_colflor.py

Browse files
Files changed (1) hide show
  1. modeling_colflor.py +94 -95
modeling_colflor.py CHANGED
@@ -1,96 +1,95 @@
1
- from typing import ClassVar
2
- from colpali_engine.models.florence2.colflor.modeling_florence2 import Florence2VisionLanguageModel
3
-
4
- import torch
5
- from torch import nn
6
- from .modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel
7
- from .configuration_florence2 import Florence2Config
8
-
9
-
10
- class ColFlor2Old(Florence2ForConditionalGeneration):
11
- """
12
- ColFlor2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
13
- """
14
-
15
- main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
16
-
17
- def __init__(self, config: Florence2Config, use_cache=False):
18
- super().__init__(config=config)
19
-
20
- self.dim = 128
21
- self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
22
- # Now initialize weights properly
23
- self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
24
- self.custom_text_proj.bias.data.zero_()
25
-
26
- self.padding_side = "right"
27
- self.post_init()
28
-
29
- def forward(self, *args, **kwargs) -> torch.Tensor:
30
- # Delete output_hidden_states from kwargs
31
- kwargs.pop("output_hidden_states", None)
32
-
33
- # TO BE DELETED
34
- kwargs['decoder_input_ids'] = kwargs['input_ids']
35
-
36
- # Create Full Attention Mask that includes the image
37
- if 'full_attention_mask' in kwargs:
38
- full_attention_mask = kwargs['full_attention_mask']
39
- del kwargs['full_attention_mask']
40
- else:
41
- full_attention_mask = kwargs['attention_mask']
42
-
43
- outputs = super().forward(*args,
44
- **kwargs) # (batch_size, sequence_length, hidden_size)
45
-
46
- last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
47
-
48
- proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
49
- # L2 normalization
50
- proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
51
- proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
52
-
53
- return proj
54
-
55
-
56
- class ColFlor(Florence2VisionLanguageModel):
57
- """
58
- ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
59
- """
60
-
61
- main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
62
-
63
- def __init__(self, config: Florence2Config, use_cache=False):
64
- super().__init__(config=config)
65
-
66
- self.dim = 128
67
- self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
68
- # Now initialize weights properly
69
- self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
70
- self.custom_text_proj.bias.data.zero_()
71
-
72
- self.padding_side = "right"
73
- self.post_init()
74
-
75
- def forward(self, *args, **kwargs) -> torch.Tensor:
76
- # Delete output_hidden_states from kwargs
77
- kwargs.pop("output_hidden_states", None)
78
-
79
- # Create Full Attention Mask that includes both the image and text
80
- if 'full_attention_mask' in kwargs:
81
- full_attention_mask = kwargs['full_attention_mask']
82
- del kwargs['full_attention_mask']
83
- else:
84
- full_attention_mask = kwargs['attention_mask']
85
-
86
- outputs = super().forward(*args,
87
- **kwargs) # (batch_size, sequence_length, hidden_size)
88
-
89
- last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
90
-
91
- proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
92
- # L2 normalization
93
- proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
94
- proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
95
-
96
  return proj
 
1
+ from typing import ClassVar
2
+
3
+ import torch
4
+ from torch import nn
5
+ from modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel
6
+ from configuration_florence2 import Florence2Config
7
+
8
+
9
+ class ColFlor2Old(Florence2ForConditionalGeneration):
10
+ """
11
+ ColFlor2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
12
+ """
13
+
14
+ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
15
+
16
+ def __init__(self, config: Florence2Config, use_cache=False):
17
+ super().__init__(config=config)
18
+
19
+ self.dim = 128
20
+ self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
21
+ # Now initialize weights properly
22
+ self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
23
+ self.custom_text_proj.bias.data.zero_()
24
+
25
+ self.padding_side = "right"
26
+ self.post_init()
27
+
28
+ def forward(self, *args, **kwargs) -> torch.Tensor:
29
+ # Delete output_hidden_states from kwargs
30
+ kwargs.pop("output_hidden_states", None)
31
+
32
+ # TO BE DELETED
33
+ kwargs['decoder_input_ids'] = kwargs['input_ids']
34
+
35
+ # Create Full Attention Mask that includes the image
36
+ if 'full_attention_mask' in kwargs:
37
+ full_attention_mask = kwargs['full_attention_mask']
38
+ del kwargs['full_attention_mask']
39
+ else:
40
+ full_attention_mask = kwargs['attention_mask']
41
+
42
+ outputs = super().forward(*args,
43
+ **kwargs) # (batch_size, sequence_length, hidden_size)
44
+
45
+ last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
46
+
47
+ proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
48
+ # L2 normalization
49
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
50
+ proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
51
+
52
+ return proj
53
+
54
+
55
+ class ColFlor(Florence2VisionLanguageModel):
56
+ """
57
+ ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
58
+ """
59
+
60
+ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
61
+
62
+ def __init__(self, config: Florence2Config, use_cache=False):
63
+ super().__init__(config=config)
64
+
65
+ self.dim = 128
66
+ self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
67
+ # Now initialize weights properly
68
+ self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
69
+ self.custom_text_proj.bias.data.zero_()
70
+
71
+ self.padding_side = "right"
72
+ self.post_init()
73
+
74
+ def forward(self, *args, **kwargs) -> torch.Tensor:
75
+ # Delete output_hidden_states from kwargs
76
+ kwargs.pop("output_hidden_states", None)
77
+
78
+ # Create Full Attention Mask that includes both the image and text
79
+ if 'full_attention_mask' in kwargs:
80
+ full_attention_mask = kwargs['full_attention_mask']
81
+ del kwargs['full_attention_mask']
82
+ else:
83
+ full_attention_mask = kwargs['attention_mask']
84
+
85
+ outputs = super().forward(*args,
86
+ **kwargs) # (batch_size, sequence_length, hidden_size)
87
+
88
+ last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
89
+
90
+ proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
91
+ # L2 normalization
92
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
93
+ proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
94
+
 
95
  return proj