from typing import Union from fastapi import FastAPI from fastapi.encoders import jsonable_encoder from fastapi.websockets import WebSocket, WebSocketDisconnect, WebSocketState from fastapi.responses import HTMLResponse, JSONResponse from accelerator import Accelerator from answerer import Answerer from mapper import Mapper mapper = Mapper("multi-qa-distilbert-cos-v1") answerer = Answerer( model="rwkv-5-world-3b-v2-20231118-ctx16k", vocab="rwkv_vocab_v20230424", strategy="cpu bf16", ctx_limit=16*1024, ) accelerator = Accelerator() app = FastAPI() HTML = """
""" @app.get("/") def index(): return HTMLResponse(HTML) @app.websocket("/accelerate") async def answer(ws: WebSocket): await accelerator.connect(ws) @app.post("/map") def map(query: Union[str, None], items: Union[list[str], None]): scores = mapper(query, items) return JSONResponse(jsonable_encoder(scores)) @app.websocket("/answer") async def answer(ws: WebSocket): await ws.accept() try: input = await ws.receive_text() except WebSocketDisconnect: return if accelerator.connected(): output = await accelerator.accelerate(input) if(ws.client_state == WebSocketState.CONNECTED): await ws.send_text(output) elif(ws.client_state == WebSocketState.DISCONNECTED): return else: output = answerer(input, 32) async for el in output: if(ws.client_state == WebSocketState.CONNECTED): await ws.send_text(el) elif(ws.client_state == WebSocketState.DISCONNECTED): return await ws.close()