from __future__ import annotations import time from typing import Any, Callable, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.huggingface_hub import HuggingFaceHub import logging logger = logging.getLogger(__name__) from tenacity import ( before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) def _create_retry_decorator(llm: KronHuggingFaceHub) -> Callable[[Any], Any]: #import cohere min_seconds = 4 max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards return retry( reraise=True, stop=stop_after_attempt(llm.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=(retry_if_exception_type(KronHFHubRateExceededException)), before_sleep=before_sleep_log(logger, logging.WARNING), ) def completion_with_retry(llm: KronHuggingFaceHub, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: return llm.internal_call(**kwargs) return _completion_with_retry(**kwargs) class KronHFHubRateExceededException(Exception): def __init__(self, message="HF Hub Service Unavailable: Rate exceeded."): self.message = message super().__init__(self.message) class KronHuggingFaceHub(HuggingFaceHub): max_retries: int = 10 """Maximum number of retries to make when generating.""" def internal_call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: try: print(f'**************************************\n{prompt}') response = super()._call(prompt, stop, run_manager, **kwargs) print(f'**************************************\n{response}') return response except ValueError as ve: if "Service Unavailable" in str(ve): raise KronHFHubRateExceededException() else: raise ve def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: response = completion_with_retry(self, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs) return response