hertogateis commited on
Commit
67b8485
·
verified ·
1 Parent(s): 2c08fd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -172
app.py CHANGED
@@ -28,178 +28,6 @@ st.markdown('<p style="font-family:sans-serif;font-size: 0.7rem;text-align: righ
28
  st.markdown("<p style='font-family:sans-serif;font-size: 0.6rem;text-align: right;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
29
 
30
 
31
- # Initialize TAPAS pipeline
32
- tqa = pipeline(task="table-question-answering",
33
- model="google/tapas-large-finetuned-wtq",
34
- device="cpu")
35
-
36
- # Initialize T5 tokenizer and model for text generation
37
- t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
38
- t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
39
-
40
- # File uploader in the sidebar
41
- file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
42
-
43
- # File processing and question answering
44
- if file_name is None:
45
-
46
- st.markdown('<p class="custom-font">Please click left side bar to upload an excel or csv file </p>', unsafe_allow_html=True)
47
- else:
48
- try:
49
- # Check file type and handle reading accordingly
50
- if file_name.name.endswith('.csv'):
51
- df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
52
- elif file_name.name.endswith('.xlsx'):
53
- df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
54
- else:
55
- st.error("Unsupported file type")
56
- df = None
57
-
58
- # Continue with further processing if df is loaded
59
- if df is not None:
60
- numeric_columns = df.select_dtypes(include=['object']).columns
61
- for col in numeric_columns:
62
- df[col] = pd.to_numeric(df[col], errors='ignore')
63
-
64
- st.write("Original Data:")
65
- st.write(df)
66
-
67
- # Create a copy for numerical operations
68
- df_numeric = df.copy()
69
- df = df.astype(str)
70
-
71
- # Display the first 5 rows of the dataframe in an editable grid
72
- grid_response = AgGrid(
73
- df.head(5),
74
- columns_auto_size_mode='FIT_CONTENTS',
75
- editable=True,
76
- height=300,
77
- width='100%',
78
- )
79
-
80
- except Exception as e:
81
- st.error(f"Error reading file: {str(e)}")
82
-
83
- # User input for the question
84
- question = st.text_input('Type your question')
85
-
86
- # Process the answer using TAPAS and T5
87
- with st.spinner():
88
- if st.button('Answer'):
89
- try:
90
- # Get the raw answer from TAPAS
91
- raw_answer = tqa(table=df, query=question, truncation=True)
92
-
93
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
94
- unsafe_allow_html=True)
95
- st.success(raw_answer)
96
-
97
- # Extract relevant information from the TAPAS result
98
- answer = raw_answer['answer']
99
- aggregator = raw_answer.get('aggregator', '')
100
- coordinates = raw_answer.get('coordinates', [])
101
- cells = raw_answer.get('cells', [])
102
-
103
- # Construct a base sentence replacing 'SUM' with the query term
104
- base_sentence = f"The {question.lower()} of the selected data is {answer}."
105
- if coordinates and cells:
106
- rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
107
- for coordinate, cell in zip(coordinates, cells)]
108
- rows_description = " and ".join(rows_info)
109
- base_sentence += f" This includes the following data: {rows_description}."
110
-
111
- # Generate a fluent response using the T5 model, rephrasing the base sentence
112
- input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
113
-
114
- # Tokenize the input and generate a fluent response using T5
115
- inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
116
- summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
117
-
118
- # Decode the generated text
119
- generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
120
-
121
- # Display the final generated response
122
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
123
- st.success(generated_text)
124
-
125
- except Exception as e:
126
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
127
-
128
- try:
129
- # Get raw answer again from TAPAS
130
- raw_answer = tqa(table=df, query=question, truncation=True)
131
-
132
- # Display raw result for debugging purposes
133
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
134
- st.success(raw_answer)
135
-
136
- # Processing the raw_answer
137
- processed_answer = raw_answer['answer'].replace(';', ' ') # Clean the answer text
138
- row_idx = raw_answer['coordinates'][0][0] # Row index from TAPAS
139
- col_idx = raw_answer['coordinates'][0][1] # Column index from TAPAS
140
- column_name = df.columns[col_idx] # Column name from the DataFrame
141
- row_data = df.iloc[row_idx].to_dict() # Row data corresponding to the row index
142
-
143
- # Handle different types of answers (e.g., 'SUM', 'MAX', 'MIN', 'AVG', etc.)
144
- if 'SUM' in processed_answer:
145
- summary_type = 'sum'
146
- numeric_value = df_numeric[column_name].sum()
147
- elif 'MAX' in processed_answer:
148
- summary_type = 'maximum'
149
- numeric_value = df_numeric[column_name].max()
150
- elif 'MIN' in processed_answer:
151
- summary_type = 'minimum'
152
- numeric_value = df_numeric[column_name].min()
153
- elif 'AVG' in processed_answer or 'AVERAGE' in processed_answer:
154
- summary_type = 'average'
155
- numeric_value = df_numeric[column_name].mean()
156
- elif 'COUNT' in processed_answer:
157
- summary_type = 'count'
158
- numeric_value = df_numeric[column_name].count()
159
- elif 'MEDIAN' in processed_answer:
160
- summary_type = 'median'
161
- numeric_value = df_numeric[column_name].median()
162
- elif 'STD' in processed_answer or 'STANDARD DEVIATION' in processed_answer:
163
- summary_type = 'std_dev'
164
- numeric_value = df_numeric[column_name].std()
165
- else:
166
- summary_type = 'value'
167
- numeric_value = processed_answer # In case of a general answer
168
-
169
- # Build a natural language response based on the aggregation type
170
- if summary_type == 'sum':
171
- natural_language_answer = f"The total {column_name} is {numeric_value}."
172
- elif summary_type == 'maximum':
173
- natural_language_answer = f"The highest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
174
- elif summary_type == 'minimum':
175
- natural_language_answer = f"The lowest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
176
- elif summary_type == 'average':
177
- natural_language_answer = f"The average {column_name} is {numeric_value}."
178
- elif summary_type == 'count':
179
- natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
180
- elif summary_type == 'median':
181
- natural_language_answer = f"The median {column_name} is {numeric_value}."
182
- elif summary_type == 'std_dev':
183
- natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
184
- else:
185
- natural_language_answer = f"The {column_name} value is {numeric_value} for '{row_data.get('Name', 'Unknown')}'."
186
-
187
- # Display the final natural language answer
188
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
189
- st.success(f"""
190
- • Answer: {natural_language_answer}
191
- Data Location:
192
- • Row: {row_idx + 1}
193
- • Column: {column_name}
194
- Additional Context:
195
- • Full Row Data: {row_data}
196
- • Query Asked: "{question}"
197
- """)
198
-
199
- except Exception as e:
200
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
201
-
202
-
203
 
204
  # Initialize TAPAS pipeline
205
  tqa = pipeline(task="table-question-answering",
 
28
  st.markdown("<p style='font-family:sans-serif;font-size: 0.6rem;text-align: right;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Initialize TAPAS pipeline
33
  tqa = pipeline(task="table-question-answering",