from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.staticfiles import StaticFiles from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.templating import Jinja2Templates from starlette.routing import Route from starlette.responses import RedirectResponse import uvicorn from transformers import pipeline from pydub import AudioSegment import torch import asyncio device = "cuda:2" if torch.cuda.is_available() else "cpu" app = Starlette() app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["X-Requested-With", "Content-Type"], ) templates = Jinja2Templates(directory="templates") async def homepage(request): return templates.TemplateResponse("index.html", {"request": request}) async def upload_file(request): formdata = await request.form() file = formdata["file"] input = await file.read() response_q = asyncio.Queue() await request.app.model_queue.put((input, response_q)) output = await response_q.get() return templates.TemplateResponse( "index.html", {"request": request, "content": output['text']}, ) async def server_loop(q): pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large", chunk_length_s=30, device=device, ) pipe.model.config.forced_decoder_ids = ( pipe.tokenizer.get_decoder_prompt_ids( language="zh", task="transcribe" ) ) while True: (input, response_q) = await q.get() out = pipe(input) await response_q.put(out) app = Starlette( routes=[ Route("/", homepage, methods=["GET"]), Route("/upload", upload_file, methods=["POST"]), ], ) @app.on_event("startup") async def startup_event(): q = asyncio.Queue() app.model_queue = q asyncio.create_task(server_loop(q))