Update optimum_encoder.py
Browse files- optimum_encoder.py +36 -42
optimum_encoder.py
CHANGED
@@ -115,9 +115,8 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
115 |
docs: List[str],
|
116 |
batch_size: int = 32,
|
117 |
normalize_embeddings: bool = True,
|
118 |
-
pooling_strategy: str = "mean"
|
119 |
-
|
120 |
-
) -> List[List[float]] | List[Dict[str, np.ndarray]]:
|
121 |
all_embeddings = []
|
122 |
for i in tqdm(range(0, len(docs), batch_size)):
|
123 |
batch_docs = docs[i : i + batch_size]
|
@@ -129,28 +128,23 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
129 |
with self._torch.no_grad():
|
130 |
model_output = self._model(**encoded_input)
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
if convert_to_numpy:
|
149 |
-
embeddings = {'text': embeddings.cpu().detach().numpy()}
|
150 |
-
else:
|
151 |
-
embeddings = embeddings.tolist()
|
152 |
|
153 |
-
|
154 |
|
155 |
return all_embeddings
|
156 |
|
@@ -159,7 +153,7 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
159 |
docs: str,
|
160 |
normalize_embeddings: bool = True,
|
161 |
pooling_strategy: str = "mean"
|
162 |
-
) ->
|
163 |
encoded_input = self._tokenizer(
|
164 |
docs, padding=True, truncation=True, return_tensors="pt"
|
165 |
).to(self.device)
|
@@ -167,23 +161,23 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
167 |
with self._torch.no_grad():
|
168 |
model_output = self._model(**encoded_input)
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
return embeddings.
|
187 |
|
188 |
def _mean_pooling(self, model_output, attention_mask):
|
189 |
token_embeddings = model_output[0]
|
|
|
115 |
docs: List[str],
|
116 |
batch_size: int = 32,
|
117 |
normalize_embeddings: bool = True,
|
118 |
+
pooling_strategy: str = "mean"
|
119 |
+
) -> List[List[float]]:
|
|
|
120 |
all_embeddings = []
|
121 |
for i in tqdm(range(0, len(docs), batch_size)):
|
122 |
batch_docs = docs[i : i + batch_size]
|
|
|
128 |
with self._torch.no_grad():
|
129 |
model_output = self._model(**encoded_input)
|
130 |
|
131 |
+
if pooling_strategy == "mean":
|
132 |
+
embeddings = self._mean_pooling(
|
133 |
+
model_output, encoded_input["attention_mask"]
|
134 |
+
)
|
135 |
+
elif pooling_strategy == "max":
|
136 |
+
embeddings = self._max_pooling(
|
137 |
+
model_output, encoded_input["attention_mask"]
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
raise ValueError(
|
141 |
+
"Invalid pooling_strategy. Please use 'mean' or 'max'."
|
142 |
+
)
|
143 |
+
|
144 |
+
if normalize_embeddings:
|
145 |
+
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
all_embeddings.extend(embeddings.tolist())
|
148 |
|
149 |
return all_embeddings
|
150 |
|
|
|
153 |
docs: str,
|
154 |
normalize_embeddings: bool = True,
|
155 |
pooling_strategy: str = "mean"
|
156 |
+
) -> List[float]:
|
157 |
encoded_input = self._tokenizer(
|
158 |
docs, padding=True, truncation=True, return_tensors="pt"
|
159 |
).to(self.device)
|
|
|
161 |
with self._torch.no_grad():
|
162 |
model_output = self._model(**encoded_input)
|
163 |
|
164 |
+
if pooling_strategy == "mean":
|
165 |
+
embeddings = self._mean_pooling(
|
166 |
+
model_output, encoded_input["attention_mask"]
|
167 |
+
)
|
168 |
+
elif pooling_strategy == "max":
|
169 |
+
embeddings = self._max_pooling(
|
170 |
+
model_output, encoded_input["attention_mask"]
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise ValueError(
|
174 |
+
"Invalid pooling_strategy. Please use 'mean' or 'max'."
|
175 |
+
)
|
176 |
+
|
177 |
+
if normalize_embeddings:
|
178 |
+
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
179 |
+
print(embeddings)
|
180 |
+
return embeddings.tolist()
|
181 |
|
182 |
def _mean_pooling(self, model_output, attention_mask):
|
183 |
token_embeddings = model_output[0]
|