cotcotquedec commited on
Commit
154ea17
·
1 Parent(s): d4c432e

refactor(main): improve code structure and error handling

Browse files

Refactored the main application file to enhance code readability and maintainability. Introduced logging for better error tracking and replaced JSONResponse with HTTPException for more consistent error handling. Added detailed docstrings to functions and methods for better understanding of their purpose and usage. Also, restructured the code to separate concerns and improve the logical flow.

Additionally, created a new `schemas.py` file to define data models using Pydantic, which helps in validating and organizing request and response data structures.

This refactor aims to improve the overall robustness and scalability of the application by ensuring that the code is more modular and easier to maintain.

Files changed (2) hide show
  1. main.py +235 -116
  2. schemas.py +41 -0
main.py CHANGED
@@ -1,26 +1,52 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException, Depends, Header, Request, Response
3
- from fastapi.responses import JSONResponse, StreamingResponse
4
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
- from pydantic import BaseModel
6
- from typing import List, Optional
7
- from anthropic import Anthropic
8
  import json
 
 
9
  import time
 
10
  from contextvars import ContextVar
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  app = FastAPI()
13
  security = HTTPBearer()
14
 
15
- # Context variable to store the token
16
  token_context = ContextVar('token', default=None)
17
 
18
- # Liste des endpoints qui ne nécessitent pas d'authentification
19
  PUBLIC_ENDPOINTS = {"/"}
20
 
 
 
 
 
 
 
 
 
21
  @app.middleware("http")
22
  async def auth_middleware(request: Request, call_next):
23
- # Skip authentication for public endpoints
 
 
 
 
 
 
 
 
 
24
  if request.url.path in PUBLIC_ENDPOINTS:
25
  start_time = time.perf_counter()
26
  response = await call_next(request)
@@ -31,22 +57,20 @@ async def auth_middleware(request: Request, call_next):
31
  try:
32
  auth_header = request.headers.get('Authorization')
33
  if not auth_header:
34
- return JSONResponse(
35
  status_code=401,
36
- content={"detail": "No authorization header"}
37
  )
38
 
39
  scheme, token = auth_header.split()
40
  if scheme.lower() != 'bearer':
41
- return JSONResponse(
42
  status_code=401,
43
- content={"detail": "Invalid authentication scheme"}
44
  )
45
 
46
- # Store token in context
47
  token_context.set(token)
48
 
49
- # Add processing time header
50
  start_time = time.perf_counter()
51
  response = await call_next(request)
52
  process_time = time.perf_counter() - start_time
@@ -54,46 +78,60 @@ async def auth_middleware(request: Request, call_next):
54
 
55
  return response
56
 
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
 
 
 
 
 
58
  return JSONResponse(
59
- status_code=401,
60
- content={"detail": "Invalid authorization header"}
61
  )
62
 
63
- # Function to get Anthropic client with current token
64
  def get_anthropic_client():
 
 
 
 
 
 
 
 
 
65
  token = token_context.get()
66
  if not token:
67
  raise HTTPException(status_code=401, detail="No authorization token found")
68
  return Anthropic(api_key=token)
69
 
70
- # Available models
71
- AVAILABLE_MODELS = [
72
- "claude-3-haiku-20240307",
73
- "claude-3-opus-20240229",
74
- "claude-3-sonnet-20240229",
75
- "claude-3-5-sonnet-20241022"
76
- ]
77
-
78
- class Message(BaseModel):
79
- role: str
80
- content: str
81
-
82
- class ChatCompletionRequest(BaseModel):
83
- model: str
84
- messages: List[Message]
85
- stream: bool = False
86
- max_tokens: Optional[int] = 1024
87
 
 
88
  @app.get("/")
89
  async def read_root():
 
90
  return {"Hello": "World!"}
91
 
 
 
92
  @app.get("/models")
93
  async def get_models():
94
- """Get available Anthropic models."""
95
- # Test the token by creating a client
96
- get_anthropic_client()
 
 
 
 
97
 
98
  models = [
99
  {
@@ -115,97 +153,178 @@ async def get_models():
115
  }
116
  )
117
 
118
- @app.post("/v1/chat/completions")
119
- async def create_chat_completion(request: ChatCompletionRequest):
120
- """Generate chat completions using Anthropic models."""
121
- try:
122
- if request.stream:
123
- return StreamingResponse(
124
- stream_response(request),
125
- media_type="text/event-stream"
126
- )
127
- else:
128
- return await generate_completion(request)
129
- except Exception as e:
130
- raise HTTPException(status_code=500, detail=str(e))
131
 
132
- async def generate_completion(request: ChatCompletionRequest):
133
- """Generate a non-streaming completion."""
134
- messages = [{"role": m.role, "content": m.content} for m in request.messages]
 
 
 
135
 
136
- # Get client with current token
 
 
 
 
 
 
137
  client = get_anthropic_client()
138
-
139
  response = client.messages.create(
140
- model=request.model,
141
- max_tokens=request.max_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  messages=messages
143
  )
144
 
 
 
145
  return {
146
  "id": response.id,
147
  "object": "chat.completion",
148
  "created": int(time.time()),
149
- "model": request.model,
150
- "choices": [{
151
- "index": 0,
152
- "message": {
153
- "role": "assistant",
154
- "content": response.content[0].text if response.content else "",
155
- },
156
- "finish_reason": "stop"
157
- }],
158
- "usage": {
159
- "prompt_tokens": response.usage.input_tokens,
160
- "completion_tokens": response.usage.output_tokens,
161
- "total_tokens": response.usage.input_tokens + response.usage.output_tokens
162
- }
163
  }
164
 
165
- async def stream_response(request: ChatCompletionRequest):
166
- """Stream the completion response."""
167
- messages = [{"role": m.role, "content": m.content} for m in request.messages]
 
 
 
168
 
169
- # Get client with current token
170
- client = get_anthropic_client()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- response = client.messages.create(
173
- model=request.model,
174
- max_tokens=request.max_tokens,
175
- messages=messages,
176
- stream=True
177
- )
 
 
178
 
179
- for chunk in response:
180
- if chunk.type == "message_start":
181
- continue
182
-
183
- if chunk.type == "content_block_delta":
184
- data = {
185
- "id": chunk.message.id,
186
- "object": "chat.completion.chunk",
187
- "created": int(time.time()),
188
- "model": request.model,
189
- "choices": [{
190
- "index": 0,
191
- "delta": {"content": chunk.delta.text if hasattr(chunk.delta, "text") else ""},
192
- "finish_reason": None
193
- }]
194
- }
195
- yield f"data: {json.dumps(data)}\n\n"
196
-
197
- elif chunk.type == "content_block_stop":
198
- data = {
199
- "id": chunk.message.id,
200
- "object": "chat.completion.chunk",
201
- "created": int(time.time()),
202
- "model": request.model,
203
- "choices": [{
204
- "index": 0,
205
- "delta": {},
206
- "finish_reason": "stop"
207
- }]
208
- }
209
- yield f"data: {json.dumps(data)}\n\n"
210
 
211
- yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import logging
3
+ import os
4
  import time
5
+ from concurrent.futures import ThreadPoolExecutor
6
  from contextvars import ContextVar
7
+ from typing import Any, Dict, Generator, List
8
 
9
+ from anthropic import Anthropic
10
+ from fastapi import FastAPI, HTTPException, Request, Response
11
+ from fastapi.responses import JSONResponse, StreamingResponse
12
+ from fastapi.security import HTTPBearer
13
+ from starlette.concurrency import run_in_threadpool
14
+
15
+ from schemas import OpenAIChatCompletionForm, FilterForm
16
+
17
+ # logger
18
+ logger = logging.getLogger()
19
+
20
+ # FastAPI app initialization
21
  app = FastAPI()
22
  security = HTTPBearer()
23
 
24
+ # Context variable for token storage
25
  token_context = ContextVar('token', default=None)
26
 
27
+ # Endpoints that don't require authentication
28
  PUBLIC_ENDPOINTS = {"/"}
29
 
30
+ # Available Anthropic models
31
+ AVAILABLE_MODELS = [
32
+ "claude-3-haiku-20240307",
33
+ "claude-3-opus-20240229",
34
+ "claude-3-sonnet-20240229",
35
+ "claude-3-5-sonnet-20241022"
36
+ ]
37
+
38
  @app.middleware("http")
39
  async def auth_middleware(request: Request, call_next):
40
+ """
41
+ Middleware for handling authentication and response logging.
42
+
43
+ Args:
44
+ request: The incoming HTTP request
45
+ call_next: The next middleware in the chain
46
+
47
+ Returns:
48
+ Response: The processed HTTP response
49
+ """
50
  if request.url.path in PUBLIC_ENDPOINTS:
51
  start_time = time.perf_counter()
52
  response = await call_next(request)
 
57
  try:
58
  auth_header = request.headers.get('Authorization')
59
  if not auth_header:
60
+ raise HTTPException(
61
  status_code=401,
62
+ detail="No authorization header"
63
  )
64
 
65
  scheme, token = auth_header.split()
66
  if scheme.lower() != 'bearer':
67
+ raise HTTPException(
68
  status_code=401,
69
+ detail="Invalid authentication scheme"
70
  )
71
 
 
72
  token_context.set(token)
73
 
 
74
  start_time = time.perf_counter()
75
  response = await call_next(request)
76
  process_time = time.perf_counter() - start_time
 
78
 
79
  return response
80
 
81
+ except HTTPException as http_ex:
82
+ logger.error(
83
+ f"HTTP Exception - Status: {http_ex.status_code} - "
84
+ f"Detail: {http_ex.detail} - Path: {request.url.path}"
85
+ )
86
+ return JSONResponse(
87
+ status_code=http_ex.status_code,
88
+ content={"detail": http_ex.detail}
89
+ )
90
  except Exception as e:
91
+ logger.error(
92
+ f"Unexpected error in middleware - Error: {str(e)} - "
93
+ f"Path: {request.url.path}",
94
+ exc_info=True
95
+ )
96
  return JSONResponse(
97
+ status_code=500,
98
+ content={"detail": "Internal server error"}
99
  )
100
 
101
+
102
  def get_anthropic_client():
103
+ """
104
+ Get an authenticated Anthropic client using the current token.
105
+
106
+ Returns:
107
+ Anthropic: An authenticated Anthropic client instance
108
+
109
+ Raises:
110
+ HTTPException: If no authorization token is found
111
+ """
112
  token = token_context.get()
113
  if not token:
114
  raise HTTPException(status_code=401, detail="No authorization token found")
115
  return Anthropic(api_key=token)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ @app.get("/v1")
119
  @app.get("/")
120
  async def read_root():
121
+ """Root endpoint for API health check."""
122
  return {"Hello": "World!"}
123
 
124
+
125
+ @app.get("/v1/models")
126
  @app.get("/models")
127
  async def get_models():
128
+ """
129
+ Get available Anthropic models.
130
+
131
+ Returns:
132
+ JSONResponse: List of available models and their details
133
+ """
134
+ get_anthropic_client() # Verify token validity
135
 
136
  models = [
137
  {
 
153
  }
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ def stream_message(
158
+ model: str,
159
+ messages: List[Dict[str, Any]]
160
+ ) -> Generator[str, None, None]:
161
+ """
162
+ Stream messages using the specified model.
163
 
164
+ Args:
165
+ model: The model identifier to use
166
+ messages: List of messages to process
167
+
168
+ Returns:
169
+ Generator: Stream of SSE formatted responses
170
+ """
171
  client = get_anthropic_client()
 
172
  response = client.messages.create(
173
+ model=model,
174
+ max_tokens=1024,
175
+ messages=messages,
176
+ stream=True
177
+ )
178
+
179
+ def event_stream() -> Generator[str, None, None]:
180
+ message_id = None
181
+
182
+ for chunk in response:
183
+ if not message_id:
184
+ message_id = f"chatcmpl-{int(time.time())}"
185
+
186
+ if chunk.type == 'content_block_delta':
187
+ data = {
188
+ "id": message_id,
189
+ "object": "chat.completion.chunk",
190
+ "created": int(time.time()),
191
+ "model": model,
192
+ "choices": [
193
+ {
194
+ "index": 0,
195
+ "delta": {
196
+ "content": (
197
+ chunk.delta.text
198
+ if hasattr(chunk.delta, 'text')
199
+ else ""
200
+ )
201
+ },
202
+ "logprobs": None,
203
+ "finish_reason": None,
204
+ }
205
+ ],
206
+ }
207
+ yield f"data: {json.dumps(data)}\n\n"
208
+
209
+ elif chunk.type == 'content_block_stop':
210
+ data = {
211
+ "id": message_id,
212
+ "object": "chat.completion.chunk",
213
+ "created": int(time.time()),
214
+ "model": model,
215
+ "choices": [
216
+ {
217
+ "index": 0,
218
+ "delta": {},
219
+ "logprobs": None,
220
+ "finish_reason": "stop",
221
+ }
222
+ ],
223
+ }
224
+ yield f"data: {json.dumps(data)}\n\n"
225
+
226
+ yield "data: [DONE]\n\n"
227
+
228
+ return event_stream()
229
+
230
+
231
+ def send_message(model: str, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
232
+ """
233
+ Send a message via the Anthropic provider without streaming.
234
+
235
+ Args:
236
+ model: The model identifier to use
237
+ messages: List of messages to process
238
+
239
+ Returns:
240
+ dict: The formatted response from the model
241
+ """
242
+ client = get_anthropic_client()
243
+ response = client.messages.create(
244
+ model=model,
245
+ max_tokens=1024,
246
  messages=messages
247
  )
248
 
249
+ content = response.content[0].text if response.content else ""
250
+
251
  return {
252
  "id": response.id,
253
  "object": "chat.completion",
254
  "created": int(time.time()),
255
+ "model": model,
256
+ "choices": [
257
+ {
258
+ "index": 0,
259
+ "message": {
260
+ "role": "assistant",
261
+ "content": content,
262
+ },
263
+ "logprobs": None,
264
+ "finish_reason": "stop",
265
+ }
266
+ ],
 
 
267
  }
268
 
269
+
270
+ @app.post("/v1/chat/completions")
271
+ @app.post("/chat/completions")
272
+ async def generate_chat_completion(form_data: OpenAIChatCompletionForm):
273
+ """
274
+ Generate chat completions from the model.
275
 
276
+ Args:
277
+ form_data: The chat completion request parameters
278
+
279
+ Returns:
280
+ Union[StreamingResponse, dict]: Either a streaming response or a complete message
281
+ """
282
+ messages = [
283
+ {"role": message.role, "content": message.content}
284
+ for message in form_data.messages
285
+ ]
286
+ model = form_data.model
287
+
288
+ def job():
289
+ """Handle both streaming and non-streaming modes."""
290
+ if form_data.stream:
291
+ return StreamingResponse(
292
+ stream_message(model=model, messages=messages),
293
+ media_type="text/event-stream"
294
+ )
295
+ return send_message(model=model, messages=messages)
296
+
297
+ with ThreadPoolExecutor() as executor:
298
+ return await run_in_threadpool(job)
299
+
300
+
301
+ @app.post("/v1/{pipeline_id}/filter/inlet")
302
+ @app.post("/{pipeline_id}/filter/inlet")
303
+ async def filter_inlet(pipeline_id: str, form_data: FilterForm):
304
+ """
305
+ Handle inlet filtering for the pipeline.
306
 
307
+ Args:
308
+ pipeline_id: The ID of the pipeline
309
+ form_data: The filter parameters
310
+
311
+ Returns:
312
+ dict: The processed request body
313
+ """
314
+ return form_data.body
315
 
316
+
317
+ @app.post("/v1/{pipeline_id}/filter/outlet")
318
+ @app.post("/{pipeline_id}/filter/outlet")
319
+ async def filter_outlet(pipeline_id: str, form_data: FilterForm):
320
+ """
321
+ Handle outlet filtering for the pipeline.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ Args:
324
+ pipeline_id: The ID of the pipeline
325
+ form_data: The filter parameters
326
+
327
+ Returns:
328
+ dict: The processed request body
329
+ """
330
+ return form_data.body
schemas.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from pydantic import BaseModel, ConfigDict, RootModel
3
+
4
+
5
+ class ImageContent(BaseModel):
6
+ """Model for image content in messages."""
7
+ type: str
8
+ image_url: dict
9
+
10
+
11
+ class TextContent(BaseModel):
12
+ """Model for text content in messages."""
13
+ type: str
14
+ text: str
15
+
16
+
17
+ class MessageContent(RootModel):
18
+ """Model for message content that can be either text or image."""
19
+ root: Union[TextContent, ImageContent]
20
+
21
+
22
+ class OpenAIChatMessage(BaseModel):
23
+ """Model for chat messages in OpenAI format."""
24
+ role: str
25
+ content: Union[str, List[MessageContent]]
26
+ model_config = ConfigDict(extra="allow")
27
+
28
+
29
+ class OpenAIChatCompletionForm(BaseModel):
30
+ """Model for chat completion request parameters."""
31
+ stream: bool = True
32
+ model: str
33
+ messages: List[OpenAIChatMessage]
34
+ model_config = ConfigDict(extra="allow")
35
+
36
+
37
+ class FilterForm(BaseModel):
38
+ """Model for filter request parameters."""
39
+ body: dict
40
+ user: Optional[dict] = None
41
+ model_config = ConfigDict(extra="allow")