File size: 5,616 Bytes
3702f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import uuid
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import json
from API_provider import API_Inference
from core_logic import (
    check_api_key_validity,
    update_request_count,
    get_rate_limit_status,
    get_subscription_status,
    get_available_models,
    get_model_info,
)

app = FastAPI()
security = HTTPBearer()

class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 4000
    temperature: Optional[float] = 0.5
    top_p: Optional[float] = 0.95

def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    return credentials.credentials

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(get_api_key)):
    try:
        # Check API key validity and rate limit
        is_valid, error_message = check_api_key_validity(api_key)
        if not is_valid:
            raise HTTPException(status_code=401, detail=error_message)

        messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
        
        # Get model info
        model_info = get_model_info(request.model)
        if not model_info:
            raise HTTPException(status_code=400, detail="Invalid model specified")
        
        if "meta-llama-405b-turbo" in request.model:
            request.model = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"

        if "claude-3.5-sonnet" in request.model:
            request.model = "claude-3-sonnet-20240229"

        if request.stream:
            def generate():
                for chunk in API_Inference(messages, model=request.model, stream=True,
                                        max_tokens=request.max_tokens,
                                        temperature=request.temperature,
                                        top_p=request.top_p):
                    yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
                yield "data: [DONE]\n\nCredits used: 1\n\n"

            
            # Update request count
            if request.model == "gpt-4o" or request.model == "claude-3-sonnet-20240229" or request.model == "gemini-1.5-pro" or request.model == "gemini-1-5-flash" or request.model == "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo":
                update_request_count(api_key, 1) 

            elif request.model == "o1-mini":
                update_request_count(api_key, 2)

            elif request.model == "o1-preview":
                update_request_count(api_key, 3)
                
            return StreamingResponse(generate(), media_type="text/event-stream")
        else:
            response = API_Inference(messages, model=request.model, stream=False,
                                  max_tokens=request.max_tokens,
                                  temperature=request.temperature,
                                  top_p=request.top_p)
            
            # Update request count
            update_request_count(api_key, 1)  # Assume 1 credit per request, adjust as needed

            return {
                "id": f"chatcmpl-{uuid.uuid4()}",
                "object": "chat.completion",
                "created": int(uuid.uuid1().time // 1e7),
                "model": request.model,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": response
                        },
                        "finish_reason": "stop"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(' '.join(msg['content'] for msg in messages).split()),
                    "completion_tokens": len(response.split()),
                    "total_tokens": len(' '.join(msg['content'] for msg in messages).split()) + len(response.split())
                },
                "credits_used": 1
            }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/rate_limit/status")
async def get_rate_limit_status_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return get_rate_limit_status(api_key)

@app.get("/subscription/status")
async def get_subscription_status_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return get_subscription_status(api_key)

@app.get("/models")
async def get_available_models_endpoint(api_key: str = Depends(get_api_key)):
    is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
    if not is_valid:
        raise HTTPException(status_code=401, detail=error_message)
    return {"data": [{"id": model} for model in get_available_models().values()]}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)