|
import asyncio |
|
import json |
|
import logging |
|
from typing import TypeVar, Type, Optional, Callable |
|
from pydantic import BaseModel |
|
from langchain_mistralai.chat_models import ChatMistralAI |
|
from langchain.schema import SystemMessage, HumanMessage |
|
from langchain.schema.messages import BaseMessage |
|
|
|
T = TypeVar('T', bound=BaseModel) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MistralAPIError(Exception): |
|
"""Base class for Mistral API errors""" |
|
pass |
|
|
|
class MistralRateLimitError(MistralAPIError): |
|
"""Raised when hitting rate limits""" |
|
pass |
|
|
|
class MistralParsingError(MistralAPIError): |
|
"""Raised when response parsing fails""" |
|
pass |
|
|
|
class MistralValidationError(MistralAPIError): |
|
"""Raised when response validation fails""" |
|
pass |
|
|
|
class MistralClient: |
|
def __init__(self, api_key: str, model_name: str = "mistral-large-latest", max_tokens: int = 1000): |
|
logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}") |
|
self.model = ChatMistralAI( |
|
mistral_api_key=api_key, |
|
model=model_name, |
|
max_tokens=max_tokens |
|
) |
|
self.fixing_model = ChatMistralAI( |
|
mistral_api_key=api_key, |
|
model=model_name, |
|
max_tokens=max_tokens |
|
) |
|
|
|
|
|
self.last_call_time = 0 |
|
self.min_delay = 1 |
|
self.max_retries = 5 |
|
self.backoff_factor = 2 |
|
self.max_backoff = 30 |
|
|
|
async def _wait_for_rate_limit(self): |
|
"""Attend le temps nécessaire pour respecter le rate limit.""" |
|
current_time = asyncio.get_event_loop().time() |
|
time_since_last_call = current_time - self.last_call_time |
|
|
|
if time_since_last_call < self.min_delay: |
|
delay = self.min_delay - time_since_last_call |
|
logger.debug(f"Rate limit: waiting for {delay:.2f} seconds") |
|
await asyncio.sleep(delay) |
|
|
|
self.last_call_time = asyncio.get_event_loop().time() |
|
|
|
async def _handle_api_error(self, error: Exception, retry_count: int) -> float: |
|
"""Handle API errors and return wait time for retry""" |
|
wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) |
|
|
|
if "rate limit" in str(error).lower(): |
|
logger.warning(f"Rate limit hit, waiting {wait_time}s before retry") |
|
raise MistralRateLimitError(str(error)) |
|
elif "403" in str(error): |
|
logger.error("Authentication error - invalid API key or quota exceeded") |
|
raise MistralAPIError("Authentication failed") |
|
|
|
return wait_time |
|
|
|
async def _generate_with_retry( |
|
self, |
|
messages: list[BaseMessage], |
|
response_model: Optional[Type[T]] = None, |
|
custom_parser: Optional[Callable[[str], T]] = None, |
|
error_feedback: str = None |
|
) -> T | str: |
|
retry_count = 0 |
|
last_error = None |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") |
|
|
|
current_messages = messages.copy() |
|
if error_feedback and retry_count > 0: |
|
if isinstance(last_error, MistralParsingError): |
|
|
|
current_messages.append(HumanMessage(content="Please ensure your response is in valid JSON format.")) |
|
elif isinstance(last_error, MistralValidationError): |
|
|
|
current_messages.append(HumanMessage(content=f"Previous error: {error_feedback}. Please try again.")) |
|
|
|
await self._wait_for_rate_limit() |
|
try: |
|
response = await self.model.ainvoke(current_messages) |
|
content = response.content |
|
logger.debug(f"Raw response: {content[:100]}...") |
|
except Exception as api_error: |
|
wait_time = await self._handle_api_error(api_error, retry_count) |
|
retry_count += 1 |
|
if retry_count < self.max_retries: |
|
await asyncio.sleep(wait_time) |
|
continue |
|
raise |
|
|
|
|
|
if not response_model and not custom_parser: |
|
return content |
|
|
|
|
|
try: |
|
if custom_parser: |
|
return custom_parser(content) |
|
|
|
|
|
data = json.loads(content) |
|
return response_model(**data) |
|
except json.JSONDecodeError as e: |
|
last_error = MistralParsingError(f"Invalid JSON format: {str(e)}") |
|
logger.error(f"JSON parsing error: {str(e)}") |
|
raise last_error |
|
except Exception as e: |
|
last_error = MistralValidationError(str(e)) |
|
logger.error(f"Validation error: {str(e)}") |
|
raise last_error |
|
|
|
except (MistralParsingError, MistralValidationError) as e: |
|
logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") |
|
last_error = e |
|
retry_count += 1 |
|
if retry_count < self.max_retries: |
|
wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) |
|
logger.info(f"Waiting {wait_time} seconds before retry...") |
|
await asyncio.sleep(wait_time) |
|
continue |
|
|
|
logger.error(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") |
|
raise Exception(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") |
|
|
|
async def generate(self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None) -> T | str: |
|
"""Génère une réponse à partir d'une liste de messages avec parsing optionnel.""" |
|
return await self._generate_with_retry(messages, response_model, custom_parser) |
|
|
|
async def transform_prompt(self, story_text: str, art_prompt: str) -> str: |
|
"""Transforme un texte d'histoire en prompt artistique.""" |
|
messages = [{ |
|
"role": "system", |
|
"content": art_prompt |
|
}, { |
|
"role": "user", |
|
"content": f"Transform this story text into a comic panel description:\n{story_text}" |
|
}] |
|
try: |
|
return await self._generate_with_retry(messages) |
|
except Exception as e: |
|
print(f"Error transforming prompt: {str(e)}") |
|
return story_text |
|
|
|
async def generate_text(self, messages: list[BaseMessage]) -> str: |
|
""" |
|
Génère une réponse textuelle simple sans structure JSON. |
|
Utile pour la génération de texte narratif ou descriptif. |
|
|
|
Args: |
|
messages: Liste des messages pour le modèle |
|
|
|
Returns: |
|
str: Le texte généré |
|
""" |
|
retry_count = 0 |
|
last_error = None |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") |
|
|
|
await self._wait_for_rate_limit() |
|
response = await self.model.ainvoke(messages) |
|
return response.content.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") |
|
retry_count += 1 |
|
if retry_count < self.max_retries: |
|
wait_time = 2 * retry_count |
|
logger.info(f"Waiting {wait_time} seconds before retry...") |
|
await asyncio.sleep(wait_time) |
|
continue |
|
|
|
logger.error(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") |
|
raise Exception(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") |