from flask import Flask, request, jsonify, Response from functools import wraps import uuid import json from typing import List, Optional from pydantic import BaseModel, ValidationError from API_provider import API_Inference from core_logic import ( check_api_key_validity, update_request_count, get_rate_limit_status, get_subscription_status, get_available_models, get_model_info, ) app = Flask(__name__) class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str messages: List[Message] stream: Optional[bool] = False max_tokens: Optional[int] = 4000 temperature: Optional[float] = 0.5 top_p: Optional[float] = 0.95 def get_api_key(): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return None return auth_header.split(' ')[1] def requires_api_key(func): @wraps(func) def decorated(*args, **kwargs): api_key = get_api_key() if not api_key: return jsonify({'detail': 'Not authenticated'}), 401 kwargs['api_key'] = api_key return func(*args, **kwargs) return decorated @app.route('/') def index(): return 'Hello, World!' @app.route('/chat/completions', methods=['POST', 'GET']) @requires_api_key def chat_completions(api_key): try: # Parse and validate request data try: data = request.get_json() chat_request = ChatCompletionRequest(**data) except ValidationError as e: return jsonify({'detail': e.errors()}), 400 # Check API key validity and rate limit is_valid, error_message = check_api_key_validity(api_key) if not is_valid: return jsonify({'detail': error_message}), 401 messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] # Get model info model_info = get_model_info(chat_request.model) if not model_info: return jsonify({'detail': 'Invalid model specified'}), 400 # Model mapping model_mapping = { "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "claude-3.5-sonnet": "claude-3-sonnet-20240229", } model_name = model_mapping.get(chat_request.model, chat_request.model) credits_reduction = { "gpt-4o": 1, "claude-3-sonnet-20240229": 1, "gemini-1.5-pro": 1, "gemini-1-5-flash": 1, "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1, "o1-mini": 2, "o1-preview": 3, }.get(model_name, 0) if chat_request.stream: def generate(): try: for chunk in API_Inference(messages, model=model_name, stream=True, max_tokens=chat_request.max_tokens, temperature=chat_request.temperature, top_p=chat_request.top_p): data = json.dumps({'choices': [{'delta': {'content': chunk}}]}) yield f"data: {data}\n\n" yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n" update_request_count(api_key, credits_reduction) except Exception as e: yield f"data: [ERROR] {str(e)}\n\n" return Response(generate(), mimetype='text/event-stream') else: response = API_Inference(messages, model=model_name, stream=False, max_tokens=chat_request.max_tokens, temperature=chat_request.temperature, top_p=chat_request.top_p) update_request_count(api_key, credits_reduction) prompt_tokens = sum(len(msg['content'].split()) for msg in messages) completion_tokens = len(response.split()) total_tokens = prompt_tokens + completion_tokens return jsonify({ "id": f"chatcmpl-{str(uuid.uuid4())}", "object": "chat.completion", "created": int(uuid.uuid1().time // 1e7), "model": model_name, "choices": [ { "index": 0, "message": { "role": "assistant", "content": response }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens }, "credits_used": credits_reduction }) except Exception as e: return jsonify({'detail': str(e)}), 500 @app.route('/rate_limit/status', methods=['GET']) @requires_api_key def get_rate_limit_status_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify(get_rate_limit_status(api_key)) @app.route('/subscription/status', methods=['GET']) @requires_api_key def get_subscription_status_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify(get_subscription_status(api_key)) @app.route('/models', methods=['GET']) @requires_api_key def get_available_models_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify({"data": [{"id": model} for model in get_available_models().values()]}) if __name__ == "__main__": app.run(host="0.0.0.0", port=8000)