Spaces:
Runtime error
Runtime error
ridhimamlds
commited on
Upload folder using huggingface_hub
Browse files- README.md +2 -8
- __pycache__/leave.cpython-310.pyc +0 -0
- __pycache__/rag.cpython-310.pyc +0 -0
- leave.py +101 -0
- main.py +61 -0
- rag.py +108 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.37.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: agent_app
|
3 |
+
app_file: main.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.37.2
|
|
|
|
|
6 |
---
|
|
|
|
__pycache__/leave.cpython-310.pyc
ADDED
Binary file (4.85 kB). View file
|
|
__pycache__/rag.cpython-310.pyc
ADDED
Binary file (4.03 kB). View file
|
|
leave.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.tools import BaseTool
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
4 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
5 |
+
from langchain_community.utilities import SQLDatabase
|
6 |
+
from langchain.schema import SystemMessage, HumanMessage, AIMessage
|
7 |
+
import os
|
8 |
+
import psycopg2
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
# Set up database connection details
|
12 |
+
pg_uri = "postgresql://ridhima:0skESLQ9D6c3m7smqwG47peapk7HzVvu@dpg-cq2h613v2p9s73esp8eg-a.singapore-postgres.render.com/hr_qugd"
|
13 |
+
|
14 |
+
def get_db_connection():
|
15 |
+
conn = psycopg2.connect(pg_uri)
|
16 |
+
return conn
|
17 |
+
|
18 |
+
class LeaveRequestInfoTool(BaseTool):
|
19 |
+
name = "leave_request_information"
|
20 |
+
description = "Provides information about the leave request process and database schema."
|
21 |
+
|
22 |
+
def _run(self, query: str) -> str:
|
23 |
+
return """
|
24 |
+
Leave Request Process:
|
25 |
+
1. Collect employee ID
|
26 |
+
2. Ask for leave type
|
27 |
+
3. Get reason for leave
|
28 |
+
4. Get start date of leave
|
29 |
+
5. Get end date of leave
|
30 |
+
6. Calculate duration (in days)
|
31 |
+
7. Insert data into request table
|
32 |
+
8. Confirm submission to user
|
33 |
+
|
34 |
+
Database Schema:
|
35 |
+
Table: request
|
36 |
+
Columns:
|
37 |
+
- leave_id (auto-increment integer)
|
38 |
+
- employee_id (integer)
|
39 |
+
- leave_type (text)
|
40 |
+
- reason (text)
|
41 |
+
- start_of_leave (date)
|
42 |
+
- end_of_leave (date)
|
43 |
+
- duration (integer, calculated in days)
|
44 |
+
- leave_status (text, default 'Pending')
|
45 |
+
|
46 |
+
Instructions:
|
47 |
+
- Collect all necessary information from the user one by one.
|
48 |
+
- Calculate the duration as the number of days between start_of_leave and end_of_leave.
|
49 |
+
- Once all information is collected, formulate an SQL INSERT statement for the 'request' table.
|
50 |
+
- REMEMBER TO EXECUTE THE INSERT QUERY FOR EACH REQUEST ONLY ONCE.
|
51 |
+
- CRITICAL: EXECUTE THE INSERT QUERY FOR EACH REQUEST ONLY ONCE.
|
52 |
+
- If you receive an error message saying an insertion has already been made, DO NOT attempt to insert again.
|
53 |
+
- Instead, inform the user that their request has been submitted and ask if they need anything else.
|
54 |
+
- After insertion, confirm to the user that their request has been submitted.
|
55 |
+
"""
|
56 |
+
|
57 |
+
class SQLAgentTool(BaseTool):
|
58 |
+
name = "sql_agent"
|
59 |
+
description = "Use this tool to interact with the database and execute SQL queries."
|
60 |
+
|
61 |
+
def _run(self, query: str) -> str:
|
62 |
+
print("Executing Query: ", query)
|
63 |
+
conn = get_db_connection()
|
64 |
+
cur = conn.cursor()
|
65 |
+
try:
|
66 |
+
cur.execute(query)
|
67 |
+
conn.commit()
|
68 |
+
result = "Your leave request has been submitted successfully."
|
69 |
+
except Exception as e:
|
70 |
+
conn.rollback()
|
71 |
+
result = f"An error occurred: {str(e)}"
|
72 |
+
finally:
|
73 |
+
cur.close()
|
74 |
+
conn.close()
|
75 |
+
return result
|
76 |
+
|
77 |
+
prompt = ChatPromptTemplate.from_messages([
|
78 |
+
SystemMessage(content="""You are an HR assistant. You can help with leave requests and provide information about company policies.
|
79 |
+
- For leave requests, ask for each piece of information one at a time. After collecting all information, use the sql_agent tool to INSERT the data into the 'request' table.
|
80 |
+
- For information queries about company policies, use the rag_info tool to provide accurate information from the RAG-trained model.
|
81 |
+
|
82 |
+
The pieces of information you need to collect for leave requests are:
|
83 |
+
1. Employee ID
|
84 |
+
2. Leave Type
|
85 |
+
3. Reason for Leave
|
86 |
+
4. Start Date of Leave (YYYY-MM-DD)
|
87 |
+
5. End Date of Leave (YYYY-MM-DD)
|
88 |
+
Calculate the duration as the number of days between start and end dates using PostgreSQL functions and convert duration to integer.
|
89 |
+
Use 'Pending' as the default leave_status.
|
90 |
+
|
91 |
+
Remember to use the correct column names as per the table structure:
|
92 |
+
request(leave_id, employee_id, leave_type, reason, start_of_leave, end_of_leave, duration, leave_status)
|
93 |
+
Where leave_id is auto-increment and should not be included in the INSERT statement.
|
94 |
+
|
95 |
+
For company policies, provide detailed and accurate information based on the RAG-trained model.
|
96 |
+
"""),
|
97 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
98 |
+
("human", "{input}"),
|
99 |
+
MessagesPlaceholder(variable_name="agent_scratchpad")
|
100 |
+
])
|
101 |
+
|
main.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from langchain.schema import HumanMessage, AIMessage
|
4 |
+
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
5 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
6 |
+
from rag import create_rag_tool
|
7 |
+
from leave import LeaveRequestInfoTool, SQLAgentTool, prompt
|
8 |
+
|
9 |
+
|
10 |
+
# Initialize LLM
|
11 |
+
key = "sk-proj-LdVhjM2bTI27bA3grOK8T3BlbkFJh5whi2UHYKkgM2pNwpbe"
|
12 |
+
os.environ["OPENAI_API_KEY"] = key
|
13 |
+
|
14 |
+
llm = ChatOpenAI(model="gpt-4", temperature=0)
|
15 |
+
|
16 |
+
# Add the RAGTool to the list of tools
|
17 |
+
rag_tool = create_rag_tool(llm=llm)
|
18 |
+
leave_request_info_tool = LeaveRequestInfoTool()
|
19 |
+
sql_tool = SQLAgentTool()
|
20 |
+
|
21 |
+
tools = [leave_request_info_tool, sql_tool, rag_tool]
|
22 |
+
|
23 |
+
agent = create_openai_tools_agent(llm, tools, prompt)
|
24 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
25 |
+
|
26 |
+
def truncate_chat_history(chat_history, max_tokens=3000):
|
27 |
+
total_tokens = sum(len(message.content.split()) for message in chat_history)
|
28 |
+
while total_tokens > max_tokens and chat_history:
|
29 |
+
chat_history.pop(0)
|
30 |
+
total_tokens = sum(len(message.content.split()) for message in chat_history)
|
31 |
+
return chat_history
|
32 |
+
|
33 |
+
def handle_user_input(user_input, chat_history):
|
34 |
+
if chat_history is None:
|
35 |
+
chat_history = []
|
36 |
+
|
37 |
+
chat_history.append(HumanMessage(content=user_input))
|
38 |
+
truncated_chat_history = truncate_chat_history(chat_history)
|
39 |
+
|
40 |
+
response = agent_executor.invoke(
|
41 |
+
{"input": user_input, "chat_history": truncated_chat_history}
|
42 |
+
)
|
43 |
+
ai_response = response['output']
|
44 |
+
|
45 |
+
chat_history.append(AIMessage(content=ai_response))
|
46 |
+
|
47 |
+
messages = [(message.content, "user" if isinstance(message, HumanMessage) else "bot") for message in chat_history]
|
48 |
+
return messages, chat_history
|
49 |
+
|
50 |
+
with gr.Blocks() as demo:
|
51 |
+
gr.Markdown("# HR Assistant Chatbot")
|
52 |
+
chatbot = gr.Chatbot()
|
53 |
+
state = gr.State()
|
54 |
+
txt = gr.Textbox(placeholder="Type your message here...")
|
55 |
+
|
56 |
+
txt.submit(handle_user_input, [txt, state], [chatbot, state])
|
57 |
+
|
58 |
+
demo.launch(share = True)
|
59 |
+
|
60 |
+
|
61 |
+
|
rag.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from langchain.tools import Tool
|
3 |
+
from langchain_community.vectorstores import Neo4jVector
|
4 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
5 |
+
from langchain_community.document_loaders import PyPDFLoader
|
6 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
7 |
+
from langchain_core.output_parsers import StrOutputParser
|
8 |
+
from langchain_core.runnables import RunnablePassthrough
|
9 |
+
from langchain import hub
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Initialize LLM
|
13 |
+
key = "sk-proj-LdVhjM2bTI27bA3grOK8T3BlbkFJh5whi2UHYKkgM2pNwpbe"
|
14 |
+
os.environ["OPENAI_API_KEY"] = key
|
15 |
+
|
16 |
+
class RAGToolConfig(BaseModel):
|
17 |
+
NEO4J_URI: str = Field(default="neo4j+s://741a3118.databases.neo4j.io")
|
18 |
+
NEO4J_USERNAME: str = Field(default="neo4j")
|
19 |
+
NEO4J_PASSWORD: str = Field(default="XvUolnAXmgx9SG_lRSJuisbDClxi2MiTKGIoBdqN53A")
|
20 |
+
pdf_path: str = Field(default="/mnt/d/atx/hragent/rag/Sirca_Paints.pdf")
|
21 |
+
|
22 |
+
class RAGToolImplementation:
|
23 |
+
def __init__(self, config: RAGToolConfig, llm):
|
24 |
+
self.config = config
|
25 |
+
self.llm = llm # Store the llm instance
|
26 |
+
self.embedding_model = OpenAIEmbeddings()
|
27 |
+
self.vectorstore = self._initialize_vectorstore()
|
28 |
+
self.rag_chain = self._setup_rag_chain()
|
29 |
+
|
30 |
+
def _initialize_vectorstore(self):
|
31 |
+
try:
|
32 |
+
# Try to load existing vector store
|
33 |
+
vectorstore = Neo4jVector(
|
34 |
+
url=self.config.NEO4J_URI,
|
35 |
+
username=self.config.NEO4J_USERNAME,
|
36 |
+
password=self.config.NEO4J_PASSWORD,
|
37 |
+
embedding=self.embedding_model,
|
38 |
+
index_name="pdf_embeddings",
|
39 |
+
node_label="PDFChunk",
|
40 |
+
text_node_property="text",
|
41 |
+
embedding_node_property="embedding"
|
42 |
+
)
|
43 |
+
vectorstore.similarity_search("Test query", k=1)
|
44 |
+
print("Existing vector store loaded.")
|
45 |
+
except Exception as e:
|
46 |
+
print(f"Creating new vector store. Error: {e}")
|
47 |
+
# Load and process the PDF
|
48 |
+
loader = PyPDFLoader(self.config.pdf_path)
|
49 |
+
docs = loader.load()
|
50 |
+
|
51 |
+
# Split the document into chunks
|
52 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
53 |
+
splits = text_splitter.split_documents(docs)
|
54 |
+
|
55 |
+
# Create new vector store
|
56 |
+
vectorstore = Neo4jVector.from_documents(
|
57 |
+
documents=splits,
|
58 |
+
embedding=self.embedding_model,
|
59 |
+
url=self.config.NEO4J_URI,
|
60 |
+
username=self.config.NEO4J_USERNAME,
|
61 |
+
password=self.config.NEO4J_PASSWORD,
|
62 |
+
index_name="pdf_embeddings",
|
63 |
+
node_label="PDFChunk",
|
64 |
+
text_node_property="text",
|
65 |
+
embedding_node_property="embedding"
|
66 |
+
)
|
67 |
+
print("New vector store created and loaded.")
|
68 |
+
return vectorstore
|
69 |
+
|
70 |
+
def _setup_rag_chain(self):
|
71 |
+
retriever = self.vectorstore.as_retriever()
|
72 |
+
prompt = hub.pull("rlm/rag-prompt")
|
73 |
+
|
74 |
+
def format_docs(docs):
|
75 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
76 |
+
|
77 |
+
rag_chain = (
|
78 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
79 |
+
| prompt
|
80 |
+
| self.llm # Use the llm instance here
|
81 |
+
| StrOutputParser()
|
82 |
+
)
|
83 |
+
return rag_chain
|
84 |
+
|
85 |
+
def run(self, query: str) -> str:
|
86 |
+
try:
|
87 |
+
response = self.rag_chain.invoke(query)
|
88 |
+
return response
|
89 |
+
except Exception as e:
|
90 |
+
return f"An error occurred while processing the query: {str(e)}"
|
91 |
+
|
92 |
+
|
93 |
+
def create_rag_tool(config: RAGToolConfig = RAGToolConfig(), llm=None):
|
94 |
+
implementation = RAGToolImplementation(config, llm)
|
95 |
+
return Tool(
|
96 |
+
name="RAGTool",
|
97 |
+
description="Retrieval-Augmented Generation Tool for querying PDF content about Sirca Paints",
|
98 |
+
func=implementation.run
|
99 |
+
)
|
100 |
+
|
101 |
+
# # Example Usage
|
102 |
+
# if __name__ == "__main__":
|
103 |
+
# llm = ChatOpenAI(model="gpt-4", temperature=0)
|
104 |
+
# rag_tool = create_rag_tool(llm=llm)
|
105 |
+
|
106 |
+
# # Test the tool
|
107 |
+
# result = rag_tool.run("What is spil ethics?")
|
108 |
+
# print(result)
|