devve1 commited on
Commit
4155b42
1 Parent(s): ff3a7e6

Create splade_encoder.py

Browse files
Files changed (1) hide show
  1. splade_encoder.py +208 -0
splade_encoder.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The following code is adapted from/inspired by the 'neural-cherche' project:
3
+ https://github.com/raphaelsty/neural-cherche
4
+ Specifically, neural-cherche/neural_cherche/models/splade.py
5
+
6
+ MIT License
7
+
8
+ Copyright (c) 2023 Raphael Sourty
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+ """
28
+
29
+ import logging
30
+ from typing import Dict, List, Optional
31
+
32
+ import torch
33
+ from scipy.sparse import csr_array, vstack
34
+
35
+ from milvus_model.base import BaseEmbeddingFunction
36
+ from milvus_model.utils import import_transformers, import_scipy, import_torch
37
+
38
+ import_torch()
39
+ import_scipy()
40
+ import_transformers()
41
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(logging.DEBUG)
45
+
46
+
47
+ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
48
+ model_name: str
49
+
50
+ def __init__(
51
+ self,
52
+ model_name: str = "naver/splade-cocondenser-ensembledistil",
53
+ batch_size: int = 32,
54
+ query_instruction: str = "",
55
+ doc_instruction: str = "",
56
+ device: Optional[str] = "cpu",
57
+ k_tokens_query: Optional[int] = None,
58
+ k_tokens_document: Optional[int] = None,
59
+ **kwargs,
60
+ ):
61
+ self.model_name = model_name
62
+
63
+ _model_config = dict(
64
+ {"model_name_or_path": model_name, "batch_size": batch_size, "device": device},
65
+ **kwargs,
66
+ )
67
+ self._model_config = _model_config
68
+ self.model = _SpladeImplementation(**self._model_config)
69
+ self.device = device
70
+ self.k_tokens_query = k_tokens_query
71
+ self.k_tokens_document = k_tokens_document
72
+ self.query_instruction = query_instruction
73
+ self.doc_instruction = doc_instruction
74
+
75
+ def __call__(self, texts: List[str]) -> csr_array:
76
+ return self._encode(texts, None)
77
+
78
+ def encode_documents(self, documents: List[str]) -> csr_array:
79
+ return self._encode(
80
+ [self.doc_instruction + document for document in documents], self.k_tokens_document,
81
+ )
82
+
83
+ def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
84
+ return self.model.forward(texts, k_tokens=k_tokens)
85
+
86
+ def encode_queries(self, queries: List[str]) -> csr_array:
87
+ return self._encode(
88
+ [self.query_instruction + query for query in queries], self.k_tokens_query,
89
+ )
90
+
91
+ @property
92
+ def dim(self) -> int:
93
+ return len(self.model.tokenizer)
94
+
95
+ def _encode_query(self, query: str) -> csr_array:
96
+ return self.model.forward([self.query_instruction + query], k_tokens=self.k_tokens_query)[0]
97
+
98
+ def _encode_document(self, document: str) -> csr_array:
99
+ return self.model.forward(
100
+ [self.doc_instruction + document], k_tokens=self.k_tokens_document
101
+ )[0]
102
+
103
+
104
+ class _SpladeImplementation:
105
+ def __init__(
106
+ self,
107
+ model_name_or_path: Optional[str] = None,
108
+ device: Optional[str] = None,
109
+ batch_size: int = 32,
110
+ **kwargs,
111
+ ):
112
+ self.device = device
113
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
114
+ self.model = ORTModelForMaskedLM.from_pretrained(
115
+ model_name_or_path,
116
+
117
+ **kwargs)
118
+ self.model.to(self.device)
119
+ self.batch_size = batch_size
120
+
121
+ self.relu = torch.nn.ReLU()
122
+ self.relu.to(self.device)
123
+ self.model.config.output_hidden_states = True
124
+
125
+ def _encode(self, texts: List[str]):
126
+ encoded_input = self.tokenizer.batch_encode_plus(
127
+ texts,
128
+ truncation=True,
129
+ max_length=self.tokenizer.model_max_length,
130
+ return_tensors="pt",
131
+ add_special_tokens=True,
132
+ padding=True,
133
+ )
134
+ encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()}
135
+ output = self.model(**encoded_input)
136
+ return output.logits
137
+
138
+ def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
139
+ return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
140
+
141
+ def forward(self, texts: List[str], k_tokens: int) -> csr_array:
142
+ with torch.no_grad():
143
+ batched_texts = self._batchify(texts, self.batch_size)
144
+ sparse_embs = []
145
+ for batch_texts in batched_texts:
146
+ logits = self._encode(texts=batch_texts)
147
+ activations = self._get_activation(logits=logits)
148
+ if k_tokens is None:
149
+ nonzero_indices = torch.nonzero(activations["sparse_activations"])
150
+ activations["activations"] = nonzero_indices
151
+ else:
152
+ activations = self._update_activations(**activations, k_tokens=k_tokens)
153
+ batch_csr = self._convert_to_csr_array(activations)
154
+ sparse_embs.extend(batch_csr)
155
+
156
+ return vstack(sparse_embs).tocsr()
157
+
158
+ def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]:
159
+ return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}
160
+
161
+ def _update_activations(self, sparse_activations: torch.Tensor, k_tokens: int) -> torch.Tensor:
162
+ activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
163
+
164
+ # Set value of max sparse_activations which are not in top k to 0.
165
+ sparse_activations = sparse_activations * torch.zeros(
166
+ (sparse_activations.shape[0], sparse_activations.shape[1]),
167
+ dtype=int,
168
+ device=self.device,
169
+ ).scatter_(dim=1, index=activations.long(), value=1)
170
+
171
+ activations = torch.cat(
172
+ (
173
+ torch.arange(activations.shape[0], device=activations.device)
174
+ .repeat_interleave(activations.shape[1])
175
+ .reshape(-1, 1),
176
+ activations.reshape((-1, 1)),
177
+ ),
178
+ dim=1,
179
+ )
180
+
181
+ return {
182
+ "activations": activations,
183
+ "sparse_activations": sparse_activations,
184
+ }
185
+
186
+ def _filter_activations(
187
+ self, activations: torch.Tensor, k_tokens: int, **kwargs
188
+ ) -> torch.Tensor:
189
+ _, activations = torch.topk(input=activations, k=k_tokens, dim=1, **kwargs)
190
+ return activations
191
+
192
+ def _convert_to_csr_array(self, activations: Dict):
193
+
194
+ values = (
195
+ activations["sparse_activations"][
196
+ activations["activations"][:, 0], activations["activations"][:, 1]
197
+ ]
198
+ .cpu()
199
+ .detach()
200
+ .numpy()
201
+ )
202
+
203
+ row_indices = activations["activations"][:, 0].cpu().detach().numpy()
204
+ col_indices = activations["activations"][:, 1].cpu().detach().numpy()
205
+ return csr_array(
206
+ (values.flatten(), (row_indices, col_indices)),
207
+ shape=activations["sparse_activations"].shape,
208
+ )