AWeirdDev commited on
Commit
5bb4c9e
·
verified ·
1 Parent(s): a19b818

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import StreamingResponse
6
+ from gradio_client import Client
7
+
8
+ app = FastAPI()
9
+ client = Client("AWeirdDev/mistral-7b-instruct-v0.2")
10
+
11
+ async def stream(iter):
12
+ while True:
13
+ try:
14
+ value = await asyncio.to_thread(iter.__next__)
15
+ yield value
16
+ except StopIteration:
17
+ break
18
+
19
+ def make_chunk_obj(i, delta, fr):
20
+ return {
21
+ "id": str(time.time_ns()),
22
+ "object": "chat.completion.chunk",
23
+ "created": round(time.time()),
24
+ "model": "mistral-7b-instruct-v0.2",
25
+ "system_fingerprint": "wtf",
26
+ "choices": [
27
+ {
28
+ "index": i,
29
+ "delta": {
30
+ "content": delta
31
+ },
32
+ "finish_reason": fr
33
+ }
34
+ ]
35
+ }
36
+
37
+ @app.get('/chat/completions')
38
+ async def index():
39
+ def streamer():
40
+ text = ""
41
+ result = client.submit(
42
+ "Hello!!",
43
+ 0.9, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
44
+ 4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
45
+ .9, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
46
+ 1, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
47
+ api_name="/chat"
48
+ )
49
+ for i, item in enumerate(result):
50
+ delta = item[len(text):]
51
+ yield "data: " + json.dumps(
52
+ make_chunk_obj(i, delta, None)
53
+ )
54
+ text = item
55
+
56
+ yield "data: " + json.dumps(make_chunk_obj(i, delta, "stop"))
57
+ yield "data: [END]"
58
+
59
+ return StreamingResponse(streamer())
60
+
61
+ print(result)