devve1 commited on
Commit
be242ca
1 Parent(s): ee34418

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +36 -42
optimum_encoder.py CHANGED
@@ -115,9 +115,8 @@ class OptimumEncoder(BaseModel, Embeddings):
115
  docs: List[str],
116
  batch_size: int = 32,
117
  normalize_embeddings: bool = True,
118
- pooling_strategy: str = "mean",
119
- convert_to_numpy: bool = False
120
- ) -> List[List[float]] | List[Dict[str, np.ndarray]]:
121
  all_embeddings = []
122
  for i in tqdm(range(0, len(docs), batch_size)):
123
  batch_docs = docs[i : i + batch_size]
@@ -129,28 +128,23 @@ class OptimumEncoder(BaseModel, Embeddings):
129
  with self._torch.no_grad():
130
  model_output = self._model(**encoded_input)
131
 
132
- if pooling_strategy == "mean":
133
- embeddings = self._mean_pooling(
134
- model_output, encoded_input["attention_mask"]
135
- )
136
- elif pooling_strategy == "max":
137
- embeddings = self._max_pooling(
138
- model_output, encoded_input["attention_mask"]
139
- )
140
- else:
141
- raise ValueError(
142
- "Invalid pooling_strategy. Please use 'mean' or 'max'."
143
- )
144
-
145
- if normalize_embeddings:
146
- embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
147
-
148
- if convert_to_numpy:
149
- embeddings = {'text': embeddings.cpu().detach().numpy()}
150
- else:
151
- embeddings = embeddings.tolist()
152
 
153
- all_embeddings.extend(embeddings)
154
 
155
  return all_embeddings
156
 
@@ -159,7 +153,7 @@ class OptimumEncoder(BaseModel, Embeddings):
159
  docs: str,
160
  normalize_embeddings: bool = True,
161
  pooling_strategy: str = "mean"
162
- ) -> np.ndarray:
163
  encoded_input = self._tokenizer(
164
  docs, padding=True, truncation=True, return_tensors="pt"
165
  ).to(self.device)
@@ -167,23 +161,23 @@ class OptimumEncoder(BaseModel, Embeddings):
167
  with self._torch.no_grad():
168
  model_output = self._model(**encoded_input)
169
 
170
- if pooling_strategy == "mean":
171
- embeddings = self._mean_pooling(
172
- model_output, encoded_input["attention_mask"]
173
- )
174
- elif pooling_strategy == "max":
175
- embeddings = self._max_pooling(
176
- model_output, encoded_input["attention_mask"]
177
- )
178
- else:
179
- raise ValueError(
180
- "Invalid pooling_strategy. Please use 'mean' or 'max'."
181
- )
182
-
183
- if normalize_embeddings:
184
- embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
185
-
186
- return embeddings.cpu().detach().numpy()
187
 
188
  def _mean_pooling(self, model_output, attention_mask):
189
  token_embeddings = model_output[0]
 
115
  docs: List[str],
116
  batch_size: int = 32,
117
  normalize_embeddings: bool = True,
118
+ pooling_strategy: str = "mean"
119
+ ) -> List[List[float]]:
 
120
  all_embeddings = []
121
  for i in tqdm(range(0, len(docs), batch_size)):
122
  batch_docs = docs[i : i + batch_size]
 
128
  with self._torch.no_grad():
129
  model_output = self._model(**encoded_input)
130
 
131
+ if pooling_strategy == "mean":
132
+ embeddings = self._mean_pooling(
133
+ model_output, encoded_input["attention_mask"]
134
+ )
135
+ elif pooling_strategy == "max":
136
+ embeddings = self._max_pooling(
137
+ model_output, encoded_input["attention_mask"]
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "Invalid pooling_strategy. Please use 'mean' or 'max'."
142
+ )
143
+
144
+ if normalize_embeddings:
145
+ embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
 
 
 
 
 
146
 
147
+ all_embeddings.extend(embeddings.tolist())
148
 
149
  return all_embeddings
150
 
 
153
  docs: str,
154
  normalize_embeddings: bool = True,
155
  pooling_strategy: str = "mean"
156
+ ) -> List[float]:
157
  encoded_input = self._tokenizer(
158
  docs, padding=True, truncation=True, return_tensors="pt"
159
  ).to(self.device)
 
161
  with self._torch.no_grad():
162
  model_output = self._model(**encoded_input)
163
 
164
+ if pooling_strategy == "mean":
165
+ embeddings = self._mean_pooling(
166
+ model_output, encoded_input["attention_mask"]
167
+ )
168
+ elif pooling_strategy == "max":
169
+ embeddings = self._max_pooling(
170
+ model_output, encoded_input["attention_mask"]
171
+ )
172
+ else:
173
+ raise ValueError(
174
+ "Invalid pooling_strategy. Please use 'mean' or 'max'."
175
+ )
176
+
177
+ if normalize_embeddings:
178
+ embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
179
+ print(embeddings)
180
+ return embeddings.tolist()
181
 
182
  def _mean_pooling(self, model_output, attention_mask):
183
  token_embeddings = model_output[0]