File size: 5,468 Bytes
5621d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47351aa
 
5621d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47351aa
5621d9a
 
 
 
47351aa
5621d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47351aa
5621d9a
 
 
 
 
47351aa
5621d9a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
import langchain
from langchain.document_loaders import OnlinePDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
import pinecone

st.sidebar.markdown(" # Welcome to Ztudy ")

# ------------------------ PDF ------------------------
# Hard-coded PDFs (TODO: make this dynamic from Google Drive)
pdf_dict = {}
pdf_dict["Field Guide to Data Science"] = "https://wolfpaulus.com/wp-content/uploads/2017/05/field-guide-to-data-science.pdf"
pdf_dict["2023 GPT-4 Technical Report"] = "https://cdn.openai.com/papers/gpt-4.pdf"
pdf_dict["Administering Data Centers"] = "https://drive.google.com/file/d/1r3bqHq-ZszXnX6UJLOaeoEEa1plUYXZu"
pdf_dict["First Aid Reference Guide (Google)"] = "https://drive.google.com/file/d/1fzN2wa_uJ8INUYim88eCymSvJdyDT2fz/"
pdf_dict["First Aid Reference Guide (Public)"] = "https://www.sja.ca/sites/default/files/2021-05/First%20aid%20reference%20guide_V4.1_Public.pdf"
pdf_dict["Astronomy 2106"] = "https://drive.google.com/file/d/1XXmjMLENP90-eXEqOaTxQ8O56ZwExsVT"
pdf_dict["Astronomy 2106 (New)"] = "https://drive.google.com/file/d/1w1S-TY2PzeJ9mjPVb1yLwcYh5EI44oP7"
pdf_dict["Learning Deep Learning: Chapter 1"] = "https://drive.google.com/file/d/1o7feaKFzXd5-95GffZyynAwY_fzGafhr/view?usp=sharing"

# -------------------- Globals ------------------------
texts = None
pinecone_index = "group-1"

if 'exchanges' not in st.session_state:
    st.session_state.exchanges = []
if 'temperature' not in st.session_state:
    st.session_state.temperature = 0.5

# -------------------- Functions -----------------------
def console_log(msg):
    st.sidebar.write(msg)

def init_pinecone():
    pinecone.init(
        api_key=st.secrets["PINECONE_API_KEY"],  # find at app.pinecone.io
        environment=st.secrets["PINECONE_API_ENV"]  # next to api key in console
    ) 
    return
 
def load_vector_database():
    embeddings = OpenAIEmbeddings(openai_api_key=st.secrets["OPENAI_API_KEY"])
    init_pinecone()
    print(f"Number of vectors: {len(texts)} to be upserted to Index: {pinecone_index}")
    Pinecone.from_texts([t.page_content for t in texts], embeddings, index_name=pinecone_index)

def load_pdf(url):
    console_log(f"Loading {url}")
    loader = OnlinePDFLoader(url)
    data = loader.load()
    console_log(f'You have {len(data)} document(s) in your data')
    console_log(f'There are {len(data[0].page_content)} characters in your document')
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    global texts 
    texts = text_splitter.split_documents(data)
    console_log(f'After splitting, you have {len(texts)} documents')
    load_vector_database()

def chat(query, temperature):

    from langchain.llms import OpenAI
    from langchain.chains.question_answering import load_qa_chain

    llm = OpenAI(temperature=temperature, openai_api_key=st.secrets["OPENAI_API_KEY"])
    chain = load_qa_chain(llm, chain_type="stuff")
   
    embeddings = OpenAIEmbeddings(openai_api_key=st.secrets["OPENAI_API_KEY"])
    init_pinecone()
    vector_store = Pinecone.from_existing_index(pinecone_index, embeddings)
    docs = vector_store.similarity_search(query, include_metadata=True)

    # Comment/Uncomment to hide/show trace of documents
    with st.expander("See documents for embedding"):
        for i in range(len(docs)):
            st.write(docs[i])
    
    return chain.run(input_documents=docs, question=query)

def format_exchanges(exchanges):
    for i in range(len(exchanges)):
        if exchanges[i]["role"] == "user":
            icon, text, blank = st.columns([1,8,1])
        elif exchanges[i]["role"] == "assistant":
            blank, text, icon = st.columns([1,8,1])
        else:
            st.markdown("*" + exchanges[i]["role"] + ":* " + exchanges[i]["content"]) 
            continue
        
        with icon:
            st.image("icon_" + exchanges[i]["role"] + ".png", width=50)
        with text:
            st.markdown(exchanges[i]["content"])
        st.markdown("""---""")

def format_prompt(exchanges):
    # Include the last 6 exchanges
    prompt = ""
    for i in range( max(len(exchanges)-7,0), len(exchanges)):
        prompt += "[Q]" if (exchanges[i]["role"] == "user") else "[A]"
        prompt += ": " + exchanges[i]["content"] + "\n"
    with st.expander("See prompt sent to LLM"):
        st.write(prompt)
    return prompt

# ------------------------ Load PDF ------------------------
with st.sidebar:
    option = st.selectbox("Select a PDF", list(pdf_dict.keys()), key="pdf", on_change=None)
    st.markdown(f"*Selected*: {option}") 
    st.button('Click to start loading PDF', key="load_pdf", on_click=load_pdf, args=[pdf_dict[option]])

# ------------------------ Chatbot ------------------------
st.slider("Temperature (0 = Most Deterministic)", min_value=0.0, max_value=1.0, step=0.1, key="temperature")
st.text_input("Prompt", placeholder="Ask me anything", key="prompt")

if st.session_state.prompt:
    st.session_state.exchanges.append({"role": "user", "content": st.session_state.prompt})
    try:
        response = chat(format_prompt(st.session_state.exchanges), st.session_state.temperature)
    except Exception as e:
        st.error(e) 
        st.stop()
    st.session_state.exchanges.append({"role": "assistant", "content": response})
    format_exchanges(st.session_state.exchanges)