Spaces:
Sleeping
Sleeping
File size: 5,616 Bytes
3702f2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import uuid
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import json
from API_provider import API_Inference
from core_logic import (
check_api_key_validity,
update_request_count,
get_rate_limit_status,
get_subscription_status,
get_available_models,
get_model_info,
)
app = FastAPI()
security = HTTPBearer()
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
max_tokens: Optional[int] = 4000
temperature: Optional[float] = 0.5
top_p: Optional[float] = 0.95
def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
return credentials.credentials
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(get_api_key)):
try:
# Check API key validity and rate limit
is_valid, error_message = check_api_key_validity(api_key)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# Get model info
model_info = get_model_info(request.model)
if not model_info:
raise HTTPException(status_code=400, detail="Invalid model specified")
if "meta-llama-405b-turbo" in request.model:
request.model = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
if "claude-3.5-sonnet" in request.model:
request.model = "claude-3-sonnet-20240229"
if request.stream:
def generate():
for chunk in API_Inference(messages, model=request.model, stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p):
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
yield "data: [DONE]\n\nCredits used: 1\n\n"
# Update request count
if request.model == "gpt-4o" or request.model == "claude-3-sonnet-20240229" or request.model == "gemini-1.5-pro" or request.model == "gemini-1-5-flash" or request.model == "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo":
update_request_count(api_key, 1)
elif request.model == "o1-mini":
update_request_count(api_key, 2)
elif request.model == "o1-preview":
update_request_count(api_key, 3)
return StreamingResponse(generate(), media_type="text/event-stream")
else:
response = API_Inference(messages, model=request.model, stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p)
# Update request count
update_request_count(api_key, 1) # Assume 1 credit per request, adjust as needed
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(uuid.uuid1().time // 1e7),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(' '.join(msg['content'] for msg in messages).split()),
"completion_tokens": len(response.split()),
"total_tokens": len(' '.join(msg['content'] for msg in messages).split()) + len(response.split())
},
"credits_used": 1
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/rate_limit/status")
async def get_rate_limit_status_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return get_rate_limit_status(api_key)
@app.get("/subscription/status")
async def get_subscription_status_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return get_subscription_status(api_key)
@app.get("/models")
async def get_available_models_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return {"data": [{"id": model} for model in get_available_models().values()]}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |