oliviercaron commited on
Commit
5af24b5
·
verified ·
1 Parent(s): 30f8888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -343
app.py CHANGED
@@ -1,343 +1,310 @@
1
- import os
2
- import csv
3
- import streamlit as st
4
- import polars as pl
5
- from io import BytesIO, StringIO
6
- from gliner import GLiNER
7
- from gliner_file import run_ner
8
- import time
9
- import torch
10
- import platform
11
- from typing import List
12
- from streamlit_tags import st_tags # Importing the st_tags component
13
-
14
- # Streamlit page configuration
15
- st.set_page_config(
16
- page_title="GLiNER",
17
- page_icon="🔥",
18
- layout="wide",
19
- initial_sidebar_state="expanded"
20
- )
21
-
22
- # Function to load data from the uploaded file
23
- @st.cache_data
24
- def load_data(file):
25
- """
26
- Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
27
- """
28
- with st.spinner("Loading data, please wait..."):
29
- try:
30
- _, file_ext = os.path.splitext(file.name)
31
- if file_ext.lower() in [".xls", ".xlsx"]:
32
- return load_excel(file)
33
- elif file_ext.lower() == ".csv":
34
- return load_csv(file)
35
- else:
36
- raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
37
- except Exception as e:
38
- st.error("Error loading data:")
39
- st.error(str(e))
40
- return None
41
-
42
- def load_excel(file):
43
- """
44
- Loads an Excel file using `BytesIO` and `polars` for reduced latency.
45
- """
46
- try:
47
- # Load the file into BytesIO for faster reading
48
- file_bytes = BytesIO(file.read())
49
-
50
- # Load the Excel file using `polars`
51
- df = pl.read_excel(file_bytes, read_options={"ignore_errors": True})
52
- return df
53
- except Exception as e:
54
- raise ValueError(f"Error reading the Excel file: {str(e)}")
55
-
56
- def load_csv(file):
57
- """
58
- Loads a CSV file by detecting the delimiter and using the quote character to handle internal delimiters.
59
- """
60
- try:
61
- file.seek(0) # Reset file pointer to ensure reading from the beginning
62
- raw_data = file.read()
63
-
64
- # Try decoding as UTF-8, else as Latin-1
65
- try:
66
- file_content = raw_data.decode('utf-8')
67
- except UnicodeDecodeError:
68
- try:
69
- file_content = raw_data.decode('latin1')
70
- except UnicodeDecodeError:
71
- raise ValueError("Unable to decode the file. Ensure it is encoded in UTF-8 or Latin-1.")
72
-
73
- # List of common delimiters
74
- delimiters = [",", ";", "|", "\t", " "]
75
-
76
- # Try each delimiter until one works
77
- for delimiter in delimiters:
78
- try:
79
- # Read CSV with current delimiter and handle quoted fields
80
- df = pl.read_csv(
81
- StringIO(file_content),
82
- separator=delimiter,
83
- quote_char='"', # Handle internal delimiters with quotes
84
- try_parse_dates=True,
85
- ignore_errors=True, # Ignore errors for invalid values
86
- truncate_ragged_lines=True
87
- )
88
- # Return the DataFrame if loading succeeds
89
- return df
90
- except Exception:
91
- continue # Move to the next delimiter in case of error
92
-
93
- # If no delimiter worked
94
- raise ValueError("Unable to load the file with common delimiters.")
95
- except Exception as e:
96
- raise ValueError(f"Error reading the CSV file: {str(e)}")
97
-
98
- # Function to load the GLiNER model
99
- @st.cache_resource
100
- def load_model():
101
- """
102
- Loads the GLiNER model into memory to avoid multiple reloads.
103
- """
104
- try:
105
- gpu_available = torch.cuda.is_available()
106
-
107
- with st.spinner("Loading the GLiNER model... Please wait."):
108
- device = torch.device("cuda" if gpu_available else "cpu")
109
- model = GLiNER.from_pretrained(
110
- "urchade/gliner_multi-v2.1"
111
- ).to(device)
112
- model.eval()
113
-
114
- if gpu_available:
115
- device_name = torch.cuda.get_device_name(0)
116
- st.success(f"GPU detected: {device_name}. Model loaded on GPU.")
117
- else:
118
- cpu_name = platform.processor()
119
- st.warning(f"No GPU detected. Using CPU: {cpu_name}")
120
-
121
- return model
122
- except Exception as e:
123
- st.error("Error loading the model:")
124
- st.error(str(e))
125
- return None
126
-
127
- # Function to perform NER and update the user interface
128
- def perform_ner(filtered_df, selected_column, labels_list, threshold):
129
- """
130
- Executes named entity recognition (NER) on the filtered data.
131
- """
132
- try:
133
- texts_to_analyze = filtered_df[selected_column].to_list()
134
- total_rows = len(texts_to_analyze)
135
- ner_results_list = []
136
-
137
- # Initialize progress bar and text
138
- progress_bar = st.progress(0)
139
- progress_text = st.empty()
140
- start_time = time.time()
141
-
142
- # Process each row individually to keep progress updates responsive
143
- for index, text in enumerate(texts_to_analyze, 1):
144
- if st.session_state.stop_processing:
145
- progress_text.text("Processing stopped by user.")
146
- break
147
-
148
- ner_results = run_ner(
149
- st.session_state.gliner_model,
150
- [text],
151
- labels_list,
152
- threshold=threshold
153
- )
154
- ner_results_list.append(ner_results)
155
-
156
- # Update progress bar and text after each row
157
- progress = index / total_rows
158
- elapsed_time = time.time() - start_time
159
- progress_bar.progress(progress)
160
- progress_text.text(f"Progress: {index}/{total_rows} - {progress * 100:.0f}% (Elapsed time: {elapsed_time:.2f}s)")
161
-
162
- # Add NER results to the DataFrame
163
- for label in labels_list:
164
- extracted_entities = []
165
- for entities in ner_results_list:
166
- texts = [entity["text"] for entity in entities[0] if entity["label"] == label]
167
- concatenated_texts = ", ".join(texts) if texts else ""
168
- extracted_entities.append(concatenated_texts)
169
- filtered_df = filtered_df.with_columns(pl.Series(name=label, values=extracted_entities))
170
-
171
- end_time = time.time()
172
- st.success(f"Processing completed in {end_time - start_time:.2f} seconds.")
173
-
174
- return filtered_df
175
- except Exception as e:
176
- st.error(f"Error during NER processing: {str(e)}")
177
- return filtered_df
178
-
179
- # Main function to run the Streamlit application
180
- def main():
181
- st.title("Use NER with GliNER on your data file")
182
- st.markdown("Prototype v0.1")
183
-
184
- # User instructions
185
- st.write("""
186
- This application performs named entity recognition (NER) on your text data using GLiNER.
187
-
188
- **Instructions:**
189
- 1. Upload a CSV or Excel file.
190
- 2. Select the column containing the text to analyze.
191
- 3. Filter the data if necessary.
192
- 4. Enter the NER labels you wish to detect.
193
- 5. Click "Start NER" to begin processing.
194
- """)
195
-
196
- # Initializing session state variables
197
- if "stop_processing" not in st.session_state:
198
- st.session_state.stop_processing = False
199
- if "threshold" not in st.session_state:
200
- st.session_state.threshold = 0.4
201
- if "labels_list" not in st.session_state:
202
- st.session_state.labels_list = []
203
-
204
- # Load the model
205
- st.session_state.gliner_model = load_model()
206
- if st.session_state.gliner_model is None:
207
- return
208
-
209
- # File upload
210
- uploaded_file = st.sidebar.file_uploader("Choose a file (CSV or Excel)")
211
- if uploaded_file is None:
212
- st.warning("Please upload a file to continue.")
213
- return
214
-
215
- # Loading data
216
- df = load_data(uploaded_file)
217
- if df is None:
218
- return
219
-
220
- # Column selection
221
- selected_column = st.selectbox("Select the column containing the text:", df.columns)
222
-
223
- # Data filtering
224
- filter_text = st.text_input("Filter the column by text", "")
225
- if filter_text:
226
- filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
227
- else:
228
- filtered_df = df
229
-
230
- st.write("Filtered data preview:")
231
-
232
- # Rows per page
233
- rows_per_page = 100
234
-
235
- # Calculate total rows and pages
236
- total_rows = len(filtered_df)
237
- total_pages = (total_rows - 1) // rows_per_page + 1
238
-
239
- # Initialize current page in session_state
240
- if "current_page" not in st.session_state:
241
- st.session_state.current_page = 1
242
-
243
- # Function to update page
244
- def update_page(new_page):
245
- st.session_state.current_page = new_page
246
-
247
- # Pagination buttons
248
- col1, col2, col3, col4, col5 = st.columns(5)
249
-
250
- with col1:
251
- first = st.button("⏮️ First")
252
- with col2:
253
- previous = st.button("⬅️ Previous")
254
- with col3:
255
- pass # Page number display will be done after
256
- with col4:
257
- next = st.button("Next ➡️")
258
- with col5:
259
- last = st.button("Last ⏭️")
260
-
261
- # Button clicks management
262
- if first:
263
- update_page(1)
264
- elif previous:
265
- if st.session_state.current_page > 1:
266
- update_page(st.session_state.current_page - 1)
267
- elif next:
268
- if st.session_state.current_page < total_pages:
269
- update_page(st.session_state.current_page + 1)
270
- elif last:
271
- update_page(total_pages)
272
-
273
- # Now display the page number after updating
274
- with col3:
275
- st.markdown(f"Page **{st.session_state.current_page}** of **{total_pages}**")
276
-
277
- # Calculate indices for pagination
278
- start_idx = (st.session_state.current_page - 1) * rows_per_page
279
- end_idx = min(start_idx + rows_per_page, total_rows)
280
-
281
- # Check if the filtered DataFrame is empty
282
- if not filtered_df.is_empty():
283
- # Retrieve current page data
284
- current_page_data = filtered_df.slice(start_idx, end_idx - start_idx)
285
- st.write(f"Displaying {start_idx + 1} to {end_idx} of {total_rows} rows")
286
- st.dataframe(current_page_data.to_pandas(), use_container_width=True)
287
- else:
288
- st.warning("The filtered DataFrame is empty. Please check your filters.")
289
-
290
- # Confidence threshold slider
291
- st.slider("Set confidence threshold", 0.0, 1.0, st.session_state.threshold, 0.01, key="threshold")
292
-
293
- # Buttons to start and stop NER
294
- col1, col2 = st.columns(2)
295
- with col1:
296
- start_button = st.button("Start NER")
297
- with col2:
298
- stop_button = st.button("Stop")
299
-
300
- if start_button:
301
- st.session_state.stop_processing = False
302
-
303
- if not st.session_state.labels_list:
304
- st.warning("Please enter labels for NER.")
305
- else:
306
- # Run NER
307
- updated_df = perform_ner(filtered_df, selected_column, st.session_state.labels_list, st.session_state.threshold)
308
- st.write("**NER Results:**")
309
- st.dataframe(updated_df.to_pandas(), use_container_width=True)
310
-
311
- # Function to convert DataFrame to Excel
312
- def to_excel(df):
313
- output = BytesIO()
314
- df.write_excel(output)
315
- return output.getvalue()
316
-
317
- # Function to convert DataFrame to CSV
318
- def to_csv(df):
319
- return df.write_csv().encode('utf-8')
320
-
321
- # Download buttons for results
322
- download_col1, download_col2 = st.columns(2)
323
- with download_col1:
324
- st.download_button(
325
- label="📥 Download as Excel",
326
- data=to_excel(updated_df),
327
- file_name="ner_results.xlsx",
328
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
329
- )
330
- with download_col2:
331
- st.download_button(
332
- label="📥 Download as CSV",
333
- data=to_csv(updated_df),
334
- file_name="ner_results.csv",
335
- mime="text/csv",
336
- )
337
-
338
- if stop_button:
339
- st.session_state.stop_processing = True
340
- st.warning("Processing stopped by user.")
341
-
342
- if __name__ == "__main__":
343
- main()
 
1
+ import os
2
+ import csv
3
+ import streamlit as st
4
+ import polars as pl
5
+ from io import BytesIO, StringIO
6
+ from gliner import GLiNER
7
+ from gliner_file import run_ner
8
+ import time
9
+ import torch
10
+ import platform
11
+ from typing import List
12
+ from streamlit_tags import st_tags # Importing the st_tags component for labels
13
+
14
+ # Streamlit page configuration
15
+ st.set_page_config(
16
+ page_title="GLiNER",
17
+ page_icon="🔥",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded"
20
+ )
21
+
22
+ # Function to load data from the uploaded file
23
+ @st.cache_data
24
+ def load_data(file):
25
+ """
26
+ Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
27
+ """
28
+ with st.spinner("Loading data, please wait..."):
29
+ try:
30
+ _, file_ext = os.path.splitext(file.name)
31
+ if file_ext.lower() in [".xls", ".xlsx"]:
32
+ return load_excel(file)
33
+ elif file_ext.lower() == ".csv":
34
+ return load_csv(file)
35
+ else:
36
+ raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
37
+ except Exception as e:
38
+ st.error("Error loading data:")
39
+ st.error(str(e))
40
+ return None
41
+
42
+ def load_excel(file):
43
+ """
44
+ Loads an Excel file using `BytesIO` and `polars` for reduced latency.
45
+ """
46
+ try:
47
+ file_bytes = BytesIO(file.read())
48
+ df = pl.read_excel(file_bytes, read_options={"ignore_errors": True})
49
+ return df
50
+ except Exception as e:
51
+ raise ValueError(f"Error reading the Excel file: {str(e)}")
52
+
53
+ def load_csv(file):
54
+ """
55
+ Loads a CSV file by detecting the delimiter and using the quote character to handle internal delimiters.
56
+ """
57
+ try:
58
+ file.seek(0) # Reset file pointer to ensure reading from the beginning
59
+ raw_data = file.read()
60
+
61
+ try:
62
+ file_content = raw_data.decode('utf-8')
63
+ except UnicodeDecodeError:
64
+ try:
65
+ file_content = raw_data.decode('latin1')
66
+ except UnicodeDecodeError:
67
+ raise ValueError("Unable to decode the file. Ensure it is encoded in UTF-8 or Latin-1.")
68
+
69
+ delimiters = [",", ";", "|", "\t", " "]
70
+
71
+ for delimiter in delimiters:
72
+ try:
73
+ df = pl.read_csv(
74
+ StringIO(file_content),
75
+ separator=delimiter,
76
+ quote_char='"',
77
+ try_parse_dates=True,
78
+ ignore_errors=True,
79
+ truncate_ragged_lines=True
80
+ )
81
+ return df
82
+ except Exception:
83
+ continue
84
+
85
+ raise ValueError("Unable to load the file with common delimiters.")
86
+ except Exception as e:
87
+ raise ValueError(f"Error reading the CSV file: {str(e)}")
88
+
89
+ @st.cache_resource
90
+ def load_model():
91
+ """
92
+ Loads the GLiNER model into memory to avoid multiple reloads.
93
+ """
94
+ try:
95
+ gpu_available = torch.cuda.is_available()
96
+
97
+ with st.spinner("Loading the GLiNER model... Please wait."):
98
+ device = torch.device("cuda" if gpu_available else "cpu")
99
+ model = GLiNER.from_pretrained(
100
+ "urchade/gliner_multi-v2.1"
101
+ ).to(device)
102
+ model.eval()
103
+
104
+ if gpu_available:
105
+ device_name = torch.cuda.get_device_name(0)
106
+ st.success(f"GPU detected: {device_name}. Model loaded on GPU.")
107
+ else:
108
+ cpu_name = platform.processor()
109
+ st.warning(f"No GPU detected. Using CPU: {cpu_name}")
110
+
111
+ return model
112
+ except Exception as e:
113
+ st.error("Error loading the model:")
114
+ st.error(str(e))
115
+ return None
116
+
117
+ def perform_ner(filtered_df, selected_column, labels_list, threshold):
118
+ """
119
+ Executes named entity recognition (NER) on the filtered data.
120
+ """
121
+ try:
122
+ texts_to_analyze = filtered_df[selected_column].to_list()
123
+ total_rows = len(texts_to_analyze)
124
+ ner_results_list = []
125
+
126
+ progress_bar = st.progress(0)
127
+ progress_text = st.empty()
128
+ start_time = time.time()
129
+
130
+ for index, text in enumerate(texts_to_analyze, 1):
131
+ if st.session_state.stop_processing:
132
+ progress_text.text("Processing stopped by user.")
133
+ break
134
+
135
+ ner_results = run_ner(
136
+ st.session_state.gliner_model,
137
+ [text],
138
+ labels_list,
139
+ threshold=threshold
140
+ )
141
+ ner_results_list.append(ner_results)
142
+
143
+ progress = index / total_rows
144
+ elapsed_time = time.time() - start_time
145
+ progress_bar.progress(progress)
146
+ progress_text.text(f"Progress: {index}/{total_rows} - {progress * 100:.0f}% (Elapsed time: {elapsed_time:.2f}s)")
147
+
148
+ for label in labels_list:
149
+ extracted_entities = []
150
+ for entities in ner_results_list:
151
+ texts = [entity["text"] for entity in entities[0] if entity["label"] == label]
152
+ concatenated_texts = ", ".join(texts) if texts else ""
153
+ extracted_entities.append(concatenated_texts)
154
+ filtered_df = filtered_df.with_columns(pl.Series(name=label, values=extracted_entities))
155
+
156
+ end_time = time.time()
157
+ st.success(f"Processing completed in {end_time - start_time:.2f} seconds.")
158
+
159
+ return filtered_df
160
+ except Exception as e:
161
+ st.error(f"Error during NER processing: {str(e)}")
162
+ return filtered_df
163
+
164
+ def main():
165
+ st.title("Use NER with GliNER on your data file")
166
+ st.markdown("Prototype v0.1")
167
+
168
+ st.write("""
169
+ This application performs named entity recognition (NER) on your text data using GLiNER.
170
+
171
+ **Instructions:**
172
+ 1. Upload a CSV or Excel file.
173
+ 2. Select the column containing the text to analyze.
174
+ 3. Filter the data if necessary.
175
+ 4. Enter the NER labels you wish to detect.
176
+ 5. Click "Start NER" to begin processing.
177
+ """)
178
+
179
+ if "stop_processing" not in st.session_state:
180
+ st.session_state.stop_processing = False
181
+ if "threshold" not in st.session_state:
182
+ st.session_state.threshold = 0.4
183
+ if "labels_list" not in st.session_state:
184
+ st.session_state.labels_list = []
185
+
186
+ st.session_state.gliner_model = load_model()
187
+ if st.session_state.gliner_model is None:
188
+ return
189
+
190
+ uploaded_file = st.sidebar.file_uploader("Choose a file (CSV or Excel)")
191
+ if uploaded_file is None:
192
+ st.warning("Please upload a file to continue.")
193
+ return
194
+
195
+ df = load_data(uploaded_file)
196
+ if df is None:
197
+ return
198
+
199
+ selected_column = st.selectbox("Select the column containing the text:", df.columns)
200
+
201
+ filter_text = st.text_input("Filter the column by text", "")
202
+ if filter_text:
203
+ filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
204
+ else:
205
+ filtered_df = df
206
+
207
+ st.write("Filtered data preview:")
208
+
209
+ rows_per_page = 100
210
+ total_rows = len(filtered_df)
211
+ total_pages = (total_rows - 1) // rows_per_page + 1
212
+
213
+ if "current_page" not in st.session_state:
214
+ st.session_state.current_page = 1
215
+
216
+ def update_page(new_page):
217
+ st.session_state.current_page = new_page
218
+
219
+ col1, col2, col3, col4, col5 = st.columns(5)
220
+
221
+ with col1:
222
+ first = st.button("⏮️ First")
223
+ with col2:
224
+ previous = st.button("⬅️ Previous")
225
+ with col3:
226
+ pass
227
+ with col4:
228
+ next = st.button("Next ➡️")
229
+ with col5:
230
+ last = st.button("Last ⏭️")
231
+
232
+ if first:
233
+ update_page(1)
234
+ elif previous:
235
+ if st.session_state.current_page > 1:
236
+ update_page(st.session_state.current_page - 1)
237
+ elif next:
238
+ if st.session_state.current_page < total_pages:
239
+ update_page(st.session_state.current_page + 1)
240
+ elif last:
241
+ update_page(total_pages)
242
+
243
+ with col3:
244
+ st.markdown(f"Page **{st.session_state.current_page}** of **{total_pages}**")
245
+
246
+ start_idx = (st.session_state.current_page - 1) * rows_per_page
247
+ end_idx = min(start_idx + rows_per_page, total_rows)
248
+
249
+ if not filtered_df.is_empty():
250
+ current_page_data = filtered_df.slice(start_idx, end_idx - start_idx)
251
+ st.write(f"Displaying {start_idx + 1} to {end_idx} of {total_rows} rows")
252
+ st.dataframe(current_page_data.to_pandas(), use_container_width=True)
253
+ else:
254
+ st.warning("The filtered DataFrame is empty. Please check your filters.")
255
+
256
+ st.slider("Set confidence threshold", 0.0, 1.0, st.session_state.threshold, 0.01, key="threshold")
257
+
258
+ st.session_state.labels_list = st_tags(
259
+ label="Enter the NER labels to detect",
260
+ text="Add more labels as needed",
261
+ value=st.session_state.labels_list,
262
+ key="1"
263
+ )
264
+
265
+ col1, col2 = st.columns(2)
266
+ with col1:
267
+ start_button = st.button("Start NER")
268
+ with col2:
269
+ stop_button = st.button("Stop")
270
+
271
+ if start_button:
272
+ st.session_state.stop_processing = False
273
+
274
+ if not st.session_state.labels_list:
275
+ st.warning("Please enter labels for NER.")
276
+ else:
277
+ updated_df = perform_ner(filtered_df, selected_column, st.session_state.labels_list, st.session_state.threshold)
278
+ st.write("**NER Results:**")
279
+ st.dataframe(updated_df.to_pandas(), use_container_width=True)
280
+
281
+ def to_excel(df):
282
+ output = BytesIO()
283
+ df.write_excel(output)
284
+ return output.getvalue()
285
+
286
+ def to_csv(df):
287
+ return df.write_csv().encode('utf-8')
288
+
289
+ download_col1, download_col2 = st.columns(2)
290
+ with download_col1:
291
+ st.download_button(
292
+ label="📥 Download as Excel",
293
+ data=to_excel(updated_df),
294
+ file_name="ner_results.xlsx"#,
295
+ #mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
296
+ )
297
+ with download_col2:
298
+ st.download_button(
299
+ label="📥 Download as CSV",
300
+ data=to_csv(updated_df),
301
+ file_name="ner_results.csv"#,
302
+ #mime="text/csv",
303
+ )
304
+
305
+ if stop_button:
306
+ st.session_state.stop_processing = True
307
+ st.warning("Processing stopped by user.")
308
+
309
+ if __name__ == "__main__":
310
+ main()