Spaces:
Sleeping
Sleeping
import json | |
import operator | |
import uuid | |
from operator import itemgetter | |
from typing import Annotated, Sequence, TypedDict | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from langchain.retrievers import ParentDocumentRetriever | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain.storage import InMemoryStore | |
# from langchain_core.output_parsers import StrOutputParser | |
from langchain.tools import tool | |
from langchain_community.document_loaders import ArxivLoader | |
from langchain_community.tools.arxiv.tool import ArxivQueryRun | |
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun | |
from langchain_community.tools.pubmed.tool import PubmedQueryRun | |
# from langgraph.graph.message import add_messages | |
from langchain_core.messages import ( | |
BaseMessage, | |
FunctionMessage, | |
SystemMessage, | |
) | |
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain_core.utils.function_calling import convert_to_openai_function | |
from langchain_openai import ChatOpenAI | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain_qdrant import Qdrant | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langgraph.graph import END, StateGraph | |
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver | |
# from langchain_community.tools.pubmed.tool import PubmedQueryRun | |
from langgraph.prebuilt import ToolExecutor, ToolInvocation | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Distance, VectorParams | |
from langchain_core.messages import HumanMessage | |
# GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE # | |
# ---- ENV VARIABLES ---- # | |
""" | |
This function will load our environment file (.env) if it is present. | |
NOTE: Make sure that .env is in your .gitignore file - it is by default, but please ensure it remains there. | |
""" | |
load_dotenv() | |
""" | |
We will load our environment variables here. | |
""" | |
# ---- GLOBAL DECLARATIONS ---- # | |
# -- RETRIEVAL -- # | |
""" | |
1. Load Documents from Text File | |
2. Split Documents into Chunks | |
3. Load HuggingFace Embeddings (remember to use the URL we set above) | |
4. Index Files if they do not exist, otherwise load the vectorstore | |
""" | |
### 1. CREATE TEXT LOADER AND LOAD DOCUMENTS | |
### NOTE: PAY ATTENTION TO THE PATH THEY ARE IN. | |
docs = ArxivLoader( | |
query='"mental health counseling" AND (data OR analytics OR "machine learning")', | |
load_max_docs=10, | |
sort_by="submittedDate", | |
sort_order="descending", | |
).load() | |
### 2. CREATE QDRANT CLIENT VECTORE STORE | |
client = QdrantClient(":memory:") | |
client.create_collection( | |
collection_name="split_parents", | |
vectors_config=VectorParams(size=1536, distance=Distance.COSINE), | |
) | |
vectorstore = Qdrant( | |
client, | |
collection_name="split_parents", | |
embeddings=OpenAIEmbeddings(model="text-embedding-3-small"), | |
) | |
store = InMemoryStore() | |
### 3. CREATE PARENT DOCUMENT TEXT SPLITTER AND RETRIEVER INITIATED | |
parent_document_retriever = ParentDocumentRetriever( | |
vectorstore=vectorstore, | |
docstore=store, | |
child_splitter=RecursiveCharacterTextSplitter(chunk_size=400), | |
parent_splitter=RecursiveCharacterTextSplitter(chunk_size=2000), | |
) | |
parent_document_retriever.add_documents(docs) | |
### 4. CREATE PROMPT OBJECT | |
RAG_PROMPT = """\ | |
Your are a professional mental helth advisor. Use the following context to answer the user's query. If you cannot answer the question, please respond with 'I don't know'. | |
Question: | |
{question} | |
Context: | |
{context} | |
""" | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
### 5. CREATE CHAIN PIPLINE RETRIVER | |
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo", streaming=True) | |
def create_qa_chain(retriever): | |
mentahealth_qa_llm = openai_chat_model | |
created_qa_chain = ( | |
{ | |
"context": itemgetter("question") | retriever, | |
"question": itemgetter("question"), | |
} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| { | |
"response": rag_prompt | mentahealth_qa_llm | StrOutputParser(), | |
"context": itemgetter("context"), | |
} | |
) | |
return created_qa_chain | |
### 6. DEFINE LIST OF TOOLS AVAILABLE FOR AND TOOL EXECUTOR WRAPPED AROUND THEM | |
async def rag_tool(question: str) -> str: | |
"""Only use this tool to retrieve research relevant information from the knowledge base.""" | |
# advanced_rag_prompt=ChatPromptTemplate.from_template(INSTRUCTION_PROMPT_TEMPLATE.format(user_query=question)) | |
parent_document_retriever_qa_chain = create_qa_chain(parent_document_retriever) | |
response = await parent_document_retriever_qa_chain.ainvoke({"question": question}) | |
return response["response"] | |
tool_belt = [ | |
rag_tool, | |
PubmedQueryRun(), | |
ArxivQueryRun(), | |
DuckDuckGoSearchRun(), | |
] | |
tool_executor = ToolExecutor(tool_belt) | |
### 7. CONVERT TOOLS INTO THE FORMAT COMAPTIBLE WITH OPENAI'S FUNCTION CALLING API THEN BINDING THEM TO MODEL TO BE USED WHEN GENERATION | |
model = ChatOpenAI(temperature=0, streaming=True) | |
functions = [convert_to_openai_function(t) for t in tool_belt] | |
model = model.bind_functions(functions) | |
model = model.with_config(tags=["final_node"]) | |
### 8. USING the TypedDict FROM THE typing module AND THE langchain_core.messages module, A CUSTOM TYPE NAMED AgentState CREATED. | |
# THE AgentState type HAS A FIELD NAMED <messages> THAT IS OF TYPE Annotated[Sequence[BaseMessage], operator.add]. | |
# Sequence[BaseMessage]: INDICATES THAT MESSAGES ARE A SEQUENCE OF BaseMessage OBJECTS. | |
# Annotated: USED TO ATTACH MEATADATA TO THE TYPE, THEN THE MESSAGE FIELD TREATED AS CONCATENABLE SEQUENCE OF BASEMASSAGES TO OPERATOR.ADD FUNCTION. | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
### 9. TWO FUNCTIONS DEFINED: 1. call_model AND 2. call_tool FUNCTIONS | |
# 1. INVOKES THE MODEL BY THE MESSAGES EXTRACTED FROM THE STATE RETURNING A DICT CONTAINING THE RESPONSE MESSAGE, | |
# 2.1 ToolInvocation OBJECT CREATED USING THE NAME AND ARGUMENTS EXTRACTED FROM THE LAST MASSAGE EXTRACTED FROM THE STATE, | |
# 2.2. tool_executor IS INVOKED BY THE CREATED toolInvocation OBJECT | |
# 2.3 FunctionMessage OBJECT IS CREATED WITH THE tool_executor RESPONSE AND THE NAME OF THAT TOOL | |
# 2.4 RETURN IS A DICT CONTAINING FunctionMessage OBJECT. | |
async def call_model(state): | |
messages = state["messages"] | |
response = await model.ainvoke(messages) | |
return {"messages": [response]} | |
async def call_tool(state): | |
last_message = state["messages"][-1] | |
action = ToolInvocation( | |
tool=last_message.additional_kwargs["function_call"]["name"], | |
tool_input=json.loads( | |
last_message.additional_kwargs["function_call"]["arguments"] | |
), | |
) | |
print() | |
print(last_message.additional_kwargs["function_call"]["name"]) | |
print() | |
response = await tool_executor.ainvoke(action) | |
function_message = FunctionMessage(content=str(response), name=action.tool) | |
return {"messages": [function_message]} | |
###10. GRAPG CREATION WITH HELPFULNESS EVALUATION | |
# should_continue CHECKS IF THE LAST MASSAGE IN THE STATE IS TO CONTINUE (additional_kwargs EXISTS) OR END. | |
# THE add_conditional_edges() method IS ORIGINATED FROM THIS REPONSE, EITHER TRANSITION TO ACTION NODE OR END. | |
def should_continue(state): | |
last_message = state["messages"][-1] | |
if "function_call" not in last_message.additional_kwargs: | |
return "end" | |
return "continue" | |
async def check_helpfulness(state): | |
initial_query = state["messages"][0] | |
final_response = state["messages"][-1] | |
# adding artificial_loop | |
if len(state["messages"]) > 20: | |
return "end" | |
prompt_template = """\ | |
Given an initial query and a final response, determine if the final response is extremely helpful or not. Please indicate helpfulness with a 'Y'\ | |
and unhelpfulness as an 'N'. | |
Initial Query: | |
{initial_query} | |
Final Response: | |
{final_response}""" | |
prompt_template = PromptTemplate.from_template(prompt_template) | |
helpfulness_check_model = ChatOpenAI(model="gpt-4") | |
helpfulness_check_chain = ( | |
prompt_template | helpfulness_check_model | StrOutputParser() | |
) | |
helpfulness_response = await helpfulness_check_chain.ainvoke( | |
{"initial_query": initial_query, "final_response": final_response} | |
) | |
if "Y" in helpfulness_response: | |
print("helpful!") | |
return "end" | |
else: | |
print(" Not helpful!!") | |
return "continue" | |
def dummy_node(state): | |
return | |
### 11. SETTING THE GRAPH WORKFLOW: | |
# 1. AN INSTANCE OF THE STATEGRAPH CREATED OF THE TYPE AgentState. THREE NODES ADDED TO THE GRAPH USING add_node() method: | |
# 1.1 THE "agent" NODE IS ASSOCIATED WITH THE call_model FUNCTION. | |
# 1.2 THE "action" NODE IS ASSOCIATED WITH THE call_tool FUNCTION. | |
# 1.3 THE "passthrough" NODE IS A CUSTOM NODE THAT IS ASSOCIATED WITH CHECKING HELPFULNESS. | |
# 1.5 THE CONDITIONAL EDGES | |
# 1.5.1 BETWEEN agent NODE AND THE OTHER TWO NODES TO EITHER action NODE OR passthrough NODE | |
# 1.5.2 BETWEEN passthrough NODE AND agen NODE OR END NODE. | |
# 1.5.3 BETWEEN agent AND action NODES AS MODEL HAS ACCESS TO TOOLS FOR RESPONSE GENERATION. | |
def get_state_update_bot(): | |
workflow = StateGraph(AgentState) | |
workflow.add_node("agent", call_model) # agent node has access to llm | |
workflow.add_node("action", call_tool) # action node has access to tools | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
should_continue, | |
{ | |
"continue": "action", # tools | |
"end": END, | |
}, | |
) | |
workflow.add_edge("action", "agent") # tools | |
state_update_bot = workflow.compile() | |
return state_update_bot | |
# -------------------------------------------------- | |
from langgraph.checkpoint.memory import MemorySaver | |
def get_state_update_bot_with_helpfullness_node(): | |
# memory = MemorySaver() | |
graph_with_helpfulness_check = StateGraph(AgentState) | |
graph_with_helpfulness_check.add_node("agent", call_model) | |
graph_with_helpfulness_check.add_node("action", call_tool) | |
graph_with_helpfulness_check.add_node("passthrough", dummy_node) | |
graph_with_helpfulness_check.set_entry_point("agent") | |
graph_with_helpfulness_check.add_conditional_edges( | |
"agent", should_continue, {"continue": "action", "end": "passthrough"} | |
) | |
graph_with_helpfulness_check.add_conditional_edges( | |
"passthrough", check_helpfulness, {"continue": "agent", "end": END} | |
) | |
graph_with_helpfulness_check.add_edge("action", "agent") | |
memory=AsyncSqliteSaver.from_conn_string(":memory:") | |
return graph_with_helpfulness_check.compile(checkpointer=memory) | |
### 12. | |
# def convert_inputs(input_object): | |
# system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses. | |
# Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution. | |
# First, look up the RAG (retrieval-augmented generation) and then arxiv research or use InternetSearch: | |
# You will operate in a loop of Thought, Action, PAUSE, and Observation. At the end of the loop, you will provide an Answer. | |
# Instructions: | |
# Thought: Describe your thoughts about the user's question. | |
# Action: Choose one of the available actions to gather information or provide insights. | |
# PAUSE: Pause to allow the action to complete. | |
# Observation: Review the results of the action. | |
# Available Actions: | |
# Use the tools at your disposal to look up information or resolve the consultancy. You are allowed to make multiple calls (either together or in sequence).: | |
# 1. rag_tool: RAG (Retrieval-Augmented Generation) to access relevant mental health information. | |
# 2. DuckDuckGoSearchRun: Perform an online search: InternetSearch to find up-to-date resources and recommendations. | |
# 3. ArxivQueryRun: Find relevant research or content. | |
# 3. PubMedQuerRun: Find a specific coping strategies or management techniques by doing research paper | |
# You may make multiple calls to these tools as needed to provide comprehensive advice. | |
# Present your final response in a clear, structured format, including a chart of recommended actions if appropriate. | |
# User's question: {input_object["messages"]} | |
# Response: Your task is When responding to users' personal issues or concerns: | |
# 1. With a brief empathetic acknowledgment of the user's situation, continue | |
# 2. Provide practical, actionable advice that often includes | |
# 3. Suggesting professional help (e.g., therapists, counselors) when appropriate | |
# 4. Encouraging open communication and dialogue with involved parties and | |
# 5. Recommending self-reflection or exploration of emotions and values and | |
# 6. Offering specific coping strategies or management techniques | |
# """ | |
# return {"messages": [SystemMessage(content=system_prompt)]} | |
def convert_inputs(input_object): | |
system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses. | |
Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution. | |
You must Use the tools at your dsiposal. | |
You must consult pubmed, then ragtool, then duckduckgo_results_json. | |
You must make multiple calls to these tools as needed to provide comprehensive advice. | |
User's question: {input_object["messages"]} | |
""" | |
return {"messages": [SystemMessage(content=system_prompt)]} | |
# Define the function to parse the output | |
def parse_output(input_state): | |
return input_state | |
# bot_with_helpfulness_check=get_state_update_bot_with_helpfullness_node() # type: | |
# bot=get_state_update_bot() | |
# Create the agent chain | |
# agent_chain = convert_inputs | bot_with_helpfulness_check# | StrOutputParser()#| parse_output | |
# Run the agent chain with the input | |
# messages=agent_chain.invoke({"question": mental_health_counseling_data['test'][14]['Context']}) | |
# --------------------------------------------------------------------------------------------------------- | |
# DEPLOYMENT | |
# --------------------------------------------------------------------------------------------------------- | |
def rename(original_author: str): | |
""" | |
This function can be used to rename the 'author' of a message. | |
In this case, we're overriding the 'Assistant' author to be 'Paul Graham Essay Bot'. | |
""" | |
rename_dict = {"Assistant": "Mental Health Advisor Bot"} | |
return rename_dict.get(original_author, original_author) | |
async def start_chat(): | |
""" | |
This function will be called at the start of every user session. | |
We will build our LCEL RAG chain here, and store it in the user session. | |
The user session is a dictionary that is unique to each user session, and is stored in the memory of the server. | |
""" | |
### BUILD LCEL RAG CHAIN THAT ONLY RETURNS TEXT | |
# lcel_rag_chain = ( {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")} | |
# | rag_prompt | hf_llm | |
# ) | |
memory=MemorySaver | |
bot_with_helpfulness_check = get_state_update_bot_with_helpfullness_node()#(checkpointer=memory) | |
# type: ignore | |
lcel_agent_langgraph_chain = ( | |
convert_inputs | bot_with_helpfulness_check) #| StrOutputParser()) | |
# bot=get_state_update_bot() | |
# lcel_agent_chain = convert_inputs | bot| parse_output# StrOutputParser() | |
cl.user_session.set("langgraph_agent_chain", lcel_agent_langgraph_chain) | |
# Create a thread id and pass it as configuration | |
# to be able to use Langgraph's MemorySaver | |
conversation_id = str(uuid.uuid4()) | |
config = {"configurable": {"thread_id": conversation_id}} | |
cl.user_session.set("config", config) | |
async def main(message: cl.Message): | |
""" | |
This function will be called every time a message is recieved from a session. | |
""" | |
# msg is the human message, could be mixed with system message. | |
# agent_message is the agent's response. | |
graph = cl.user_session.get("langgraph_agent_chain") | |
config = cl.user_session.get("config") | |
final_output="" | |
# inputs = {"messages": [("user", message.content)]} | |
inputs={"messages": [HumanMessage(message.content)]} | |
agent_message = cl.Message(content="") | |
await agent_message.send() | |
# final_output="" | |
async for event in graph.astream_events( | |
inputs, | |
config=config,#=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
version="v2", | |
): | |
kind = event["event"] | |
tags = event.get("tags", []) | |
name=event.get("name", "") | |
print() | |
print(f"Received event: {event}") # Debugging statement | |
print() | |
if kind == "on_chain_start": | |
if ( | |
event["name"] == "Agent" | |
): # Was assigned when creating the agent with `.with_config({"run_name": "Agent"})` | |
print( | |
f"Starting agent: {event['name']} with input: {event['data'].get('input')}" | |
) | |
# await agent_message.send() | |
elif kind == "on_chain_end" and name=="RunnableSequence":#"tool_end" in tags: | |
if 'output' in event['data'] and "agent" in event["data"]['output']: | |
agent_output=event["data"]["output"]["agent"] | |
if "messages" in agent_output and agent_output["messages"]: | |
final_output=agent_output["messages"][0].content | |
await agent_message.stream_token(final_output) | |
# elif kind=="on_chain_stream": | |
# data=event['data'] | |
# if data["chunk"].content: | |
# print(f"Streaming content: {data['chunk'].content}") | |
# await agent_message.stream_token(data["chunk"].content) | |
await agent_message.send() | |
#docker build -t llm-app-langgraph-react-chainlit-mentalmindbt . | |
#docker run -it -p 7860:7860 llm-app-langgraph-react-chainlit-mentalmindbt:latest |