Arbazkhan-cs's picture
Update app.py
18e3524 verified
raw
history blame contribute delete
No virus
5.18 kB
import os
from flask import Flask, request, render_template, redirect, url_for, session, flash
from werkzeug.utils import secure_filename
from utils import create_retriever_tool_agent, create_arxiv_tool_agent, google_search, get_prompt
from langchain_groq import ChatGroq
from langchain.agents import create_tool_calling_agent, AgentExecutor
from dotenv import load_dotenv
import textwrap
from datetime import timedelta
# Load environment variables from a .env file
load_dotenv()
app = Flask(__name__, template_folder="./templates")
app.config['UPLOAD_FOLDER'] = "./uploads"
app.config['ALLOWED_EXTENSIONS'] = {'pdf'}
app.secret_key = os.urandom(24) # Secret key for session management
def allowed_file(filename):
"""Check if the uploaded file is an allowed type."""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def create_agent(api_key):
"""Create an agent using the provided API key."""
pdf_retriever_tool = create_retriever_tool_agent(app.config['file_path'])
arxiv_tool = create_arxiv_tool_agent()
tools = [pdf_retriever_tool, arxiv_tool, google_search]
prompt = get_prompt()
agent = create_tool_calling_agent(
llm=ChatGroq(model="llama3-8b-8192", api_key=api_key, temperature=0.5, max_tokens=512),
tools=tools,
prompt=prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
return agent_executor
def wrap_text_preserve_code(text, max_line_length=120):
"""Wrap text while preserving code blocks and handling newline characters."""
parts = text.split("```")
wrapped_text = []
for i, part in enumerate(parts):
if i % 2 == 0:
# Split by newlines and wrap each line separately
lines = part.splitlines()
wrapped_lines = []
for line in lines:
wrapped_lines.extend(textwrap.wrap(line, width=max_line_length))
wrapped_text.extend(wrapped_lines)
else:
wrapped_text.append(f"```{part}```")
return "\n".join(wrapped_text)
@app.route('/', methods=['GET', 'POST'])
def index():
"""Main route to check API key and redirect to appropriate page."""
if 'api_key' not in session:
return redirect(url_for('get_api_key'))
return redirect(url_for('index2'))
@app.route('/get_api_key', methods=['GET', 'POST'])
def get_api_key():
"""Route to get API key from user."""
if request.method == 'POST':
api_key = request.form['api_key']
session['api_key'] = api_key
session.permanent = True
app.permanent_session_lifetime = timedelta(minutes=30)
return redirect(url_for('index2'))
return render_template('index.html')
@app.route('/index2', methods=['GET', 'POST'])
def index2():
"""Main interface route to interact with the agent."""
if 'api_key' not in session:
return redirect(url_for('get_api_key'))
if 'agent_executor' not in app.config:
app.config['file_path'] = "./Pdfs"
app.config['agent_executor'] = create_agent(session['api_key'])
if request.method == 'POST':
file = request.files.get('file')
if file.filename != "" and allowed_file(file.filename):
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
app.config['file_path'] = app.config['UPLOAD_FOLDER'] # Store the file path in UPLOAD_FOLDER
session['file_path'] = file_path # Store the file path in session
# Reload the agent with the new PDF
agent_executor = create_agent(session['api_key'])
app.config['agent_executor'] = agent_executor
if 'query' in request.form:
query = request.form['query']
agent_executor = app.config['agent_executor']
try:
result = agent_executor.invoke({"input": query})["output"]
except Exception as e:
result = str(e)
wrapped_result = wrap_text_preserve_code(result, max_line_length=105)
return render_template('index2.html', result=wrapped_result)
return render_template('index2.html')
@app.route('/logout')
def logout():
"""Logout route to clear session and delete uploaded file."""
file_path = session.pop('file_path', None) # Remove file path from session
if file_path and os.path.exists(file_path):
os.remove(file_path) # Delete the file
session.clear()
return redirect(url_for('index'))
@app.before_request
def before_request():
session.permanent = True
app.permanent_session_lifetime = timedelta(minutes=30) # Set session lifetime as needed
@app.teardown_request
def cleanup(exception=None):
file_path = session.pop('file_path', None)
if file_path and os.path.exists(file_path):
os.remove(file_path)
if __name__ == '__main__':
# Create upload folder if it does not exist
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'])
print("Flask app running...")
app.run(debug=False)