Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings | |
from langchain.vectorstores.faiss import FAISS | |
from huggingface_hub import snapshot_download | |
from langchain.callbacks import StreamlitCallbackHandler | |
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( | |
AgentTokenBufferMemory, | |
) | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import SystemMessage, AIMessage, HumanMessage | |
from langchain.prompts import MessagesPlaceholder | |
from langsmith import Client | |
client = Client() | |
st.set_page_config( | |
page_title="Chat with CFA Level 1", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
#Load API Key | |
api_key = os.environ["OPENAI_API_KEY"] | |
#### sidebar section 1 #### | |
with st.sidebar: | |
book = st.radio("Embedding Model: ", | |
["Sbert"] | |
) | |
#load embedding models | |
def load_embedding_models(model): | |
if model == 'Sbert': | |
model_sbert = "sentence-transformers/all-mpnet-base-v2" | |
emb = HuggingFaceEmbeddings(model_name=model_sbert) | |
elif model == 'Instruct': | |
embed_instruction = "Represent the financial paragraph for document retrieval: " | |
query_instruction = "Represent the question for retrieving supporting documents: " | |
model_instr = "hkunlp/instructor-large" | |
emb = HuggingFaceInstructEmbeddings(model_name=model_instr, | |
embed_instruction=embed_instruction, | |
query_instruction=query_instruction) | |
return emb | |
embeddings = load_embedding_models(book) | |
##### functionss #### | |
def load_vectorstore(_embeddings): | |
# download from hugging face | |
cache_dir="cfa_level_1_cache" | |
snapshot_download(repo_id="nickmuchi/CFA_Level_1_Text_Embeddings", | |
repo_type="dataset", | |
revision="main", | |
allow_patterns="CFA_Level_1/*", | |
cache_dir=cache_dir, | |
) | |
target_dir = "CFA_Level_1" | |
# Walk through the directory tree recursively | |
for root, dirs, files in os.walk(cache_dir): | |
# Check if the target directory is in the list of directories | |
if target_dir in dirs: | |
# Get the full path of the target directory | |
target_path = os.path.join(root, target_dir) | |
print(target_path) | |
# load faiss | |
vectorstore = FAISS.load_local(folder_path=target_path, embeddings=_embeddings) | |
return vectorstore.as_retriever(search_kwargs={"k": 4}) | |
tool = create_retriever_tool( | |
load_vectorstore(embeddings), | |
"search_cfa_docs", | |
"Searches and returns documents regarding the CFA level 1 curriculum. CFA is a rigorous program for investment professionals which covers topics such as ethics, corporate finance, economics, fixed income, equities and derivatives markets. You do not know anything about the CFA program, so if you are ever asked about CFA material or curriculum you should use this tool.", | |
) | |
tools = [tool] | |
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4") | |
message = SystemMessage( | |
content=( | |
"You are a helpful CFA level 1 chatbot who is tasked with answering questions about the CFA level 1 program. " | |
"Do not answer any question that is not related to the CFA program or finance." | |
"If there is any ambiguity, politely decline to answer the question." | |
) | |
) | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=message, | |
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], | |
) | |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
return_intermediate_steps=True, | |
) | |
memory = AgentTokenBufferMemory(llm=llm) | |
starter_message = "Ask me anything about the CFA Level 1 Curriculum!" | |
if "messages" not in st.session_state or st.sidebar.button("Clear message history"): | |
st.session_state["messages"] = [AIMessage(content=starter_message)] | |
def send_feedback(run_id, score): | |
client.create_feedback(run_id, "user_score", score=score) | |
for msg in st.session_state.messages: | |
if isinstance(msg, AIMessage): | |
st.chat_message("assistant").write(msg.content) | |
elif isinstance(msg, HumanMessage): | |
st.chat_message("user").write(msg.content) | |
memory.chat_memory.add_message(msg) | |
if prompt := st.chat_input(placeholder=starter_message): | |
st.chat_message("user").write(prompt) | |
with st.chat_message("assistant"): | |
st_callback = StreamlitCallbackHandler(st.container()) | |
response = agent_executor( | |
{"input": prompt, "history": st.session_state.messages}, | |
callbacks=[st_callback], | |
include_run_info=True, | |
) | |
st.session_state.messages.append(AIMessage(content=response["output"])) | |
st.write(response["output"]) | |
memory.save_context({"input": prompt}, response) | |
st.session_state["messages"] = memory.buffer | |
run_id = response["__run"].run_id | |
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) | |
with col_text: | |
st.text("Feedback:") | |
# with col1: | |
# st.button("π", on_click=send_feedback, args=(run_id, 1)) | |
# with col2: | |
# st.button("π", on_click=send_feedback, args=(run_id, 0)) | |