File size: 1,777 Bytes
5bb4c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import time
import json

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from gradio_client import Client

app = FastAPI()
client = Client("AWeirdDev/mistral-7b-instruct-v0.2")

async def stream(iter):
    while True:
        try:
            value = await asyncio.to_thread(iter.__next__)
            yield value
        except StopIteration:
            break

def make_chunk_obj(i, delta, fr):
    return {
        "id": str(time.time_ns()),
        "object": "chat.completion.chunk",
        "created": round(time.time()),
        "model": "mistral-7b-instruct-v0.2",
        "system_fingerprint": "wtf",
        "choices": [
            {
                "index": i,
                "delta": {
                    "content": delta
                },
                "finish_reason": fr
            }
        ]
    }

@app.get('/chat/completions')
async def index():
    def streamer():
        text = ""
        result = client.submit(
        		"Hello!!",
        		0.9,	# float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
        		4096,	# float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
        		.9,	# float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
        		1,	# float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
        		api_name="/chat"
        )
        for i, item in enumerate(result):
            delta = item[len(text):]
            yield "data: " + json.dumps(
                make_chunk_obj(i, delta, None)
            )
            text = item

        yield "data: " + json.dumps(make_chunk_obj(i, delta, "stop"))
        yield "data: [END]"

    return StreamingResponse(streamer())

print(result)