devve1 commited on
Commit
1e0e61e
1 Parent(s): b3e09e9

Update splade_encoder.py

Browse files
Files changed (1) hide show
  1. splade_encoder.py +7 -10
splade_encoder.py CHANGED
@@ -50,7 +50,6 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
50
  def __init__(
51
  self,
52
  model_name: str = "naver/splade-cocondenser-ensembledistil",
53
- batch_size: int = 32,
54
  query_instruction: str = "",
55
  doc_instruction: str = "",
56
  device: Optional[str] = "cpu",
@@ -60,7 +59,7 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
60
  self.model_name = model_name
61
 
62
  _model_config = dict(
63
- {"model_name_or_path": model_name, "batch_size": batch_size, "device": device}
64
  )
65
  self._model_config = _model_config
66
  self.model = _SpladeImplementation(**self._model_config)
@@ -70,7 +69,7 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
70
  self.query_instruction = query_instruction
71
  self.doc_instruction = doc_instruction
72
 
73
- def __call__(self, texts: List[str]) -> csr_array:
74
  return self._encode(texts, None)
75
 
76
  def encode_documents(self, documents: List[str]) -> csr_array:
@@ -78,8 +77,8 @@ class SpladeEmbeddingFunction(BaseEmbeddingFunction):
78
  [self.doc_instruction + document for document in documents], self.k_tokens_document,
79
  )
80
 
81
- def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
82
- return self.model.forward(texts, k_tokens=k_tokens)
83
 
84
  def encode_queries(self, queries: List[str]) -> csr_array:
85
  return self._encode(
@@ -103,8 +102,7 @@ class _SpladeImplementation:
103
  def __init__(
104
  self,
105
  model_name_or_path: Optional[str] = None,
106
- device: Optional[str] = None,
107
- batch_size: int = 32
108
  ):
109
  self.device = device
110
  self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
@@ -119,7 +117,6 @@ class _SpladeImplementation:
119
  use_io_binding=True,
120
  session_options=session_options
121
  )
122
- self.batch_size = batch_size
123
 
124
  self.relu = torch.nn.ReLU()
125
  self.relu.to(self.device)
@@ -141,9 +138,9 @@ class _SpladeImplementation:
141
  def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
142
  return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
143
 
144
- def forward(self, texts: List[str], k_tokens: int) -> csr_array:
145
  with torch.no_grad():
146
- batched_texts = self._batchify(texts, self.batch_size)
147
  sparse_embs = []
148
  for batch_texts in batched_texts:
149
  logits = self._encode(texts=batch_texts)
 
50
  def __init__(
51
  self,
52
  model_name: str = "naver/splade-cocondenser-ensembledistil",
 
53
  query_instruction: str = "",
54
  doc_instruction: str = "",
55
  device: Optional[str] = "cpu",
 
59
  self.model_name = model_name
60
 
61
  _model_config = dict(
62
+ {"model_name_or_path": model_name, "device": device}
63
  )
64
  self._model_config = _model_config
65
  self.model = _SpladeImplementation(**self._model_config)
 
69
  self.query_instruction = query_instruction
70
  self.doc_instruction = doc_instruction
71
 
72
+ def __call__(self, texts: List[str], batch_size: int = 32) -> csr_array:
73
  return self._encode(texts, None)
74
 
75
  def encode_documents(self, documents: List[str]) -> csr_array:
 
77
  [self.doc_instruction + document for document in documents], self.k_tokens_document,
78
  )
79
 
80
+ def _encode(self, texts: List[str], k_tokens: int, batch_size: int) -> csr_array:
81
+ return self.model.forward(texts, k_tokens=k_tokens, batch_size=batch_size)
82
 
83
  def encode_queries(self, queries: List[str]) -> csr_array:
84
  return self._encode(
 
102
  def __init__(
103
  self,
104
  model_name_or_path: Optional[str] = None,
105
+ device: Optional[str] = None
 
106
  ):
107
  self.device = device
108
  self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
 
117
  use_io_binding=True,
118
  session_options=session_options
119
  )
 
120
 
121
  self.relu = torch.nn.ReLU()
122
  self.relu.to(self.device)
 
138
  def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
139
  return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
140
 
141
+ def forward(self, texts: List[str], k_tokens: int, batch_size: int) -> csr_array:
142
  with torch.no_grad():
143
+ batched_texts = self._batchify(texts, batch_size)
144
  sparse_embs = []
145
  for batch_texts in batched_texts:
146
  logits = self._encode(texts=batch_texts)