from typing import List, Union from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse, JSONResponse from answerer import Answerer from mapper import Mapper mapper = Mapper( repo="sentence-transformers", model="multi-qa-distilbert-cos-v1", ) answerer = Answerer( model="RWKV-5-World-3B-v2-20231118-ctx16k.pth", vocab="rwkv_vocab_v20230424", strategy="cpu bf16", ctx_limit=16*1024, ) app = FastAPI() HTML = """

""" @app.get("/") def index(): return HTMLResponse(HTML) @app.get("/map") def map(query: str, items: List[str]): indices = mapper(query, items) return JSONResponse(indices) @app.websocket("/answer") async def answer(ws: WebSocket): await ws.accept() print("ws accepted!") input = await ws.receive_text() print("input received!") await ws.send_text("OK!") output = answerer(input, 32) print("output created!") for el in output: print(f"sent: '{el}'") await ws.send_text(el) await ws.close()