File size: 5,106 Bytes
92808fd
 
2563fc9
92808fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93b6d4c
92808fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8191bb
f8b532e
 
 
 
 
 
92808fd
a8191bb
92808fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93b6d4c
 
 
 
 
 
92808fd
93b6d4c
 
 
 
92808fd
 
93b6d4c
 
 
 
 
 
 
 
 
f8b532e
93b6d4c
 
 
 
 
 
 
92808fd
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import streamlit_scrollable_textbox as stx
from ast import literal_eval
import pinecone
import streamlit as st

st.set_page_config(layout="wide")  # isort: split


from utils import nltkmodules
from utils.models import (
    get_bm25_model,
    tokenizer,
    get_data,
    get_instructor_embedding_model,
    preprocess_text,
)
from utils.retriever import (
    query_pinecone,
    format_context,
    format_query,
    get_bm25_search_hits,
    retrieve_transcript,
)


st.title("Instructor XL Embeddings")


st.write(
    "The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript."
)

data = get_data()


col1, col2 = st.columns([3, 3], gap="medium")

instructor_model = get_instructor_embedding_model()

question_choice = [
    "What was discussed regarding Ryzen revenue performance?",
    "What is the impact of the enterprise and cloud on AMD's growth",
    "What was the impact of situation in China on the sales and revenue?",
]

question_instruction_choice = [
    "Represent the financial question for retrieving supporting documents:",
    "Represent the financial question for retrieving supporting sentences:",
    "Represent the finance query for retrieving supporting documents:",
    "Represent the finance query for retrieving related documents:",
    "Represent a finance query for retrieving relevant documents:",
]


with col1:
    st.subheader("Question")
    st.write(
        "Choose a preset question example from the dropdown or enter a question in the text box."
    )
    default_query = st.selectbox("Question Examples", question_choice)

    query_text = st.text_area(
        "Question",
        value=default_query,
    )

    st.subheader("Question Embedding-Instruction")
    st.write(
        "Choose a preset instruction example from the dropdown or enter a instruction in the text box."
    )
    default_query_embedding_instruction = st.selectbox(
        "Question Embedding-Instruction Examples", question_instruction_choice
    )

    query_embedding_instruction = st.text_area(
        "Question Embedding-Instruction",
        value=default_query_embedding_instruction,
    )

    num_results = int(
        st.number_input("Number of Results to query", 1, 15, value=5)
    )

corpus, bm25 = get_bm25_model(data)

tokenized_query = preprocess_text(query_text).split()
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
indices = get_bm25_search_hits(corpus, sparse_scores, 50)


dense_embedding_output = literal_eval(
    instructor_model.predict(
        query_embedding_instruction,
        query_text,
        api_name="/predict",
    )
)

dense_embedding = [int(x) for x in dense_embedding_output]

text_embedding_instructions_choice = [
    "Represent the financial statement for retrieval:",
    "Represent the financial document for retrieval:",
    "Represent the finance passage for retrieval:",
    "Represent the earnings call transcript for retrieval:",
    "Represent the earnings call transcript sentence for retrieval:",
    "Represent the earnings call transcript answer for retrieval:",
]

index_mapping = {
    "Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1",
    "Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2",
    "Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3",
    "Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4",
    "Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5",
    "Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6",
}

with col2:
    with st.form("my_form"):
        text_embedding_instruction = st.selectbox(
            "Select instruction for Text Embedding",
            text_embedding_instructions_choice,
        )

        pinecone_index_name = index_mapping[text_embedding_instruction]
        pinecone.init(
            api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
            environment="asia-southeast1-gcp-free",
        )

        pinecone_index = pinecone.Index(pinecone_index_name)

        submitted = st.form_submit_button("Submit")
        if submitted:
            matches = query_pinecone(
                dense_embedding, num_results, pinecone_index, indices
            )
            context = format_query(matches)
            output_text = format_context(context)

            st.subheader("Retrieved Text:")
            for output in output_text:
                output = f"""{output}"""
                st.write(
                    f"<ul><li><p>{output}</p></li></ul>",
                    unsafe_allow_html=True,
                )

tab1 = st.tabs(["View transcript"])


with tab1:
    file_text = retrieve_transcript()
    with st.expander("See Transcript"):
        st.subheader("AMD Q1 2020 Earnings Call Transcript:")
        stx.scrollableTextbox(
            file_text, height=700, border=False, fontFamily="Helvetica"
        )