File size: 5,690 Bytes
3e8fafc
 
 
4dfafae
3e8fafc
 
8cbab56
 
 
 
 
 
a3155e0
8cbab56
 
 
 
a3155e0
8cbab56
3e8fafc
8cbab56
 
 
 
 
 
3e8fafc
7345394
 
3e8fafc
 
 
60f3227
 
3e8fafc
8cbab56
 
4dfafae
a49b3d3
4dfafae
 
 
 
 
3e8fafc
4dfafae
 
 
 
 
 
 
3e8fafc
4dfafae
3e8fafc
4dfafae
3e8fafc
 
a49b3d3
6fb2144
3e8fafc
4dfafae
 
3e8fafc
 
6fb2144
3e8fafc
 
 
b9d30a5
3e8fafc
 
 
 
 
 
 
 
b9d30a5
3e8fafc
 
8cbab56
 
 
 
 
a49b3d3
8cbab56
 
 
 
 
 
 
7614cb1
 
 
8cbab56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
014f3db
 
8cbab56
014f3db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import streamlit as st

from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
from langchain.vectorstores.faiss import FAISS
from huggingface_hub import snapshot_download

from langchain.callbacks import StreamlitCallbackHandler
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
    AgentTokenBufferMemory,
)
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, AIMessage, HumanMessage
from langchain.prompts import MessagesPlaceholder
from langsmith import Client

client = Client()

st.set_page_config(
    page_title="Chat with CFA Level 1",
    page_icon="๐Ÿ“–",
    layout="wide",
    initial_sidebar_state="collapsed",
)

#Load API Key
api_key = os.environ["OPENAI_API_KEY"]   

#### sidebar section 1 ####
with st.sidebar:
    book = st.radio("Embedding Model: ",
                   ["Sbert"]
                   )


#load embedding models
@st.cache_resource
def load_embedding_models(model):

    if model == 'Sbert':
        model_sbert = "sentence-transformers/all-mpnet-base-v2"
        emb = HuggingFaceEmbeddings(model_name=model_sbert)

    elif model == 'Instruct':
        embed_instruction = "Represent the financial paragraph for document retrieval: "
        query_instruction = "Represent the question for retrieving supporting documents: "
        model_instr = "hkunlp/instructor-large"
        emb = HuggingFaceInstructEmbeddings(model_name=model_instr,
                                                 embed_instruction=embed_instruction,
                                                 query_instruction=query_instruction)

    return emb

embeddings = load_embedding_models(book)

##### functionss ####
@st.cache_data
def load_vectorstore(_embeddings):
    # download from hugging face
    cache_dir="cfa_level_1_cache"
    snapshot_download(repo_id="nickmuchi/CFA_Level_1_Text_Embeddings",
                                    repo_type="dataset",
                                    revision="main",
                                    allow_patterns="CFA_Level_1/*",
                                    cache_dir=cache_dir,
                                    )

    target_dir = "CFA_Level_1"

    # Walk through the directory tree recursively
    for root, dirs, files in os.walk(cache_dir):
        # Check if the target directory is in the list of directories
        if target_dir in dirs:
            # Get the full path of the target directory
            target_path = os.path.join(root, target_dir)

            print(target_path)

    # load faiss
    vectorstore = FAISS.load_local(folder_path=target_path, embeddings=_embeddings)

    return vectorstore.as_retriever(search_kwargs={"k": 4})

tool = create_retriever_tool(
    load_vectorstore(embeddings),
    "search_cfa_docs",
    "Searches and returns documents regarding the CFA level 1 curriculum. CFA is a rigorous program for investment professionals which covers topics such as ethics, corporate finance, economics, fixed income, equities and derivatives markets. You do not know anything about the CFA program, so if you are ever asked about CFA material or curriculum you should use this tool.",
)
tools = [tool]
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
message = SystemMessage(
    content=(
        "You are a helpful CFA level 1 chatbot who is tasked with answering questions about the CFA level 1 program. "
        "Do not answer any question that is not related to the CFA program or finance."
        "If there is any ambiguity, politely decline to answer the question."
    )
)

prompt = OpenAIFunctionsAgent.create_prompt(
    system_message=message,
    extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    verbose=True,
    return_intermediate_steps=True,
)
memory = AgentTokenBufferMemory(llm=llm)
starter_message = "Ask me anything about the CFA Level 1 Curriculum!"
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
    st.session_state["messages"] = [AIMessage(content=starter_message)]


def send_feedback(run_id, score):
    client.create_feedback(run_id, "user_score", score=score)


for msg in st.session_state.messages:
    if isinstance(msg, AIMessage):
        st.chat_message("assistant").write(msg.content)
    elif isinstance(msg, HumanMessage):
        st.chat_message("user").write(msg.content)
    memory.chat_memory.add_message(msg)


if prompt := st.chat_input(placeholder=starter_message):
    st.chat_message("user").write(prompt)
    with st.chat_message("assistant"):
        st_callback = StreamlitCallbackHandler(st.container())
        response = agent_executor(
            {"input": prompt, "history": st.session_state.messages},
            callbacks=[st_callback],
            include_run_info=True,
        )
        st.session_state.messages.append(AIMessage(content=response["output"]))
        st.write(response["output"])
        memory.save_context({"input": prompt}, response)
        st.session_state["messages"] = memory.buffer
        run_id = response["__run"].run_id

        col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
        with col_text:
            st.text("Feedback:")

        # with col1:
        #     st.button("๐Ÿ‘", on_click=send_feedback, args=(run_id, 1))

        # with col2:
        #     st.button("๐Ÿ‘Ž", on_click=send_feedback, args=(run_id, 0))