Spaces:
Running
Running
File size: 3,847 Bytes
5e8fd8b 1c8932d 9d5513a 553dd69 676f7fc 9d5513a 5e8fd8b 9d5513a 5e8fd8b 9f0a9ca 5e8fd8b ab4a5ae 676f7fc 9aab80e 608245a 9d5513a 9aab80e 9d5513a 9aab80e 9d5513a 6eb1c7e 553dd69 6eb1c7e 9d5513a 6eb1c7e 9d5513a 9aab80e 628c689 608245a 3f857b9 9aab80e 676f7fc 608245a 676f7fc 9d5513a 676f7fc 9d5513a 676f7fc 5e8fd8b 9d5513a 9aab80e 9d5513a 676f7fc 5e8fd8b 4c88907 d329125 9aab80e 4c88907 9d5513a 676f7fc 4c88907 5e8fd8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import logging
import shutil
import time
import re
from fastapi import FastAPI, Request, UploadFile
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse, StreamingResponse
from .rag import ChatPDF
middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=['*'],
allow_headers=['*']
)
]
app = FastAPI(middleware=middleware)
files_dir = os.path.expanduser("~/wtp_be_files/")
session_assistant = ChatPDF()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Session:
isBusy = False # Processing upload or query response
curUserID = ""
prevUserID = ""
lastQueryTimestamp = 0
session = Session()
@app.middleware("http")
async def resolve_availability(request: Request, call_next):
if session.isBusy:
return PlainTextResponse("Server is busy", status_code=503)
request_args = dict(request.query_params)
if not "id" in request_args:
return PlainTextResponse("Bad request", status_code=400)
now = time.time()
if session.curUserID == request_args["id"]:
session.lastQueryTimestamp = now
return await call_next(request)
if session.prevUserID == request_args["id"]:
return PlainTextResponse("Session has expired", status_code=419)
if now - session.lastQueryTimestamp >= 300:
session.lastQueryTimestamp = now
session.prevUserID = session.curUserID
session.curUserID = request_args["id"]
return await call_next(request)
return PlainTextResponse("Server is busy", status_code=503)
def astreamer(generator):
t0 = time.time()
for i in generator:
logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms) - {i}")
yield i
logger.info(f"X-Process-Time: {int((time.time()-t0)*1000)}ms")
session.isBusy = False
@app.get("/query")
async def process_input(text: str):
session.isBusy = True
generator = None
if text and len(text.strip()) > 0:
if session_assistant.pdf_count > 0:
text = text.strip()
streaming_response = session_assistant.ask(text)
generator = streaming_response.response_gen
else:
message = "Please provide the PDF document you'd like to add."
generator = re.split(r'(\s)', message)
else:
message = "Your query is empty. Please provide a query for me to process."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
@app.post("/upload")
def upload(files: list[UploadFile]):
session.isBusy = True
try:
os.makedirs(files_dir)
for file in files:
try:
path = f"{files_dir}/{file.filename}"
file.file.seek(0)
with open(path, 'wb') as destination:
shutil.copyfileobj(file.file, destination)
finally:
file.file.close()
finally:
session_assistant.ingest(files_dir)
shutil.rmtree(files_dir)
message = "All files have been added successfully to your account. Your first query may take a little longer as the system indexes your documents. Please be patient while we process your request."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
@app.get("/clear")
def clear():
session.isBusy = True
session_assistant.clear()
message = "Your files have been cleared successfully."
generator = re.split(r'(\s)', message)
return StreamingResponse(astreamer(generator), media_type='text/event-stream')
@app.get("/")
def ping():
return "Pong!"
|