File size: 4,236 Bytes
5bb4c9e d80b380 5bb4c9e d80b380 500acbd f4fee37 8808472 f4fee37 5bb4c9e d80b380 f4fee37 d80b380 cb9be9f d80b380 5bb4c9e f4fee37 5bb4c9e f4fee37 46e7e6f f4fee37 3ebf038 500acbd 5bb4c9e 43d6f9a 500acbd d80b380 3ebf038 532449f d80b380 0f9f52a d80b380 5bb4c9e f4fee37 532449f d80b380 500acbd d80b380 cb9be9f 5bb4c9e 3c9aed4 5bb4c9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import time
import json
from typing import List, Literal
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
app = FastAPI()
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.2"
)
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
class Payload(BaseModel):
stream: bool = False
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2"
messages: List[Message]
temperature: float = 0.9
frequency_penalty: float = 1.2
top_p: float = 0.9
async def stream(iter):
while True:
try:
value = await asyncio.to_thread(iter.__next__)
yield value
except StopIteration:
break
def format_prompt(messages: List[Message]):
prompt = "<s>"
for message in messages:
if message['role'] == 'user':
prompt += f"[INST] {message['content']} [/INST]"
else:
prompt += f" {message['content']}</s> "
return prompt
def make_chunk_obj(i, delta, fr):
return {
"id": str(time.time_ns()),
"object": "chat.completion.chunk",
"created": round(time.time()),
"model": "mistral-7b-instruct-v0.2",
"system_fingerprint": "wtf",
"choices": [
{
"index": i,
"delta": {
"content": delta
},
"finish_reason": fr
}
]
}
def generate(
messages,
temperature=0.9,
max_new_tokens=256,
top_p=0.95,
repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=None
)
formatted_prompt = format_prompt(messages)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
t = response.token.text
yield t if t != "</s>" else ""
#return output
def generate_norm(*args) -> str:
t = ""
for chunk in generate(*args):
t += chunk
return t
@app.get('/')
async def index():
return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" })
@app.post('/chat/completions')
async def c_cmp(payload: Payload):
if not payload.stream:
return JSONResponse(
{
"id": str(time.time_ns()),
"object": "chat.completion",
"created": round(time.time()),
"model": payload.model,
"system_fingerprint": "wtf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": generate_norm(
payload.model_dump()['messages'],
payload.temperature,
4096,
payload.top_p,
payload.frequency_penalty
)
}
}
]
}
)
def streamer():
text = ""
result = generate(
payload.model_dump()['messages'],
payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component
payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
payload.frequency_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
)
for i, item in enumerate(result):
yield item
return StreamingResponse(streamer())
|