AWeirdDev commited on
Commit
d80b380
·
verified ·
1 Parent(s): 066e534

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -8
app.py CHANGED
@@ -1,14 +1,29 @@
1
  import time
2
  import json
 
3
 
4
  from fastapi import FastAPI
5
- from fastapi.responses import StreamingResponse
 
6
 
7
  from gradio_client import Client
8
 
9
  app = FastAPI()
10
  client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  async def stream(iter):
13
  while True:
14
  try:
@@ -37,18 +52,46 @@ def make_chunk_obj(i, delta, fr):
37
 
38
  @app.get('/')
39
  async def index():
40
- return { "message": "hello" }
41
 
42
  @app.post('/chat/completions')
43
- async def c_cmp():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def streamer():
45
  text = ""
46
  result = client.submit(
47
- "Hello!!",
48
- 0.9, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
49
  4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
50
- .9, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
51
- 1, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
52
  api_name="/chat"
53
  )
54
  for i, item in enumerate(result):
@@ -58,7 +101,7 @@ async def c_cmp():
58
  )
59
  text = item
60
 
61
- yield "data: " + json.dumps(make_chunk_obj(i, delta, "stop"))
62
  yield "data: [END]"
63
 
64
  return StreamingResponse(streamer())
 
1
  import time
2
  import json
3
+ from typing import List, Literal
4
 
5
  from fastapi import FastAPI
6
+ from fastapi.responses import StreamingResponse, JSONResponse
7
+ from pydantic import BaseModel
8
 
9
  from gradio_client import Client
10
 
11
  app = FastAPI()
12
  client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
13
 
14
+ class Message(BaseModel):
15
+ role: Literal["user", "assistant", "system"]
16
+ content: str
17
+
18
+ class Payload(BaseModel):
19
+ stream: bool = False
20
+ model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2"
21
+ messages: List[Message]
22
+ temperature: float
23
+ presence_penalty: float
24
+ frequency_penalty: float
25
+ top_p: float
26
+
27
  async def stream(iter):
28
  while True:
29
  try:
 
52
 
53
  @app.get('/')
54
  async def index():
55
+ return JSONResponse({ "message": "hello" })
56
 
57
  @app.post('/chat/completions')
58
+ async def c_cmp(payload: Payload):
59
+ if not payload.stream:
60
+ return JSONResponse(
61
+ {
62
+ "id": str(time.time_ns()),
63
+ "object": "chat.completion",
64
+ "created": round(time.time()),
65
+ "model": payload.model,
66
+ "system_fingerprint": "wtf",
67
+ "choices": [
68
+ {
69
+ "index": 0,
70
+ "message": {
71
+ "role": "assistant",
72
+ "content": client.predict(
73
+ payload.messages.model_dump_json(),
74
+ payload.temperature,
75
+ 4096,
76
+ payload.top_p,
77
+ payload.presence_penalty,
78
+ api_name="/chat"
79
+ )
80
+ }
81
+ }
82
+ ]
83
+ }
84
+ )
85
+
86
+
87
  def streamer():
88
  text = ""
89
  result = client.submit(
90
+ payload.messages.model_dump_json(),
91
+ payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
92
  4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
93
+ payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
94
+ payload.presence_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
95
  api_name="/chat"
96
  )
97
  for i, item in enumerate(result):
 
101
  )
102
  text = item
103
 
104
+ yield "data: " + json.dumps(make_chunk_obj(i, "", "stop"))
105
  yield "data: [END]"
106
 
107
  return StreamingResponse(streamer())