Create splade_encoder.py
Browse files- splade_encoder.py +208 -0
splade_encoder.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The following code is adapted from/inspired by the 'neural-cherche' project:
|
3 |
+
https://github.com/raphaelsty/neural-cherche
|
4 |
+
Specifically, neural-cherche/neural_cherche/models/splade.py
|
5 |
+
|
6 |
+
MIT License
|
7 |
+
|
8 |
+
Copyright (c) 2023 Raphael Sourty
|
9 |
+
|
10 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
11 |
+
of this software and associated documentation files (the "Software"), to deal
|
12 |
+
in the Software without restriction, including without limitation the rights
|
13 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
14 |
+
copies of the Software, and to permit persons to whom the Software is
|
15 |
+
furnished to do so, subject to the following conditions:
|
16 |
+
|
17 |
+
The above copyright notice and this permission notice shall be included in all
|
18 |
+
copies or substantial portions of the Software.
|
19 |
+
|
20 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
21 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
22 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
23 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
24 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
25 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
26 |
+
SOFTWARE.
|
27 |
+
"""
|
28 |
+
|
29 |
+
import logging
|
30 |
+
from typing import Dict, List, Optional
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from scipy.sparse import csr_array, vstack
|
34 |
+
|
35 |
+
from milvus_model.base import BaseEmbeddingFunction
|
36 |
+
from milvus_model.utils import import_transformers, import_scipy, import_torch
|
37 |
+
|
38 |
+
import_torch()
|
39 |
+
import_scipy()
|
40 |
+
import_transformers()
|
41 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
logger.setLevel(logging.DEBUG)
|
45 |
+
|
46 |
+
|
47 |
+
class SpladeEmbeddingFunction(BaseEmbeddingFunction):
|
48 |
+
model_name: str
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
53 |
+
batch_size: int = 32,
|
54 |
+
query_instruction: str = "",
|
55 |
+
doc_instruction: str = "",
|
56 |
+
device: Optional[str] = "cpu",
|
57 |
+
k_tokens_query: Optional[int] = None,
|
58 |
+
k_tokens_document: Optional[int] = None,
|
59 |
+
**kwargs,
|
60 |
+
):
|
61 |
+
self.model_name = model_name
|
62 |
+
|
63 |
+
_model_config = dict(
|
64 |
+
{"model_name_or_path": model_name, "batch_size": batch_size, "device": device},
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
self._model_config = _model_config
|
68 |
+
self.model = _SpladeImplementation(**self._model_config)
|
69 |
+
self.device = device
|
70 |
+
self.k_tokens_query = k_tokens_query
|
71 |
+
self.k_tokens_document = k_tokens_document
|
72 |
+
self.query_instruction = query_instruction
|
73 |
+
self.doc_instruction = doc_instruction
|
74 |
+
|
75 |
+
def __call__(self, texts: List[str]) -> csr_array:
|
76 |
+
return self._encode(texts, None)
|
77 |
+
|
78 |
+
def encode_documents(self, documents: List[str]) -> csr_array:
|
79 |
+
return self._encode(
|
80 |
+
[self.doc_instruction + document for document in documents], self.k_tokens_document,
|
81 |
+
)
|
82 |
+
|
83 |
+
def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
|
84 |
+
return self.model.forward(texts, k_tokens=k_tokens)
|
85 |
+
|
86 |
+
def encode_queries(self, queries: List[str]) -> csr_array:
|
87 |
+
return self._encode(
|
88 |
+
[self.query_instruction + query for query in queries], self.k_tokens_query,
|
89 |
+
)
|
90 |
+
|
91 |
+
@property
|
92 |
+
def dim(self) -> int:
|
93 |
+
return len(self.model.tokenizer)
|
94 |
+
|
95 |
+
def _encode_query(self, query: str) -> csr_array:
|
96 |
+
return self.model.forward([self.query_instruction + query], k_tokens=self.k_tokens_query)[0]
|
97 |
+
|
98 |
+
def _encode_document(self, document: str) -> csr_array:
|
99 |
+
return self.model.forward(
|
100 |
+
[self.doc_instruction + document], k_tokens=self.k_tokens_document
|
101 |
+
)[0]
|
102 |
+
|
103 |
+
|
104 |
+
class _SpladeImplementation:
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
model_name_or_path: Optional[str] = None,
|
108 |
+
device: Optional[str] = None,
|
109 |
+
batch_size: int = 32,
|
110 |
+
**kwargs,
|
111 |
+
):
|
112 |
+
self.device = device
|
113 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
114 |
+
self.model = ORTModelForMaskedLM.from_pretrained(
|
115 |
+
model_name_or_path,
|
116 |
+
|
117 |
+
**kwargs)
|
118 |
+
self.model.to(self.device)
|
119 |
+
self.batch_size = batch_size
|
120 |
+
|
121 |
+
self.relu = torch.nn.ReLU()
|
122 |
+
self.relu.to(self.device)
|
123 |
+
self.model.config.output_hidden_states = True
|
124 |
+
|
125 |
+
def _encode(self, texts: List[str]):
|
126 |
+
encoded_input = self.tokenizer.batch_encode_plus(
|
127 |
+
texts,
|
128 |
+
truncation=True,
|
129 |
+
max_length=self.tokenizer.model_max_length,
|
130 |
+
return_tensors="pt",
|
131 |
+
add_special_tokens=True,
|
132 |
+
padding=True,
|
133 |
+
)
|
134 |
+
encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()}
|
135 |
+
output = self.model(**encoded_input)
|
136 |
+
return output.logits
|
137 |
+
|
138 |
+
def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
|
139 |
+
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
140 |
+
|
141 |
+
def forward(self, texts: List[str], k_tokens: int) -> csr_array:
|
142 |
+
with torch.no_grad():
|
143 |
+
batched_texts = self._batchify(texts, self.batch_size)
|
144 |
+
sparse_embs = []
|
145 |
+
for batch_texts in batched_texts:
|
146 |
+
logits = self._encode(texts=batch_texts)
|
147 |
+
activations = self._get_activation(logits=logits)
|
148 |
+
if k_tokens is None:
|
149 |
+
nonzero_indices = torch.nonzero(activations["sparse_activations"])
|
150 |
+
activations["activations"] = nonzero_indices
|
151 |
+
else:
|
152 |
+
activations = self._update_activations(**activations, k_tokens=k_tokens)
|
153 |
+
batch_csr = self._convert_to_csr_array(activations)
|
154 |
+
sparse_embs.extend(batch_csr)
|
155 |
+
|
156 |
+
return vstack(sparse_embs).tocsr()
|
157 |
+
|
158 |
+
def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]:
|
159 |
+
return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}
|
160 |
+
|
161 |
+
def _update_activations(self, sparse_activations: torch.Tensor, k_tokens: int) -> torch.Tensor:
|
162 |
+
activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
|
163 |
+
|
164 |
+
# Set value of max sparse_activations which are not in top k to 0.
|
165 |
+
sparse_activations = sparse_activations * torch.zeros(
|
166 |
+
(sparse_activations.shape[0], sparse_activations.shape[1]),
|
167 |
+
dtype=int,
|
168 |
+
device=self.device,
|
169 |
+
).scatter_(dim=1, index=activations.long(), value=1)
|
170 |
+
|
171 |
+
activations = torch.cat(
|
172 |
+
(
|
173 |
+
torch.arange(activations.shape[0], device=activations.device)
|
174 |
+
.repeat_interleave(activations.shape[1])
|
175 |
+
.reshape(-1, 1),
|
176 |
+
activations.reshape((-1, 1)),
|
177 |
+
),
|
178 |
+
dim=1,
|
179 |
+
)
|
180 |
+
|
181 |
+
return {
|
182 |
+
"activations": activations,
|
183 |
+
"sparse_activations": sparse_activations,
|
184 |
+
}
|
185 |
+
|
186 |
+
def _filter_activations(
|
187 |
+
self, activations: torch.Tensor, k_tokens: int, **kwargs
|
188 |
+
) -> torch.Tensor:
|
189 |
+
_, activations = torch.topk(input=activations, k=k_tokens, dim=1, **kwargs)
|
190 |
+
return activations
|
191 |
+
|
192 |
+
def _convert_to_csr_array(self, activations: Dict):
|
193 |
+
|
194 |
+
values = (
|
195 |
+
activations["sparse_activations"][
|
196 |
+
activations["activations"][:, 0], activations["activations"][:, 1]
|
197 |
+
]
|
198 |
+
.cpu()
|
199 |
+
.detach()
|
200 |
+
.numpy()
|
201 |
+
)
|
202 |
+
|
203 |
+
row_indices = activations["activations"][:, 0].cpu().detach().numpy()
|
204 |
+
col_indices = activations["activations"][:, 1].cpu().detach().numpy()
|
205 |
+
return csr_array(
|
206 |
+
(values.flatten(), (row_indices, col_indices)),
|
207 |
+
shape=activations["sparse_activations"].shape,
|
208 |
+
)
|