Spaces:
Sleeping
Sleeping
cotcotquedec
commited on
Commit
·
154ea17
1
Parent(s):
d4c432e
refactor(main): improve code structure and error handling
Browse filesRefactored the main application file to enhance code readability and maintainability. Introduced logging for better error tracking and replaced JSONResponse with HTTPException for more consistent error handling. Added detailed docstrings to functions and methods for better understanding of their purpose and usage. Also, restructured the code to separate concerns and improve the logical flow.
Additionally, created a new `schemas.py` file to define data models using Pydantic, which helps in validating and organizing request and response data structures.
This refactor aims to improve the overall robustness and scalability of the application by ensuring that the code is more modular and easier to maintain.
- main.py +235 -116
- schemas.py +41 -0
main.py
CHANGED
@@ -1,26 +1,52 @@
|
|
1 |
-
import os
|
2 |
-
from fastapi import FastAPI, HTTPException, Depends, Header, Request, Response
|
3 |
-
from fastapi.responses import JSONResponse, StreamingResponse
|
4 |
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
5 |
-
from pydantic import BaseModel
|
6 |
-
from typing import List, Optional
|
7 |
-
from anthropic import Anthropic
|
8 |
import json
|
|
|
|
|
9 |
import time
|
|
|
10 |
from contextvars import ContextVar
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
app = FastAPI()
|
13 |
security = HTTPBearer()
|
14 |
|
15 |
-
# Context variable
|
16 |
token_context = ContextVar('token', default=None)
|
17 |
|
18 |
-
#
|
19 |
PUBLIC_ENDPOINTS = {"/"}
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
@app.middleware("http")
|
22 |
async def auth_middleware(request: Request, call_next):
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
if request.url.path in PUBLIC_ENDPOINTS:
|
25 |
start_time = time.perf_counter()
|
26 |
response = await call_next(request)
|
@@ -31,22 +57,20 @@ async def auth_middleware(request: Request, call_next):
|
|
31 |
try:
|
32 |
auth_header = request.headers.get('Authorization')
|
33 |
if not auth_header:
|
34 |
-
|
35 |
status_code=401,
|
36 |
-
|
37 |
)
|
38 |
|
39 |
scheme, token = auth_header.split()
|
40 |
if scheme.lower() != 'bearer':
|
41 |
-
|
42 |
status_code=401,
|
43 |
-
|
44 |
)
|
45 |
|
46 |
-
# Store token in context
|
47 |
token_context.set(token)
|
48 |
|
49 |
-
# Add processing time header
|
50 |
start_time = time.perf_counter()
|
51 |
response = await call_next(request)
|
52 |
process_time = time.perf_counter() - start_time
|
@@ -54,46 +78,60 @@ async def auth_middleware(request: Request, call_next):
|
|
54 |
|
55 |
return response
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
58 |
return JSONResponse(
|
59 |
-
status_code=
|
60 |
-
content={"detail": "
|
61 |
)
|
62 |
|
63 |
-
|
64 |
def get_anthropic_client():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
token = token_context.get()
|
66 |
if not token:
|
67 |
raise HTTPException(status_code=401, detail="No authorization token found")
|
68 |
return Anthropic(api_key=token)
|
69 |
|
70 |
-
# Available models
|
71 |
-
AVAILABLE_MODELS = [
|
72 |
-
"claude-3-haiku-20240307",
|
73 |
-
"claude-3-opus-20240229",
|
74 |
-
"claude-3-sonnet-20240229",
|
75 |
-
"claude-3-5-sonnet-20241022"
|
76 |
-
]
|
77 |
-
|
78 |
-
class Message(BaseModel):
|
79 |
-
role: str
|
80 |
-
content: str
|
81 |
-
|
82 |
-
class ChatCompletionRequest(BaseModel):
|
83 |
-
model: str
|
84 |
-
messages: List[Message]
|
85 |
-
stream: bool = False
|
86 |
-
max_tokens: Optional[int] = 1024
|
87 |
|
|
|
88 |
@app.get("/")
|
89 |
async def read_root():
|
|
|
90 |
return {"Hello": "World!"}
|
91 |
|
|
|
|
|
92 |
@app.get("/models")
|
93 |
async def get_models():
|
94 |
-
"""
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
|
98 |
models = [
|
99 |
{
|
@@ -115,97 +153,178 @@ async def get_models():
|
|
115 |
}
|
116 |
)
|
117 |
|
118 |
-
@app.post("/v1/chat/completions")
|
119 |
-
async def create_chat_completion(request: ChatCompletionRequest):
|
120 |
-
"""Generate chat completions using Anthropic models."""
|
121 |
-
try:
|
122 |
-
if request.stream:
|
123 |
-
return StreamingResponse(
|
124 |
-
stream_response(request),
|
125 |
-
media_type="text/event-stream"
|
126 |
-
)
|
127 |
-
else:
|
128 |
-
return await generate_completion(request)
|
129 |
-
except Exception as e:
|
130 |
-
raise HTTPException(status_code=500, detail=str(e))
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
messages
|
|
|
|
|
|
|
135 |
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
client = get_anthropic_client()
|
138 |
-
|
139 |
response = client.messages.create(
|
140 |
-
model=
|
141 |
-
max_tokens=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
messages=messages
|
143 |
)
|
144 |
|
|
|
|
|
145 |
return {
|
146 |
"id": response.id,
|
147 |
"object": "chat.completion",
|
148 |
"created": int(time.time()),
|
149 |
-
"model":
|
150 |
-
"choices": [
|
151 |
-
|
152 |
-
|
153 |
-
"
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
162 |
-
}
|
163 |
}
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
168 |
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
"id": chunk.message.id,
|
186 |
-
"object": "chat.completion.chunk",
|
187 |
-
"created": int(time.time()),
|
188 |
-
"model": request.model,
|
189 |
-
"choices": [{
|
190 |
-
"index": 0,
|
191 |
-
"delta": {"content": chunk.delta.text if hasattr(chunk.delta, "text") else ""},
|
192 |
-
"finish_reason": None
|
193 |
-
}]
|
194 |
-
}
|
195 |
-
yield f"data: {json.dumps(data)}\n\n"
|
196 |
-
|
197 |
-
elif chunk.type == "content_block_stop":
|
198 |
-
data = {
|
199 |
-
"id": chunk.message.id,
|
200 |
-
"object": "chat.completion.chunk",
|
201 |
-
"created": int(time.time()),
|
202 |
-
"model": request.model,
|
203 |
-
"choices": [{
|
204 |
-
"index": 0,
|
205 |
-
"delta": {},
|
206 |
-
"finish_reason": "stop"
|
207 |
-
}]
|
208 |
-
}
|
209 |
-
yield f"data: {json.dumps(data)}\n\n"
|
210 |
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
import time
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
from contextvars import ContextVar
|
7 |
+
from typing import Any, Dict, Generator, List
|
8 |
|
9 |
+
from anthropic import Anthropic
|
10 |
+
from fastapi import FastAPI, HTTPException, Request, Response
|
11 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
12 |
+
from fastapi.security import HTTPBearer
|
13 |
+
from starlette.concurrency import run_in_threadpool
|
14 |
+
|
15 |
+
from schemas import OpenAIChatCompletionForm, FilterForm
|
16 |
+
|
17 |
+
# logger
|
18 |
+
logger = logging.getLogger()
|
19 |
+
|
20 |
+
# FastAPI app initialization
|
21 |
app = FastAPI()
|
22 |
security = HTTPBearer()
|
23 |
|
24 |
+
# Context variable for token storage
|
25 |
token_context = ContextVar('token', default=None)
|
26 |
|
27 |
+
# Endpoints that don't require authentication
|
28 |
PUBLIC_ENDPOINTS = {"/"}
|
29 |
|
30 |
+
# Available Anthropic models
|
31 |
+
AVAILABLE_MODELS = [
|
32 |
+
"claude-3-haiku-20240307",
|
33 |
+
"claude-3-opus-20240229",
|
34 |
+
"claude-3-sonnet-20240229",
|
35 |
+
"claude-3-5-sonnet-20241022"
|
36 |
+
]
|
37 |
+
|
38 |
@app.middleware("http")
|
39 |
async def auth_middleware(request: Request, call_next):
|
40 |
+
"""
|
41 |
+
Middleware for handling authentication and response logging.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
request: The incoming HTTP request
|
45 |
+
call_next: The next middleware in the chain
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Response: The processed HTTP response
|
49 |
+
"""
|
50 |
if request.url.path in PUBLIC_ENDPOINTS:
|
51 |
start_time = time.perf_counter()
|
52 |
response = await call_next(request)
|
|
|
57 |
try:
|
58 |
auth_header = request.headers.get('Authorization')
|
59 |
if not auth_header:
|
60 |
+
raise HTTPException(
|
61 |
status_code=401,
|
62 |
+
detail="No authorization header"
|
63 |
)
|
64 |
|
65 |
scheme, token = auth_header.split()
|
66 |
if scheme.lower() != 'bearer':
|
67 |
+
raise HTTPException(
|
68 |
status_code=401,
|
69 |
+
detail="Invalid authentication scheme"
|
70 |
)
|
71 |
|
|
|
72 |
token_context.set(token)
|
73 |
|
|
|
74 |
start_time = time.perf_counter()
|
75 |
response = await call_next(request)
|
76 |
process_time = time.perf_counter() - start_time
|
|
|
78 |
|
79 |
return response
|
80 |
|
81 |
+
except HTTPException as http_ex:
|
82 |
+
logger.error(
|
83 |
+
f"HTTP Exception - Status: {http_ex.status_code} - "
|
84 |
+
f"Detail: {http_ex.detail} - Path: {request.url.path}"
|
85 |
+
)
|
86 |
+
return JSONResponse(
|
87 |
+
status_code=http_ex.status_code,
|
88 |
+
content={"detail": http_ex.detail}
|
89 |
+
)
|
90 |
except Exception as e:
|
91 |
+
logger.error(
|
92 |
+
f"Unexpected error in middleware - Error: {str(e)} - "
|
93 |
+
f"Path: {request.url.path}",
|
94 |
+
exc_info=True
|
95 |
+
)
|
96 |
return JSONResponse(
|
97 |
+
status_code=500,
|
98 |
+
content={"detail": "Internal server error"}
|
99 |
)
|
100 |
|
101 |
+
|
102 |
def get_anthropic_client():
|
103 |
+
"""
|
104 |
+
Get an authenticated Anthropic client using the current token.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Anthropic: An authenticated Anthropic client instance
|
108 |
+
|
109 |
+
Raises:
|
110 |
+
HTTPException: If no authorization token is found
|
111 |
+
"""
|
112 |
token = token_context.get()
|
113 |
if not token:
|
114 |
raise HTTPException(status_code=401, detail="No authorization token found")
|
115 |
return Anthropic(api_key=token)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
@app.get("/v1")
|
119 |
@app.get("/")
|
120 |
async def read_root():
|
121 |
+
"""Root endpoint for API health check."""
|
122 |
return {"Hello": "World!"}
|
123 |
|
124 |
+
|
125 |
+
@app.get("/v1/models")
|
126 |
@app.get("/models")
|
127 |
async def get_models():
|
128 |
+
"""
|
129 |
+
Get available Anthropic models.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
JSONResponse: List of available models and their details
|
133 |
+
"""
|
134 |
+
get_anthropic_client() # Verify token validity
|
135 |
|
136 |
models = [
|
137 |
{
|
|
|
153 |
}
|
154 |
)
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
+
def stream_message(
|
158 |
+
model: str,
|
159 |
+
messages: List[Dict[str, Any]]
|
160 |
+
) -> Generator[str, None, None]:
|
161 |
+
"""
|
162 |
+
Stream messages using the specified model.
|
163 |
|
164 |
+
Args:
|
165 |
+
model: The model identifier to use
|
166 |
+
messages: List of messages to process
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Generator: Stream of SSE formatted responses
|
170 |
+
"""
|
171 |
client = get_anthropic_client()
|
|
|
172 |
response = client.messages.create(
|
173 |
+
model=model,
|
174 |
+
max_tokens=1024,
|
175 |
+
messages=messages,
|
176 |
+
stream=True
|
177 |
+
)
|
178 |
+
|
179 |
+
def event_stream() -> Generator[str, None, None]:
|
180 |
+
message_id = None
|
181 |
+
|
182 |
+
for chunk in response:
|
183 |
+
if not message_id:
|
184 |
+
message_id = f"chatcmpl-{int(time.time())}"
|
185 |
+
|
186 |
+
if chunk.type == 'content_block_delta':
|
187 |
+
data = {
|
188 |
+
"id": message_id,
|
189 |
+
"object": "chat.completion.chunk",
|
190 |
+
"created": int(time.time()),
|
191 |
+
"model": model,
|
192 |
+
"choices": [
|
193 |
+
{
|
194 |
+
"index": 0,
|
195 |
+
"delta": {
|
196 |
+
"content": (
|
197 |
+
chunk.delta.text
|
198 |
+
if hasattr(chunk.delta, 'text')
|
199 |
+
else ""
|
200 |
+
)
|
201 |
+
},
|
202 |
+
"logprobs": None,
|
203 |
+
"finish_reason": None,
|
204 |
+
}
|
205 |
+
],
|
206 |
+
}
|
207 |
+
yield f"data: {json.dumps(data)}\n\n"
|
208 |
+
|
209 |
+
elif chunk.type == 'content_block_stop':
|
210 |
+
data = {
|
211 |
+
"id": message_id,
|
212 |
+
"object": "chat.completion.chunk",
|
213 |
+
"created": int(time.time()),
|
214 |
+
"model": model,
|
215 |
+
"choices": [
|
216 |
+
{
|
217 |
+
"index": 0,
|
218 |
+
"delta": {},
|
219 |
+
"logprobs": None,
|
220 |
+
"finish_reason": "stop",
|
221 |
+
}
|
222 |
+
],
|
223 |
+
}
|
224 |
+
yield f"data: {json.dumps(data)}\n\n"
|
225 |
+
|
226 |
+
yield "data: [DONE]\n\n"
|
227 |
+
|
228 |
+
return event_stream()
|
229 |
+
|
230 |
+
|
231 |
+
def send_message(model: str, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
232 |
+
"""
|
233 |
+
Send a message via the Anthropic provider without streaming.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
model: The model identifier to use
|
237 |
+
messages: List of messages to process
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
dict: The formatted response from the model
|
241 |
+
"""
|
242 |
+
client = get_anthropic_client()
|
243 |
+
response = client.messages.create(
|
244 |
+
model=model,
|
245 |
+
max_tokens=1024,
|
246 |
messages=messages
|
247 |
)
|
248 |
|
249 |
+
content = response.content[0].text if response.content else ""
|
250 |
+
|
251 |
return {
|
252 |
"id": response.id,
|
253 |
"object": "chat.completion",
|
254 |
"created": int(time.time()),
|
255 |
+
"model": model,
|
256 |
+
"choices": [
|
257 |
+
{
|
258 |
+
"index": 0,
|
259 |
+
"message": {
|
260 |
+
"role": "assistant",
|
261 |
+
"content": content,
|
262 |
+
},
|
263 |
+
"logprobs": None,
|
264 |
+
"finish_reason": "stop",
|
265 |
+
}
|
266 |
+
],
|
|
|
|
|
267 |
}
|
268 |
|
269 |
+
|
270 |
+
@app.post("/v1/chat/completions")
|
271 |
+
@app.post("/chat/completions")
|
272 |
+
async def generate_chat_completion(form_data: OpenAIChatCompletionForm):
|
273 |
+
"""
|
274 |
+
Generate chat completions from the model.
|
275 |
|
276 |
+
Args:
|
277 |
+
form_data: The chat completion request parameters
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
Union[StreamingResponse, dict]: Either a streaming response or a complete message
|
281 |
+
"""
|
282 |
+
messages = [
|
283 |
+
{"role": message.role, "content": message.content}
|
284 |
+
for message in form_data.messages
|
285 |
+
]
|
286 |
+
model = form_data.model
|
287 |
+
|
288 |
+
def job():
|
289 |
+
"""Handle both streaming and non-streaming modes."""
|
290 |
+
if form_data.stream:
|
291 |
+
return StreamingResponse(
|
292 |
+
stream_message(model=model, messages=messages),
|
293 |
+
media_type="text/event-stream"
|
294 |
+
)
|
295 |
+
return send_message(model=model, messages=messages)
|
296 |
+
|
297 |
+
with ThreadPoolExecutor() as executor:
|
298 |
+
return await run_in_threadpool(job)
|
299 |
+
|
300 |
+
|
301 |
+
@app.post("/v1/{pipeline_id}/filter/inlet")
|
302 |
+
@app.post("/{pipeline_id}/filter/inlet")
|
303 |
+
async def filter_inlet(pipeline_id: str, form_data: FilterForm):
|
304 |
+
"""
|
305 |
+
Handle inlet filtering for the pipeline.
|
306 |
|
307 |
+
Args:
|
308 |
+
pipeline_id: The ID of the pipeline
|
309 |
+
form_data: The filter parameters
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
dict: The processed request body
|
313 |
+
"""
|
314 |
+
return form_data.body
|
315 |
|
316 |
+
|
317 |
+
@app.post("/v1/{pipeline_id}/filter/outlet")
|
318 |
+
@app.post("/{pipeline_id}/filter/outlet")
|
319 |
+
async def filter_outlet(pipeline_id: str, form_data: FilterForm):
|
320 |
+
"""
|
321 |
+
Handle outlet filtering for the pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
+
Args:
|
324 |
+
pipeline_id: The ID of the pipeline
|
325 |
+
form_data: The filter parameters
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
dict: The processed request body
|
329 |
+
"""
|
330 |
+
return form_data.body
|
schemas.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
from pydantic import BaseModel, ConfigDict, RootModel
|
3 |
+
|
4 |
+
|
5 |
+
class ImageContent(BaseModel):
|
6 |
+
"""Model for image content in messages."""
|
7 |
+
type: str
|
8 |
+
image_url: dict
|
9 |
+
|
10 |
+
|
11 |
+
class TextContent(BaseModel):
|
12 |
+
"""Model for text content in messages."""
|
13 |
+
type: str
|
14 |
+
text: str
|
15 |
+
|
16 |
+
|
17 |
+
class MessageContent(RootModel):
|
18 |
+
"""Model for message content that can be either text or image."""
|
19 |
+
root: Union[TextContent, ImageContent]
|
20 |
+
|
21 |
+
|
22 |
+
class OpenAIChatMessage(BaseModel):
|
23 |
+
"""Model for chat messages in OpenAI format."""
|
24 |
+
role: str
|
25 |
+
content: Union[str, List[MessageContent]]
|
26 |
+
model_config = ConfigDict(extra="allow")
|
27 |
+
|
28 |
+
|
29 |
+
class OpenAIChatCompletionForm(BaseModel):
|
30 |
+
"""Model for chat completion request parameters."""
|
31 |
+
stream: bool = True
|
32 |
+
model: str
|
33 |
+
messages: List[OpenAIChatMessage]
|
34 |
+
model_config = ConfigDict(extra="allow")
|
35 |
+
|
36 |
+
|
37 |
+
class FilterForm(BaseModel):
|
38 |
+
"""Model for filter request parameters."""
|
39 |
+
body: dict
|
40 |
+
user: Optional[dict] = None
|
41 |
+
model_config = ConfigDict(extra="allow")
|