Update splade_encoder.py
Browse files- splade_encoder.py +7 -10
splade_encoder.py
CHANGED
@@ -50,7 +50,6 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
|
|
50 |
def __init__(
|
51 |
self,
|
52 |
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
53 |
-
batch_size: int = 32,
|
54 |
query_instruction: str = "",
|
55 |
doc_instruction: str = "",
|
56 |
device: Optional[str] = "cpu",
|
@@ -60,7 +59,7 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
|
|
60 |
self.model_name = model_name
|
61 |
|
62 |
_model_config = dict(
|
63 |
-
{"model_name_or_path": model_name, "
|
64 |
)
|
65 |
self._model_config = _model_config
|
66 |
self.model = _SpladeImplementation(**self._model_config)
|
@@ -70,7 +69,7 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
|
|
70 |
self.query_instruction = query_instruction
|
71 |
self.doc_instruction = doc_instruction
|
72 |
|
73 |
-
def __call__(self, texts: List[str]) -> csr_array:
|
74 |
return self._encode(texts, None)
|
75 |
|
76 |
def encode_documents(self, documents: List[str]) -> csr_array:
|
@@ -78,8 +77,8 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
|
|
78 |
[self.doc_instruction + document for document in documents], self.k_tokens_document,
|
79 |
)
|
80 |
|
81 |
-
def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
|
82 |
-
return self.model.forward(texts, k_tokens=k_tokens)
|
83 |
|
84 |
def encode_queries(self, queries: List[str]) -> csr_array:
|
85 |
return self._encode(
|
@@ -103,8 +102,7 @@ class _SpladeImplementation:
|
|
103 |
def __init__(
|
104 |
self,
|
105 |
model_name_or_path: Optional[str] = None,
|
106 |
-
device: Optional[str] = None
|
107 |
-
batch_size: int = 32
|
108 |
):
|
109 |
self.device = device
|
110 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
@@ -119,7 +117,6 @@ class _SpladeImplementation:
|
|
119 |
use_io_binding=True,
|
120 |
session_options=session_options
|
121 |
)
|
122 |
-
self.batch_size = batch_size
|
123 |
|
124 |
self.relu = torch.nn.ReLU()
|
125 |
self.relu.to(self.device)
|
@@ -141,9 +138,9 @@ class _SpladeImplementation:
|
|
141 |
def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
|
142 |
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
143 |
|
144 |
-
def forward(self, texts: List[str], k_tokens: int) -> csr_array:
|
145 |
with torch.no_grad():
|
146 |
-
batched_texts = self._batchify(texts,
|
147 |
sparse_embs = []
|
148 |
for batch_texts in batched_texts:
|
149 |
logits = self._encode(texts=batch_texts)
|
|
|
50 |
def __init__(
|
51 |
self,
|
52 |
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
|
|
53 |
query_instruction: str = "",
|
54 |
doc_instruction: str = "",
|
55 |
device: Optional[str] = "cpu",
|
|
|
59 |
self.model_name = model_name
|
60 |
|
61 |
_model_config = dict(
|
62 |
+
{"model_name_or_path": model_name, "device": device}
|
63 |
)
|
64 |
self._model_config = _model_config
|
65 |
self.model = _SpladeImplementation(**self._model_config)
|
|
|
69 |
self.query_instruction = query_instruction
|
70 |
self.doc_instruction = doc_instruction
|
71 |
|
72 |
+
def __call__(self, texts: List[str], batch_size: int = 32) -> csr_array:
|
73 |
return self._encode(texts, None)
|
74 |
|
75 |
def encode_documents(self, documents: List[str]) -> csr_array:
|
|
|
77 |
[self.doc_instruction + document for document in documents], self.k_tokens_document,
|
78 |
)
|
79 |
|
80 |
+
def _encode(self, texts: List[str], k_tokens: int, batch_size: int) -> csr_array:
|
81 |
+
return self.model.forward(texts, k_tokens=k_tokens, batch_size=batch_size)
|
82 |
|
83 |
def encode_queries(self, queries: List[str]) -> csr_array:
|
84 |
return self._encode(
|
|
|
102 |
def __init__(
|
103 |
self,
|
104 |
model_name_or_path: Optional[str] = None,
|
105 |
+
device: Optional[str] = None
|
|
|
106 |
):
|
107 |
self.device = device
|
108 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
|
117 |
use_io_binding=True,
|
118 |
session_options=session_options
|
119 |
)
|
|
|
120 |
|
121 |
self.relu = torch.nn.ReLU()
|
122 |
self.relu.to(self.device)
|
|
|
138 |
def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
|
139 |
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
140 |
|
141 |
+
def forward(self, texts: List[str], k_tokens: int, batch_size: int) -> csr_array:
|
142 |
with torch.no_grad():
|
143 |
+
batched_texts = self._batchify(texts, batch_size)
|
144 |
sparse_embs = []
|
145 |
for batch_texts in batched_texts:
|
146 |
logits = self._encode(texts=batch_texts)
|