File size: 3,277 Bytes
27bbfe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, Request, Form
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
# from fastapi.middleware.cors import CORSMiddleware
from helpmate_ai import initialize_conversation, retreive_results, rerank_with_cross_encoder, generate_response
import re
import google.generativeai as genai

# Configure Gemini API
gemini_api_key = open("gemini_api_key.txt", "r").read().strip()
genai.configure(api_key=gemini_api_key)

# Initialize FastAPI app
app = FastAPI()

# Set up templates
templates = Jinja2Templates(directory="templates")

# Serve static files (if needed)
app.mount("/static", StaticFiles(directory="static"), name="static")

# Enable CORS middleware if needed
# app.add_middleware(
#     CORSMiddleware,
#     allow_origins=["*"],  # Adjust origins as per requirements
#     allow_credentials=True,
#     allow_methods=["*"],
#     allow_headers=["*"],
# )

def format_rag_response(response_text):
    formatted_text = response_text.replace("\n", "<br>")
    formatted_text = re.sub(r'(\*\*.*?\*\*)', r'<strong>\1</strong>', formatted_text).replace("**", "")
    formatted_text = re.sub(r'(\d+\.\s)', r'<br><strong>\1</strong>', formatted_text)
    formatted_text = re.sub(r'(\-\s)', r'<br>&bull; ', formatted_text)
    formatted_text = re.sub(r'(Citations?:\s)', r'<br><em>\1</em>', formatted_text)
    formatted_text = re.sub(r'\|\s*', r'</td><td>', formatted_text)
    formatted_text = re.sub(r'\n\|\s*', r'<tr><td>', formatted_text)
    return formatted_text

conversation_bot = []
conversation = initialize_conversation()

# Initialize Gemini model
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation)

def get_gemini_completions(conversation):
    response = model.generate_content(conversation)
    return response.text

introduction = get_gemini_completions(conversation)
conversation_bot.append({'bot': introduction})
top_3_laptops = None

@app.get("/")
async def default_func(request: Request):
    global conversation_bot
    return templates.TemplateResponse("index_invite.html", {"request": request, "name_xyz": conversation_bot})

@app.post("/end_conv")
async def end_conv():
    global conversation_bot, conversation, top_3_laptops
    conversation_bot = []
    conversation = initialize_conversation()
    introduction = get_gemini_completions(conversation)
    conversation_bot.append({'bot': introduction})
    top_3_laptops = None
    return RedirectResponse(url="/", status_code=303)

@app.post("/invite")
async def invite(user_input_message: str = Form(...)):
    global conversation_bot, conversation, top_3_laptops
    user_input = user_input_message
    conversation_bot.append({'user': user_input})

    results_df = retreive_results(user_input)
    top_docs = rerank_with_cross_encoder(user_input, results_df)

    # Generate response
    messages = generate_response(user_input, top_docs)
    response_assistant = get_gemini_completions(messages)

    conversation_bot.append({'bot': format_rag_response(response_assistant)})
    return RedirectResponse(url="/", status_code=303)

# Run the application
if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, debug=True)