devve1 commited on
Commit
eb93a92
1 Parent(s): a1ddbe5

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +1 -136
optimum_encoder.py CHANGED
@@ -154,139 +154,4 @@ class OptimumEncoder(BaseEncoder):
154
  attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
155
  )
156
  token_embeddings[input_mask_expanded == 0] = -1e9
157
- return self._torch.max(token_embeddings, 1)[0]
158
-
159
-
160
- class HFEndpointEncoder(BaseEncoder):
161
- """
162
- A class to encode documents using a Hugging Face transformer model endpoint.
163
-
164
- Attributes:
165
- huggingface_url (str): The URL of the Hugging Face API endpoint.
166
- huggingface_api_key (str): The API key for authenticating with the Hugging Face API.
167
- score_threshold (float): A threshold value used for filtering or processing the embeddings.
168
- """
169
-
170
- name: str = "hugging_face_custom_endpoint"
171
- huggingface_url: Optional[str] = None
172
- huggingface_api_key: Optional[str] = None
173
- score_threshold: float = 0.8
174
-
175
- def __init__(
176
- self,
177
- name: Optional[str] = "hugging_face_custom_endpoint",
178
- huggingface_url: Optional[str] = None,
179
- huggingface_api_key: Optional[str] = None,
180
- score_threshold: float = 0.8,
181
- ):
182
- """
183
- Initializes the HFEndpointEncoder with the specified parameters.
184
-
185
- Args:
186
- name (str, optional): The name of the encoder. Defaults to
187
- "hugging_face_custom_endpoint".
188
- huggingface_url (str, optional): The URL of the Hugging Face API endpoint.
189
- Cannot be None.
190
- huggingface_api_key (str, optional): The API key for the Hugging Face API.
191
- Cannot be None.
192
- score_threshold (float, optional): A threshold for processing the embeddings.
193
- Defaults to 0.8.
194
-
195
- Raises:
196
- ValueError: If either `huggingface_url` or `huggingface_api_key` is None.
197
- """
198
- huggingface_url = huggingface_url or os.getenv("HF_API_URL")
199
- huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
200
-
201
- super().__init__(name=name, score_threshold=score_threshold) # type: ignore
202
-
203
- if huggingface_url is None:
204
- raise ValueError("HuggingFace endpoint url cannot be 'None'.")
205
- if huggingface_api_key is None:
206
- raise ValueError("HuggingFace API key cannot be 'None'.")
207
-
208
- self.huggingface_url = huggingface_url or os.getenv("HF_API_URL")
209
- self.huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
210
-
211
- try:
212
- self.query({"inputs": "Hello World!", "parameters": {}})
213
- except Exception as e:
214
- raise ValueError(
215
- f"HuggingFace endpoint client failed to initialize. Error: {e}"
216
- ) from e
217
-
218
- def __call__(self, docs: List[str]) -> List[List[float]]:
219
- """
220
- Encodes a list of documents into embeddings using the Hugging Face API.
221
-
222
- Args:
223
- docs (List[str]): A list of documents to encode.
224
-
225
- Returns:
226
- List[List[float]]: A list of embeddings for the given documents.
227
-
228
- Raises:
229
- ValueError: If no embeddings are returned for a document.
230
- """
231
- embeddings = []
232
- for d in docs:
233
- try:
234
- output = self.query({"inputs": d, "parameters": {}})
235
- if not output or len(output) == 0:
236
- raise ValueError("No embeddings returned from the query.")
237
- embeddings.append(output)
238
-
239
- except Exception as e:
240
- raise ValueError(
241
- f"No embeddings returned for document. Error: {e}"
242
- ) from e
243
- return embeddings
244
-
245
- def query(self, payload, max_retries=3, retry_interval=5):
246
- """
247
- Sends a query to the Hugging Face API and returns the response.
248
-
249
- Args:
250
- payload (dict): The payload to send in the request.
251
-
252
- Returns:
253
- dict: The response from the Hugging Face API.
254
-
255
- Raises:
256
- ValueError: If the query fails or the response status is not 200.
257
- """
258
- headers = {
259
- "Accept": "application/json",
260
- "Authorization": f"Bearer {self.huggingface_api_key}",
261
- "Content-Type": "application/json",
262
- }
263
- for attempt in range(1, max_retries + 1):
264
- try:
265
- response = requests.post(
266
- self.huggingface_url,
267
- headers=headers,
268
- json=payload,
269
- # timeout=timeout_seconds,
270
- )
271
- if response.status_code == 503:
272
- estimated_time = response.json().get("estimated_time", "")
273
- if estimated_time:
274
- logger.info(
275
- f"Model Initializing wait for - {estimated_time:.2f}s "
276
- )
277
- time.sleep(estimated_time)
278
- continue
279
- else:
280
- response.raise_for_status()
281
-
282
- except requests.exceptions.RequestException:
283
- if attempt < max_retries - 1:
284
- logger.info(f"Retrying attempt: {attempt} for payload: {payload} ")
285
- time.sleep(retry_interval)
286
- retry_interval += attempt
287
- else:
288
- raise ValueError(
289
- f"Query failed with status {response.status_code}: {response.text}"
290
- )
291
-
292
- return response.json()
 
154
  attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
155
  )
156
  token_embeddings[input_mask_expanded == 0] = -1e9
157
+ return self._torch.max(token_embeddings, 1)[0]