Arbazkhan-cs's picture
Update app.py
08bf29a verified
raw
history blame
No virus
2.95 kB
import os
from flask import Flask, request, render_template, redirect, url_for
from werkzeug.utils import secure_filename
from utils import create_retriever_tool_agent, create_arxiv_tool_agent, google_search, get_prompt
from langchain_groq import ChatGroq
from langchain.agents import create_tool_calling_agent, AgentExecutor
from dotenv import load_dotenv
load_dotenv()
app = Flask(__name__, template_folder="./templates")
app.config['UPLOAD_FOLDER'] = "./uploads"
app.config['ALLOWED_EXTENSIONS'] = {'pdf'}
# Initialize tools and agent
app.config["pdf_retriever_tool"] = create_retriever_tool_agent("./Pdfs")
app.config["arxiv_tool"] = create_arxiv_tool_agent()
app.config["tools"] = [app.config["pdf_retriever_tool"], app.config["arxiv_tool"], google_search]
app.config["prompt"] = get_prompt()
app.config["agent"] = create_tool_calling_agent(
llm=ChatGroq(model="llama3-8b-8192", temperature=0.5),
tools=app.config["tools"],
prompt=app.config["prompt"]
)
app.config["agent_executor"] = AgentExecutor(agent=app.config["agent"], tools=app.config["tools"], verbose=True)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
@app.route('/', methods=['GET', 'POST'])
def index():
print("Index function called...")
if request.method == 'POST':
file = request.files.get('file')
if file.filename != "" and allowed_file(file.filename):
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# Load the PDF using PyPDFLoader
app.config["pdf_retriever_tool"] = create_retriever_tool_agent(app.config['UPLOAD_FOLDER'])
app.config["tools"] = [app.config["pdf_retriever_tool"], app.config["arxiv_tool"], google_search]
app.config["agent"] = create_tool_calling_agent(
llm=ChatGroq(model="llama3-8b-8192", temperature=0.5),
tools=app.config["tools"],
prompt=app.config["prompt"]
)
app.config["agent_executor"] = AgentExecutor(agent=app.config["agent"], tools=app.config["tools"], verbose=True)
if 'query' in request.form:
query = request.form['query']
result = app.config["agent_executor"].invoke({"input": query})["output"]
max_line_length = 100
text = result
lines = [text[i:i+max_line_length] for i in range(0, len(text), max_line_length)]
formatted_text = "\n".join(lines)
return render_template('index.html', result=formatted_text)
return render_template('index.html')
if __name__ == '__main__':
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'])
print("Flask app running...")
app.run(debug=False)