XINZHANG-Geotab commited on
Commit
2b4a9d2
·
1 Parent(s): 7cd74e0
Files changed (4) hide show
  1. README.md +3 -3
  2. app.py +197 -0
  3. requirements.txt +1 -0
  4. utils.py +41 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: RelevancyApp
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.35.0
8
  app_file: app.py
 
1
  ---
2
  title: RelevancyApp
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: streamlit
7
  sdk_version: 1.35.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import (
2
+ container_emphasis_css,
3
+ container_regular_css,
4
+ question_font,
5
+ modify_sql_guid_injection
6
+ )
7
+
8
+ import streamlit as st
9
+ import clipboard
10
+ import textwrap as tw
11
+ import pandas as pd
12
+ from streamlit_extras.stylable_container import stylable_container
13
+
14
+ @st.cache_data
15
+ def load_data(file):
16
+ return pd.read_csv(file)
17
+
18
+ st.set_page_config(layout="wide", initial_sidebar_state="expanded")
19
+ st.title("Relevancy Compare App")
20
+ st.divider()
21
+ st.markdown(question_font, unsafe_allow_html=True)
22
+
23
+
24
+ def edit_df_data_change():
25
+ edited_rows: dict = st.session_state.edit_df_data['edited_rows']
26
+ st.session_state.previous_row_index = st.session_state.previous_row_index
27
+ st.session_state.current_row_index = next(iter(edited_rows))
28
+ st.session_state.df = st.session_state.df.assign(selected=False)
29
+ update_dict = {idx: values for idx, values in edited_rows.items()}
30
+ st.session_state.df.update(pd.DataFrame.from_dict(update_dict, orient='index'))
31
+
32
+
33
+ st.need_rerun = False
34
+ if "df" not in st.session_state:
35
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
36
+ st.need_rerun = True
37
+ else:
38
+ uploaded_file = st.session_state.df
39
+ if uploaded_file is not None:
40
+ with st.container(height=250):
41
+ if "df" not in st.session_state:
42
+ df = load_data(uploaded_file)
43
+ st.session_state.df_shape = df.shape
44
+ # selected load
45
+ select_values = [False] * len(df)
46
+
47
+ if "Comment" not in df.columns:
48
+ comments = [''] * len(df)
49
+ df.insert(0, 'Comment', comments)
50
+ st.session_state.default_row_index = 0
51
+ else:
52
+ df['Comment'] = df['Comment'].astype(str)
53
+ comments = [c if c != 'nan' else '' for c in df['Comment'].tolist()]
54
+ df['Comment'] = comments
55
+ first_nan = 0
56
+ for i, c in enumerate(comments):
57
+ if not c:
58
+ first_nan = i
59
+ break
60
+ st.session_state.default_row_index = first_nan
61
+
62
+ st.session_state.current_row_index = st.session_state.default_row_index
63
+ st.session_state.previous_row_index = st.session_state.current_row_index
64
+
65
+ df.insert(0, 'Select', select_values)
66
+ df.at[st.session_state.current_row_index, 'Select'] = True
67
+ st.session_state.comments_add = dict((idx, t) for idx, t in enumerate(df["Comment"]))
68
+ st.session_state['df'] = df
69
+
70
+ df_col, = st.columns(1)
71
+
72
+ st.session_state.previous_row_index = st.session_state.current_row_index
73
+ with st.container():
74
+ bcol1, bcol2, _ = st.columns([1, 1, 5])
75
+ with bcol1:
76
+ previous_row = st.button("Prev Row")
77
+ with bcol2:
78
+ next_row = st.button("Next Row")
79
+
80
+
81
+ if "current_row_index" in st.session_state and previous_row:
82
+ if st.session_state.current_row_index > 0:
83
+ st.session_state.previous_row_index = st.session_state.current_row_index
84
+ st.session_state.current_row_index -= 1
85
+
86
+ if "current_row_index" in st.session_state and next_row:
87
+ if st.session_state.current_row_index < st.session_state.df_shape[0] - 1:
88
+ st.session_state.previous_row_index = st.session_state.current_row_index
89
+ st.session_state.current_row_index += 1
90
+
91
+ row_index = st.session_state.current_row_index
92
+ row_data = st.session_state.df.loc[row_index]
93
+ question = row_data['question']
94
+ pred_relevant_choice = row_data['pred_relevant_choice']
95
+ correct_response = row_data['correct_response']
96
+ first_css = container_regular_css
97
+ second_css = container_regular_css
98
+ first_color = ':black'
99
+ second_color = ':black'
100
+
101
+
102
+ st.write(f"##### Row {row_index}/{st.session_state.df_shape[0]-1}")
103
+ with st.container(border=True):
104
+ # st.write(f"Question: {question}")
105
+ st.markdown(f'<p class="question-font">[Q]: {question}</p>', unsafe_allow_html=True)
106
+ select_by = st.selectbox("Mark By: ", ["pred_relevant_choice", "correct_response"])
107
+ col1_res, col2_res, _ = st.columns([0.2, 0.2, 0.6])
108
+ if select_by == 'pred_relevant_choice':
109
+ col1_res.write(f":blue[pred_relevant_choice: {pred_relevant_choice}]")
110
+ col2_res.write(f"correct_response: {correct_response}")
111
+ else:
112
+ col1_res.write(f"pred_relevant_choice: {pred_relevant_choice}")
113
+ col2_res.write(f":blue[correct_response: {correct_response}]")
114
+
115
+
116
+ if row_data[select_by].lower() == 'first' or row_data[select_by].lower() == 'both':
117
+ first_css = container_emphasis_css
118
+ first_color = ':blue'
119
+
120
+ if row_data[select_by].lower() == 'both':
121
+ second_css = container_emphasis_css
122
+ second_color = ':blue'
123
+ else:
124
+ second_css = container_emphasis_css
125
+ second_color = ':blue'
126
+
127
+ col1, col2 = st.columns(2)
128
+ with col1:
129
+ st.write("Target SQL:", use_container_width=True)
130
+ with stylable_container(
131
+ key="first_sql_container",
132
+ css_styles=first_css
133
+ ):
134
+ target_sql_text = modify_sql_guid_injection(row_data['target_sql'].strip())
135
+ #st.write(f"{first_color}[{target_sql_text}]", use_container_width=True)
136
+ col1, _ = st.columns([0.99, 0.01])
137
+ with col1:
138
+ st.code(
139
+ target_sql_text,
140
+ language="sql"
141
+ )
142
+ # st.code(
143
+ # "\n".join(
144
+ # tw.wrap(
145
+ # target_sql_text,
146
+ # width=60,
147
+ # )
148
+ # ),
149
+ # language="sql"
150
+ # )
151
+
152
+ with col2:
153
+ st.write("Predicted SQL:", use_container_width=True)
154
+ with stylable_container(
155
+ key="second_sql_container",
156
+ css_styles=second_css
157
+ ):
158
+ predicted_sql_text = modify_sql_guid_injection(row_data['predicted_sql'].strip())
159
+ # st.write(f"{second_color}[{predicted_sql_text}]")
160
+ col1, _ = st.columns([0.99, 0.01])
161
+ with col1:
162
+ st.code(
163
+ predicted_sql_text,
164
+ language="sql"
165
+ )
166
+ comment = st.text_area("Comment", value=row_data['Comment'])
167
+
168
+ with st.container():
169
+ bcol1, bcol2, _, _, _, _, _, _, _, _ = st.columns(10)
170
+ with bcol1:
171
+ save = st.button("Save")
172
+ with bcol2:
173
+ download = st.button("Download")
174
+
175
+ if save:
176
+ st.session_state.comments_add[row_index] = comment
177
+ st.session_state.df.at[row_index, 'Comment'] = comment
178
+ df = st.session_state.df.drop('Select', axis=1)
179
+ df.to_csv('modified.csv', index=False)
180
+ # st.rerun()
181
+
182
+ if download:
183
+ with open('modified.csv') as f:
184
+ st.download_button('Download modified CSV', f, file_name='modified.csv')
185
+
186
+ st.session_state.df['Select'] = False
187
+ st.session_state.df.at[st.session_state.current_row_index, 'Select'] = True
188
+
189
+ edited_df = df_col.data_editor(st.session_state.df, num_rows="fixed", use_container_width=True, key='edit_df_data', on_change=edit_df_data_change)
190
+
191
+ #st.session_state.mannul_select = edited_df[edited_df['Select']].index
192
+ st.session_state.df = edited_df
193
+
194
+ if st.need_rerun:
195
+ st.need_rerun = False
196
+ st.rerun()
197
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ streamlit-extras
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ container_emphasis_css = """
2
+ {
3
+ border: 5px solid rgba(128, 128, 128, 1);
4
+ border-radius: 0.5rem;
5
+ padding: 0em 1em 1em 0.5em;
6
+ background-color: rgba(0, 128, 0, 0.2);
7
+ overflow-wrap: break-word; /* Added to ensure long words break */
8
+ word-wrap: break-word; /* Added for compatibility with older browsers */
9
+ word-break: break-all; /* Added to handle breaking of long strings */
10
+ white-space: pre-wrap; /* Added to respect new lines and wrap text */
11
+ }
12
+ """
13
+
14
+ container_regular_css = """
15
+ {
16
+ border: 5px solid rgba(0, 0, 0, 0.2);
17
+ border-radius: 0.5rem;
18
+ padding: 0em 1em 1em 0.5em;
19
+ background-color: rgba(255, 255, 255, 0);
20
+ overflow-wrap: break-word; /* Added to ensure long words break */
21
+ word-wrap: break-word; /* Added for compatibility with older browsers */
22
+ word-break: break-all; /* Added to handle breaking of long strings */
23
+ white-space: pre-wrap; /* Added to respect new lines and wrap text */
24
+ }
25
+ """
26
+
27
+
28
+ question_font = """
29
+ <style>
30
+ .question-font {
31
+ font-size: 16px !important;
32
+ font-weight: bold;
33
+ font-family: 'Arial', sans-serif;
34
+ }
35
+ </style>
36
+ """
37
+
38
+
39
+ def modify_sql_guid_injection(sql):
40
+ injected = "where UPPER(CompanyGuid) = '14A328EE-1ABF-448C-9F86-8D4E05838F8D')"
41
+ return sql.replace(injected, '\n' + injected)