Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Form | |
from fastapi.responses import RedirectResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
# from fastapi.middleware.cors import CORSMiddleware | |
from helpmate_ai import initialize_conversation, retreive_results, rerank_with_cross_encoder, generate_response | |
import re | |
import google.generativeai as genai | |
# Configure Gemini API | |
gemini_api_key = open("gemini_api_key.txt", "r").read().strip() | |
genai.configure(api_key=gemini_api_key) | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Set up templates | |
templates = Jinja2Templates(directory="templates") | |
# Serve static files (if needed) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Enable CORS middleware if needed | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=["*"], # Adjust origins as per requirements | |
# allow_credentials=True, | |
# allow_methods=["*"], | |
# allow_headers=["*"], | |
# ) | |
def format_rag_response(response_text): | |
formatted_text = response_text.replace("\n", "<br>") | |
formatted_text = re.sub(r'(\*\*.*?\*\*)', r'<strong>\1</strong>', formatted_text).replace("**", "") | |
formatted_text = re.sub(r'(\d+\.\s)', r'<br><strong>\1</strong>', formatted_text) | |
formatted_text = re.sub(r'(\-\s)', r'<br>• ', formatted_text) | |
formatted_text = re.sub(r'(Citations?:\s)', r'<br><em>\1</em>', formatted_text) | |
formatted_text = re.sub(r'\|\s*', r'</td><td>', formatted_text) | |
formatted_text = re.sub(r'\n\|\s*', r'<tr><td>', formatted_text) | |
return formatted_text | |
conversation_bot = [] | |
conversation = initialize_conversation() | |
# Initialize Gemini model | |
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation) | |
def get_gemini_completions(conversation): | |
response = model.generate_content(conversation) | |
return response.text | |
introduction = get_gemini_completions(conversation) | |
conversation_bot.append({'bot': introduction}) | |
top_3_laptops = None | |
async def default_func(request: Request): | |
global conversation_bot | |
return templates.TemplateResponse("index_invite.html", {"request": request, "name_xyz": conversation_bot}) | |
async def end_conv(): | |
global conversation_bot, conversation, top_3_laptops | |
conversation_bot = [] | |
conversation = initialize_conversation() | |
introduction = get_gemini_completions(conversation) | |
conversation_bot.append({'bot': introduction}) | |
top_3_laptops = None | |
return RedirectResponse(url="/", status_code=303) | |
async def invite(user_input_message: str = Form(...)): | |
global conversation_bot, conversation, top_3_laptops | |
user_input = user_input_message | |
conversation_bot.append({'user': user_input}) | |
results_df = retreive_results(user_input) | |
top_docs = rerank_with_cross_encoder(user_input, results_df) | |
# Generate response | |
messages = generate_response(user_input, top_docs) | |
response_assistant = get_gemini_completions(messages) | |
conversation_bot.append({'bot': format_rag_response(response_assistant)}) | |
return RedirectResponse(url="/", status_code=303) | |
# Run the application | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000, debug=True) |