Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,29 @@
|
|
1 |
import time
|
2 |
import json
|
|
|
3 |
|
4 |
from fastapi import FastAPI
|
5 |
-
from fastapi.responses import StreamingResponse
|
|
|
6 |
|
7 |
from gradio_client import Client
|
8 |
|
9 |
app = FastAPI()
|
10 |
client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
async def stream(iter):
|
13 |
while True:
|
14 |
try:
|
@@ -37,18 +52,46 @@ def make_chunk_obj(i, delta, fr):
|
|
37 |
|
38 |
@app.get('/')
|
39 |
async def index():
|
40 |
-
return { "message": "hello" }
|
41 |
|
42 |
@app.post('/chat/completions')
|
43 |
-
async def c_cmp():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def streamer():
|
45 |
text = ""
|
46 |
result = client.submit(
|
47 |
-
|
48 |
-
|
49 |
4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
|
50 |
-
.
|
51 |
-
|
52 |
api_name="/chat"
|
53 |
)
|
54 |
for i, item in enumerate(result):
|
@@ -58,7 +101,7 @@ async def c_cmp():
|
|
58 |
)
|
59 |
text = item
|
60 |
|
61 |
-
yield "data: " + json.dumps(make_chunk_obj(i,
|
62 |
yield "data: [END]"
|
63 |
|
64 |
return StreamingResponse(streamer())
|
|
|
1 |
import time
|
2 |
import json
|
3 |
+
from typing import List, Literal
|
4 |
|
5 |
from fastapi import FastAPI
|
6 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
7 |
+
from pydantic import BaseModel
|
8 |
|
9 |
from gradio_client import Client
|
10 |
|
11 |
app = FastAPI()
|
12 |
client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
|
13 |
|
14 |
+
class Message(BaseModel):
|
15 |
+
role: Literal["user", "assistant", "system"]
|
16 |
+
content: str
|
17 |
+
|
18 |
+
class Payload(BaseModel):
|
19 |
+
stream: bool = False
|
20 |
+
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2"
|
21 |
+
messages: List[Message]
|
22 |
+
temperature: float
|
23 |
+
presence_penalty: float
|
24 |
+
frequency_penalty: float
|
25 |
+
top_p: float
|
26 |
+
|
27 |
async def stream(iter):
|
28 |
while True:
|
29 |
try:
|
|
|
52 |
|
53 |
@app.get('/')
|
54 |
async def index():
|
55 |
+
return JSONResponse({ "message": "hello" })
|
56 |
|
57 |
@app.post('/chat/completions')
|
58 |
+
async def c_cmp(payload: Payload):
|
59 |
+
if not payload.stream:
|
60 |
+
return JSONResponse(
|
61 |
+
{
|
62 |
+
"id": str(time.time_ns()),
|
63 |
+
"object": "chat.completion",
|
64 |
+
"created": round(time.time()),
|
65 |
+
"model": payload.model,
|
66 |
+
"system_fingerprint": "wtf",
|
67 |
+
"choices": [
|
68 |
+
{
|
69 |
+
"index": 0,
|
70 |
+
"message": {
|
71 |
+
"role": "assistant",
|
72 |
+
"content": client.predict(
|
73 |
+
payload.messages.model_dump_json(),
|
74 |
+
payload.temperature,
|
75 |
+
4096,
|
76 |
+
payload.top_p,
|
77 |
+
payload.presence_penalty,
|
78 |
+
api_name="/chat"
|
79 |
+
)
|
80 |
+
}
|
81 |
+
}
|
82 |
+
]
|
83 |
+
}
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
def streamer():
|
88 |
text = ""
|
89 |
result = client.submit(
|
90 |
+
payload.messages.model_dump_json(),
|
91 |
+
payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
|
92 |
4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
|
93 |
+
payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
|
94 |
+
payload.presence_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
|
95 |
api_name="/chat"
|
96 |
)
|
97 |
for i, item in enumerate(result):
|
|
|
101 |
)
|
102 |
text = item
|
103 |
|
104 |
+
yield "data: " + json.dumps(make_chunk_obj(i, "", "stop"))
|
105 |
yield "data: [END]"
|
106 |
|
107 |
return StreamingResponse(streamer())
|