Test-Running / SERVER.py
DevsDoCode's picture
Upload 10 files
3702f2a verified
raw
history blame
5.62 kB
import uuid
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import json
from API_provider import API_Inference
from core_logic import (
check_api_key_validity,
update_request_count,
get_rate_limit_status,
get_subscription_status,
get_available_models,
get_model_info,
)
app = FastAPI()
security = HTTPBearer()
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
max_tokens: Optional[int] = 4000
temperature: Optional[float] = 0.5
top_p: Optional[float] = 0.95
def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
return credentials.credentials
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(get_api_key)):
try:
# Check API key validity and rate limit
is_valid, error_message = check_api_key_validity(api_key)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# Get model info
model_info = get_model_info(request.model)
if not model_info:
raise HTTPException(status_code=400, detail="Invalid model specified")
if "meta-llama-405b-turbo" in request.model:
request.model = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
if "claude-3.5-sonnet" in request.model:
request.model = "claude-3-sonnet-20240229"
if request.stream:
def generate():
for chunk in API_Inference(messages, model=request.model, stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p):
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
yield "data: [DONE]\n\nCredits used: 1\n\n"
# Update request count
if request.model == "gpt-4o" or request.model == "claude-3-sonnet-20240229" or request.model == "gemini-1.5-pro" or request.model == "gemini-1-5-flash" or request.model == "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo":
update_request_count(api_key, 1)
elif request.model == "o1-mini":
update_request_count(api_key, 2)
elif request.model == "o1-preview":
update_request_count(api_key, 3)
return StreamingResponse(generate(), media_type="text/event-stream")
else:
response = API_Inference(messages, model=request.model, stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p)
# Update request count
update_request_count(api_key, 1) # Assume 1 credit per request, adjust as needed
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(uuid.uuid1().time // 1e7),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(' '.join(msg['content'] for msg in messages).split()),
"completion_tokens": len(response.split()),
"total_tokens": len(' '.join(msg['content'] for msg in messages).split()) + len(response.split())
},
"credits_used": 1
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/rate_limit/status")
async def get_rate_limit_status_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return get_rate_limit_status(api_key)
@app.get("/subscription/status")
async def get_subscription_status_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return get_subscription_status(api_key)
@app.get("/models")
async def get_available_models_endpoint(api_key: str = Depends(get_api_key)):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
raise HTTPException(status_code=401, detail=error_message)
return {"data": [{"id": model} for model in get_available_models().values()]}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)