hertogateis commited on
Commit
2c08fd3
·
verified ·
1 Parent(s): a8b1039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -3
app.py CHANGED
@@ -23,9 +23,183 @@ style = '''
23
  '''
24
  st.markdown(style, unsafe_allow_html=True)
25
 
26
- st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS and Model Language</p>', unsafe_allow_html=True)
27
- st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> This code is based on Jordan Skinner. I recoded and enhanced it </p>', unsafe_allow_html=True)
28
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Initialize TAPAS pipeline
31
  tqa = pipeline(task="table-question-answering",
 
23
  '''
24
  st.markdown(style, unsafe_allow_html=True)
25
 
26
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.5rem;text-align: right;"> HertogAI Table Q&A using TAPAS and Model Language</p>', unsafe_allow_html=True)
27
+ st.markdown('<p style="font-family:sans-serif;font-size: 0.7rem;text-align: right;"> This code is based on Jordan Skinner. I enhanced his work using Language Model T5</p>', unsafe_allow_html=True)
28
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.6rem;text-align: right;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
29
+
30
+
31
+ # Initialize TAPAS pipeline
32
+ tqa = pipeline(task="table-question-answering",
33
+ model="google/tapas-large-finetuned-wtq",
34
+ device="cpu")
35
+
36
+ # Initialize T5 tokenizer and model for text generation
37
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
38
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
39
+
40
+ # File uploader in the sidebar
41
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
42
+
43
+ # File processing and question answering
44
+ if file_name is None:
45
+
46
+ st.markdown('<p class="custom-font">Please click left side bar to upload an excel or csv file </p>', unsafe_allow_html=True)
47
+ else:
48
+ try:
49
+ # Check file type and handle reading accordingly
50
+ if file_name.name.endswith('.csv'):
51
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
52
+ elif file_name.name.endswith('.xlsx'):
53
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
54
+ else:
55
+ st.error("Unsupported file type")
56
+ df = None
57
+
58
+ # Continue with further processing if df is loaded
59
+ if df is not None:
60
+ numeric_columns = df.select_dtypes(include=['object']).columns
61
+ for col in numeric_columns:
62
+ df[col] = pd.to_numeric(df[col], errors='ignore')
63
+
64
+ st.write("Original Data:")
65
+ st.write(df)
66
+
67
+ # Create a copy for numerical operations
68
+ df_numeric = df.copy()
69
+ df = df.astype(str)
70
+
71
+ # Display the first 5 rows of the dataframe in an editable grid
72
+ grid_response = AgGrid(
73
+ df.head(5),
74
+ columns_auto_size_mode='FIT_CONTENTS',
75
+ editable=True,
76
+ height=300,
77
+ width='100%',
78
+ )
79
+
80
+ except Exception as e:
81
+ st.error(f"Error reading file: {str(e)}")
82
+
83
+ # User input for the question
84
+ question = st.text_input('Type your question')
85
+
86
+ # Process the answer using TAPAS and T5
87
+ with st.spinner():
88
+ if st.button('Answer'):
89
+ try:
90
+ # Get the raw answer from TAPAS
91
+ raw_answer = tqa(table=df, query=question, truncation=True)
92
+
93
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
94
+ unsafe_allow_html=True)
95
+ st.success(raw_answer)
96
+
97
+ # Extract relevant information from the TAPAS result
98
+ answer = raw_answer['answer']
99
+ aggregator = raw_answer.get('aggregator', '')
100
+ coordinates = raw_answer.get('coordinates', [])
101
+ cells = raw_answer.get('cells', [])
102
+
103
+ # Construct a base sentence replacing 'SUM' with the query term
104
+ base_sentence = f"The {question.lower()} of the selected data is {answer}."
105
+ if coordinates and cells:
106
+ rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
107
+ for coordinate, cell in zip(coordinates, cells)]
108
+ rows_description = " and ".join(rows_info)
109
+ base_sentence += f" This includes the following data: {rows_description}."
110
+
111
+ # Generate a fluent response using the T5 model, rephrasing the base sentence
112
+ input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
113
+
114
+ # Tokenize the input and generate a fluent response using T5
115
+ inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
116
+ summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
117
+
118
+ # Decode the generated text
119
+ generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
120
+
121
+ # Display the final generated response
122
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
123
+ st.success(generated_text)
124
+
125
+ except Exception as e:
126
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
127
+
128
+ try:
129
+ # Get raw answer again from TAPAS
130
+ raw_answer = tqa(table=df, query=question, truncation=True)
131
+
132
+ # Display raw result for debugging purposes
133
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
134
+ st.success(raw_answer)
135
+
136
+ # Processing the raw_answer
137
+ processed_answer = raw_answer['answer'].replace(';', ' ') # Clean the answer text
138
+ row_idx = raw_answer['coordinates'][0][0] # Row index from TAPAS
139
+ col_idx = raw_answer['coordinates'][0][1] # Column index from TAPAS
140
+ column_name = df.columns[col_idx] # Column name from the DataFrame
141
+ row_data = df.iloc[row_idx].to_dict() # Row data corresponding to the row index
142
+
143
+ # Handle different types of answers (e.g., 'SUM', 'MAX', 'MIN', 'AVG', etc.)
144
+ if 'SUM' in processed_answer:
145
+ summary_type = 'sum'
146
+ numeric_value = df_numeric[column_name].sum()
147
+ elif 'MAX' in processed_answer:
148
+ summary_type = 'maximum'
149
+ numeric_value = df_numeric[column_name].max()
150
+ elif 'MIN' in processed_answer:
151
+ summary_type = 'minimum'
152
+ numeric_value = df_numeric[column_name].min()
153
+ elif 'AVG' in processed_answer or 'AVERAGE' in processed_answer:
154
+ summary_type = 'average'
155
+ numeric_value = df_numeric[column_name].mean()
156
+ elif 'COUNT' in processed_answer:
157
+ summary_type = 'count'
158
+ numeric_value = df_numeric[column_name].count()
159
+ elif 'MEDIAN' in processed_answer:
160
+ summary_type = 'median'
161
+ numeric_value = df_numeric[column_name].median()
162
+ elif 'STD' in processed_answer or 'STANDARD DEVIATION' in processed_answer:
163
+ summary_type = 'std_dev'
164
+ numeric_value = df_numeric[column_name].std()
165
+ else:
166
+ summary_type = 'value'
167
+ numeric_value = processed_answer # In case of a general answer
168
+
169
+ # Build a natural language response based on the aggregation type
170
+ if summary_type == 'sum':
171
+ natural_language_answer = f"The total {column_name} is {numeric_value}."
172
+ elif summary_type == 'maximum':
173
+ natural_language_answer = f"The highest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
174
+ elif summary_type == 'minimum':
175
+ natural_language_answer = f"The lowest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
176
+ elif summary_type == 'average':
177
+ natural_language_answer = f"The average {column_name} is {numeric_value}."
178
+ elif summary_type == 'count':
179
+ natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
180
+ elif summary_type == 'median':
181
+ natural_language_answer = f"The median {column_name} is {numeric_value}."
182
+ elif summary_type == 'std_dev':
183
+ natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
184
+ else:
185
+ natural_language_answer = f"The {column_name} value is {numeric_value} for '{row_data.get('Name', 'Unknown')}'."
186
+
187
+ # Display the final natural language answer
188
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
189
+ st.success(f"""
190
+ • Answer: {natural_language_answer}
191
+ Data Location:
192
+ • Row: {row_idx + 1}
193
+ • Column: {column_name}
194
+ Additional Context:
195
+ • Full Row Data: {row_data}
196
+ • Query Asked: "{question}"
197
+ """)
198
+
199
+ except Exception as e:
200
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
201
+
202
+
203
 
204
  # Initialize TAPAS pipeline
205
  tqa = pipeline(task="table-question-answering",