|
import os |
|
from flask import Flask, request, render_template, redirect, url_for, session, flash |
|
from werkzeug.utils import secure_filename |
|
from utils import create_retriever_tool_agent, create_arxiv_tool_agent, google_search, get_prompt |
|
from langchain_groq import ChatGroq |
|
from langchain.agents import create_tool_calling_agent, AgentExecutor |
|
from dotenv import load_dotenv |
|
import textwrap |
|
from datetime import timedelta |
|
|
|
|
|
load_dotenv() |
|
|
|
app = Flask(__name__, template_folder="./templates") |
|
app.config['UPLOAD_FOLDER'] = "./uploads" |
|
app.config['ALLOWED_EXTENSIONS'] = {'pdf'} |
|
app.secret_key = os.urandom(24) |
|
|
|
def allowed_file(filename): |
|
"""Check if the uploaded file is an allowed type.""" |
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] |
|
|
|
|
|
def create_agent(api_key): |
|
"""Create an agent using the provided API key.""" |
|
pdf_retriever_tool = create_retriever_tool_agent(app.config['file_path']) |
|
arxiv_tool = create_arxiv_tool_agent() |
|
tools = [pdf_retriever_tool, arxiv_tool, google_search] |
|
prompt = get_prompt() |
|
agent = create_tool_calling_agent( |
|
llm=ChatGroq(model="llama3-8b-8192", api_key=api_key, temperature=0.5, max_tokens=512), |
|
tools=tools, |
|
prompt=prompt |
|
) |
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) |
|
return agent_executor |
|
|
|
|
|
def wrap_text_preserve_code(text, max_line_length=120): |
|
"""Wrap text while preserving code blocks and handling newline characters.""" |
|
parts = text.split("```") |
|
wrapped_text = [] |
|
for i, part in enumerate(parts): |
|
if i % 2 == 0: |
|
|
|
lines = part.splitlines() |
|
wrapped_lines = [] |
|
for line in lines: |
|
wrapped_lines.extend(textwrap.wrap(line, width=max_line_length)) |
|
wrapped_text.extend(wrapped_lines) |
|
else: |
|
wrapped_text.append(f"```{part}```") |
|
return "\n".join(wrapped_text) |
|
|
|
|
|
@app.route('/', methods=['GET', 'POST']) |
|
def index(): |
|
"""Main route to check API key and redirect to appropriate page.""" |
|
if 'api_key' not in session: |
|
return redirect(url_for('get_api_key')) |
|
return redirect(url_for('index2')) |
|
|
|
@app.route('/get_api_key', methods=['GET', 'POST']) |
|
def get_api_key(): |
|
"""Route to get API key from user.""" |
|
if request.method == 'POST': |
|
api_key = request.form['api_key'] |
|
session['api_key'] = api_key |
|
session.permanent = True |
|
app.permanent_session_lifetime = timedelta(minutes=30) |
|
return redirect(url_for('index2')) |
|
return render_template('index.html') |
|
|
|
@app.route('/index2', methods=['GET', 'POST']) |
|
def index2(): |
|
"""Main interface route to interact with the agent.""" |
|
if 'api_key' not in session: |
|
return redirect(url_for('get_api_key')) |
|
|
|
if 'agent_executor' not in app.config: |
|
app.config['file_path'] = "./Pdfs" |
|
app.config['agent_executor'] = create_agent(session['api_key']) |
|
|
|
if request.method == 'POST': |
|
file = request.files.get('file') |
|
if file.filename != "" and allowed_file(file.filename): |
|
filename = secure_filename(file.filename) |
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
|
file.save(file_path) |
|
app.config['file_path'] = app.config['UPLOAD_FOLDER'] |
|
session['file_path'] = file_path |
|
|
|
|
|
agent_executor = create_agent(session['api_key']) |
|
app.config['agent_executor'] = agent_executor |
|
|
|
if 'query' in request.form: |
|
query = request.form['query'] |
|
agent_executor = app.config['agent_executor'] |
|
try: |
|
result = agent_executor.invoke({"input": query})["output"] |
|
except Exception as e: |
|
result = str(e) |
|
wrapped_result = wrap_text_preserve_code(result, max_line_length=105) |
|
return render_template('index2.html', result=wrapped_result) |
|
|
|
return render_template('index2.html') |
|
|
|
@app.route('/logout') |
|
def logout(): |
|
"""Logout route to clear session and delete uploaded file.""" |
|
file_path = session.pop('file_path', None) |
|
if file_path and os.path.exists(file_path): |
|
os.remove(file_path) |
|
session.clear() |
|
return redirect(url_for('index')) |
|
|
|
@app.before_request |
|
def before_request(): |
|
session.permanent = True |
|
app.permanent_session_lifetime = timedelta(minutes=30) |
|
|
|
@app.teardown_request |
|
def cleanup(exception=None): |
|
file_path = session.pop('file_path', None) |
|
if file_path and os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
if __name__ == '__main__': |
|
|
|
if not os.path.exists(app.config['UPLOAD_FOLDER']): |
|
os.makedirs(app.config['UPLOAD_FOLDER']) |
|
print("Flask app running...") |
|
app.run(debug=False) |
|
|