Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,343 +1,310 @@
|
|
1 |
-
import os
|
2 |
-
import csv
|
3 |
-
import streamlit as st
|
4 |
-
import polars as pl
|
5 |
-
from io import BytesIO, StringIO
|
6 |
-
from gliner import GLiNER
|
7 |
-
from gliner_file import run_ner
|
8 |
-
import time
|
9 |
-
import torch
|
10 |
-
import platform
|
11 |
-
from typing import List
|
12 |
-
from streamlit_tags import st_tags # Importing the st_tags component
|
13 |
-
|
14 |
-
# Streamlit page configuration
|
15 |
-
st.set_page_config(
|
16 |
-
page_title="GLiNER",
|
17 |
-
page_icon="🔥",
|
18 |
-
layout="wide",
|
19 |
-
initial_sidebar_state="expanded"
|
20 |
-
)
|
21 |
-
|
22 |
-
# Function to load data from the uploaded file
|
23 |
-
@st.cache_data
|
24 |
-
def load_data(file):
|
25 |
-
"""
|
26 |
-
Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
|
27 |
-
"""
|
28 |
-
with st.spinner("Loading data, please wait..."):
|
29 |
-
try:
|
30 |
-
_, file_ext = os.path.splitext(file.name)
|
31 |
-
if file_ext.lower() in [".xls", ".xlsx"]:
|
32 |
-
return load_excel(file)
|
33 |
-
elif file_ext.lower() == ".csv":
|
34 |
-
return load_csv(file)
|
35 |
-
else:
|
36 |
-
raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
|
37 |
-
except Exception as e:
|
38 |
-
st.error("Error loading data:")
|
39 |
-
st.error(str(e))
|
40 |
-
return None
|
41 |
-
|
42 |
-
def load_excel(file):
|
43 |
-
"""
|
44 |
-
Loads an Excel file using `BytesIO` and `polars` for reduced latency.
|
45 |
-
"""
|
46 |
-
try:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
# Function to convert DataFrame to Excel
|
312 |
-
def to_excel(df):
|
313 |
-
output = BytesIO()
|
314 |
-
df.write_excel(output)
|
315 |
-
return output.getvalue()
|
316 |
-
|
317 |
-
# Function to convert DataFrame to CSV
|
318 |
-
def to_csv(df):
|
319 |
-
return df.write_csv().encode('utf-8')
|
320 |
-
|
321 |
-
# Download buttons for results
|
322 |
-
download_col1, download_col2 = st.columns(2)
|
323 |
-
with download_col1:
|
324 |
-
st.download_button(
|
325 |
-
label="📥 Download as Excel",
|
326 |
-
data=to_excel(updated_df),
|
327 |
-
file_name="ner_results.xlsx",
|
328 |
-
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
329 |
-
)
|
330 |
-
with download_col2:
|
331 |
-
st.download_button(
|
332 |
-
label="📥 Download as CSV",
|
333 |
-
data=to_csv(updated_df),
|
334 |
-
file_name="ner_results.csv",
|
335 |
-
mime="text/csv",
|
336 |
-
)
|
337 |
-
|
338 |
-
if stop_button:
|
339 |
-
st.session_state.stop_processing = True
|
340 |
-
st.warning("Processing stopped by user.")
|
341 |
-
|
342 |
-
if __name__ == "__main__":
|
343 |
-
main()
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import streamlit as st
|
4 |
+
import polars as pl
|
5 |
+
from io import BytesIO, StringIO
|
6 |
+
from gliner import GLiNER
|
7 |
+
from gliner_file import run_ner
|
8 |
+
import time
|
9 |
+
import torch
|
10 |
+
import platform
|
11 |
+
from typing import List
|
12 |
+
from streamlit_tags import st_tags # Importing the st_tags component for labels
|
13 |
+
|
14 |
+
# Streamlit page configuration
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="GLiNER",
|
17 |
+
page_icon="🔥",
|
18 |
+
layout="wide",
|
19 |
+
initial_sidebar_state="expanded"
|
20 |
+
)
|
21 |
+
|
22 |
+
# Function to load data from the uploaded file
|
23 |
+
@st.cache_data
|
24 |
+
def load_data(file):
|
25 |
+
"""
|
26 |
+
Loads an uploaded CSV or Excel file with resilient detection of delimiters and types.
|
27 |
+
"""
|
28 |
+
with st.spinner("Loading data, please wait..."):
|
29 |
+
try:
|
30 |
+
_, file_ext = os.path.splitext(file.name)
|
31 |
+
if file_ext.lower() in [".xls", ".xlsx"]:
|
32 |
+
return load_excel(file)
|
33 |
+
elif file_ext.lower() == ".csv":
|
34 |
+
return load_csv(file)
|
35 |
+
else:
|
36 |
+
raise ValueError("Unsupported file format. Please upload a CSV or Excel file.")
|
37 |
+
except Exception as e:
|
38 |
+
st.error("Error loading data:")
|
39 |
+
st.error(str(e))
|
40 |
+
return None
|
41 |
+
|
42 |
+
def load_excel(file):
|
43 |
+
"""
|
44 |
+
Loads an Excel file using `BytesIO` and `polars` for reduced latency.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
file_bytes = BytesIO(file.read())
|
48 |
+
df = pl.read_excel(file_bytes, read_options={"ignore_errors": True})
|
49 |
+
return df
|
50 |
+
except Exception as e:
|
51 |
+
raise ValueError(f"Error reading the Excel file: {str(e)}")
|
52 |
+
|
53 |
+
def load_csv(file):
|
54 |
+
"""
|
55 |
+
Loads a CSV file by detecting the delimiter and using the quote character to handle internal delimiters.
|
56 |
+
"""
|
57 |
+
try:
|
58 |
+
file.seek(0) # Reset file pointer to ensure reading from the beginning
|
59 |
+
raw_data = file.read()
|
60 |
+
|
61 |
+
try:
|
62 |
+
file_content = raw_data.decode('utf-8')
|
63 |
+
except UnicodeDecodeError:
|
64 |
+
try:
|
65 |
+
file_content = raw_data.decode('latin1')
|
66 |
+
except UnicodeDecodeError:
|
67 |
+
raise ValueError("Unable to decode the file. Ensure it is encoded in UTF-8 or Latin-1.")
|
68 |
+
|
69 |
+
delimiters = [",", ";", "|", "\t", " "]
|
70 |
+
|
71 |
+
for delimiter in delimiters:
|
72 |
+
try:
|
73 |
+
df = pl.read_csv(
|
74 |
+
StringIO(file_content),
|
75 |
+
separator=delimiter,
|
76 |
+
quote_char='"',
|
77 |
+
try_parse_dates=True,
|
78 |
+
ignore_errors=True,
|
79 |
+
truncate_ragged_lines=True
|
80 |
+
)
|
81 |
+
return df
|
82 |
+
except Exception:
|
83 |
+
continue
|
84 |
+
|
85 |
+
raise ValueError("Unable to load the file with common delimiters.")
|
86 |
+
except Exception as e:
|
87 |
+
raise ValueError(f"Error reading the CSV file: {str(e)}")
|
88 |
+
|
89 |
+
@st.cache_resource
|
90 |
+
def load_model():
|
91 |
+
"""
|
92 |
+
Loads the GLiNER model into memory to avoid multiple reloads.
|
93 |
+
"""
|
94 |
+
try:
|
95 |
+
gpu_available = torch.cuda.is_available()
|
96 |
+
|
97 |
+
with st.spinner("Loading the GLiNER model... Please wait."):
|
98 |
+
device = torch.device("cuda" if gpu_available else "cpu")
|
99 |
+
model = GLiNER.from_pretrained(
|
100 |
+
"urchade/gliner_multi-v2.1"
|
101 |
+
).to(device)
|
102 |
+
model.eval()
|
103 |
+
|
104 |
+
if gpu_available:
|
105 |
+
device_name = torch.cuda.get_device_name(0)
|
106 |
+
st.success(f"GPU detected: {device_name}. Model loaded on GPU.")
|
107 |
+
else:
|
108 |
+
cpu_name = platform.processor()
|
109 |
+
st.warning(f"No GPU detected. Using CPU: {cpu_name}")
|
110 |
+
|
111 |
+
return model
|
112 |
+
except Exception as e:
|
113 |
+
st.error("Error loading the model:")
|
114 |
+
st.error(str(e))
|
115 |
+
return None
|
116 |
+
|
117 |
+
def perform_ner(filtered_df, selected_column, labels_list, threshold):
|
118 |
+
"""
|
119 |
+
Executes named entity recognition (NER) on the filtered data.
|
120 |
+
"""
|
121 |
+
try:
|
122 |
+
texts_to_analyze = filtered_df[selected_column].to_list()
|
123 |
+
total_rows = len(texts_to_analyze)
|
124 |
+
ner_results_list = []
|
125 |
+
|
126 |
+
progress_bar = st.progress(0)
|
127 |
+
progress_text = st.empty()
|
128 |
+
start_time = time.time()
|
129 |
+
|
130 |
+
for index, text in enumerate(texts_to_analyze, 1):
|
131 |
+
if st.session_state.stop_processing:
|
132 |
+
progress_text.text("Processing stopped by user.")
|
133 |
+
break
|
134 |
+
|
135 |
+
ner_results = run_ner(
|
136 |
+
st.session_state.gliner_model,
|
137 |
+
[text],
|
138 |
+
labels_list,
|
139 |
+
threshold=threshold
|
140 |
+
)
|
141 |
+
ner_results_list.append(ner_results)
|
142 |
+
|
143 |
+
progress = index / total_rows
|
144 |
+
elapsed_time = time.time() - start_time
|
145 |
+
progress_bar.progress(progress)
|
146 |
+
progress_text.text(f"Progress: {index}/{total_rows} - {progress * 100:.0f}% (Elapsed time: {elapsed_time:.2f}s)")
|
147 |
+
|
148 |
+
for label in labels_list:
|
149 |
+
extracted_entities = []
|
150 |
+
for entities in ner_results_list:
|
151 |
+
texts = [entity["text"] for entity in entities[0] if entity["label"] == label]
|
152 |
+
concatenated_texts = ", ".join(texts) if texts else ""
|
153 |
+
extracted_entities.append(concatenated_texts)
|
154 |
+
filtered_df = filtered_df.with_columns(pl.Series(name=label, values=extracted_entities))
|
155 |
+
|
156 |
+
end_time = time.time()
|
157 |
+
st.success(f"Processing completed in {end_time - start_time:.2f} seconds.")
|
158 |
+
|
159 |
+
return filtered_df
|
160 |
+
except Exception as e:
|
161 |
+
st.error(f"Error during NER processing: {str(e)}")
|
162 |
+
return filtered_df
|
163 |
+
|
164 |
+
def main():
|
165 |
+
st.title("Use NER with GliNER on your data file")
|
166 |
+
st.markdown("Prototype v0.1")
|
167 |
+
|
168 |
+
st.write("""
|
169 |
+
This application performs named entity recognition (NER) on your text data using GLiNER.
|
170 |
+
|
171 |
+
**Instructions:**
|
172 |
+
1. Upload a CSV or Excel file.
|
173 |
+
2. Select the column containing the text to analyze.
|
174 |
+
3. Filter the data if necessary.
|
175 |
+
4. Enter the NER labels you wish to detect.
|
176 |
+
5. Click "Start NER" to begin processing.
|
177 |
+
""")
|
178 |
+
|
179 |
+
if "stop_processing" not in st.session_state:
|
180 |
+
st.session_state.stop_processing = False
|
181 |
+
if "threshold" not in st.session_state:
|
182 |
+
st.session_state.threshold = 0.4
|
183 |
+
if "labels_list" not in st.session_state:
|
184 |
+
st.session_state.labels_list = []
|
185 |
+
|
186 |
+
st.session_state.gliner_model = load_model()
|
187 |
+
if st.session_state.gliner_model is None:
|
188 |
+
return
|
189 |
+
|
190 |
+
uploaded_file = st.sidebar.file_uploader("Choose a file (CSV or Excel)")
|
191 |
+
if uploaded_file is None:
|
192 |
+
st.warning("Please upload a file to continue.")
|
193 |
+
return
|
194 |
+
|
195 |
+
df = load_data(uploaded_file)
|
196 |
+
if df is None:
|
197 |
+
return
|
198 |
+
|
199 |
+
selected_column = st.selectbox("Select the column containing the text:", df.columns)
|
200 |
+
|
201 |
+
filter_text = st.text_input("Filter the column by text", "")
|
202 |
+
if filter_text:
|
203 |
+
filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
|
204 |
+
else:
|
205 |
+
filtered_df = df
|
206 |
+
|
207 |
+
st.write("Filtered data preview:")
|
208 |
+
|
209 |
+
rows_per_page = 100
|
210 |
+
total_rows = len(filtered_df)
|
211 |
+
total_pages = (total_rows - 1) // rows_per_page + 1
|
212 |
+
|
213 |
+
if "current_page" not in st.session_state:
|
214 |
+
st.session_state.current_page = 1
|
215 |
+
|
216 |
+
def update_page(new_page):
|
217 |
+
st.session_state.current_page = new_page
|
218 |
+
|
219 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
220 |
+
|
221 |
+
with col1:
|
222 |
+
first = st.button("⏮️ First")
|
223 |
+
with col2:
|
224 |
+
previous = st.button("⬅️ Previous")
|
225 |
+
with col3:
|
226 |
+
pass
|
227 |
+
with col4:
|
228 |
+
next = st.button("Next ➡️")
|
229 |
+
with col5:
|
230 |
+
last = st.button("Last ⏭️")
|
231 |
+
|
232 |
+
if first:
|
233 |
+
update_page(1)
|
234 |
+
elif previous:
|
235 |
+
if st.session_state.current_page > 1:
|
236 |
+
update_page(st.session_state.current_page - 1)
|
237 |
+
elif next:
|
238 |
+
if st.session_state.current_page < total_pages:
|
239 |
+
update_page(st.session_state.current_page + 1)
|
240 |
+
elif last:
|
241 |
+
update_page(total_pages)
|
242 |
+
|
243 |
+
with col3:
|
244 |
+
st.markdown(f"Page **{st.session_state.current_page}** of **{total_pages}**")
|
245 |
+
|
246 |
+
start_idx = (st.session_state.current_page - 1) * rows_per_page
|
247 |
+
end_idx = min(start_idx + rows_per_page, total_rows)
|
248 |
+
|
249 |
+
if not filtered_df.is_empty():
|
250 |
+
current_page_data = filtered_df.slice(start_idx, end_idx - start_idx)
|
251 |
+
st.write(f"Displaying {start_idx + 1} to {end_idx} of {total_rows} rows")
|
252 |
+
st.dataframe(current_page_data.to_pandas(), use_container_width=True)
|
253 |
+
else:
|
254 |
+
st.warning("The filtered DataFrame is empty. Please check your filters.")
|
255 |
+
|
256 |
+
st.slider("Set confidence threshold", 0.0, 1.0, st.session_state.threshold, 0.01, key="threshold")
|
257 |
+
|
258 |
+
st.session_state.labels_list = st_tags(
|
259 |
+
label="Enter the NER labels to detect",
|
260 |
+
text="Add more labels as needed",
|
261 |
+
value=st.session_state.labels_list,
|
262 |
+
key="1"
|
263 |
+
)
|
264 |
+
|
265 |
+
col1, col2 = st.columns(2)
|
266 |
+
with col1:
|
267 |
+
start_button = st.button("Start NER")
|
268 |
+
with col2:
|
269 |
+
stop_button = st.button("Stop")
|
270 |
+
|
271 |
+
if start_button:
|
272 |
+
st.session_state.stop_processing = False
|
273 |
+
|
274 |
+
if not st.session_state.labels_list:
|
275 |
+
st.warning("Please enter labels for NER.")
|
276 |
+
else:
|
277 |
+
updated_df = perform_ner(filtered_df, selected_column, st.session_state.labels_list, st.session_state.threshold)
|
278 |
+
st.write("**NER Results:**")
|
279 |
+
st.dataframe(updated_df.to_pandas(), use_container_width=True)
|
280 |
+
|
281 |
+
def to_excel(df):
|
282 |
+
output = BytesIO()
|
283 |
+
df.write_excel(output)
|
284 |
+
return output.getvalue()
|
285 |
+
|
286 |
+
def to_csv(df):
|
287 |
+
return df.write_csv().encode('utf-8')
|
288 |
+
|
289 |
+
download_col1, download_col2 = st.columns(2)
|
290 |
+
with download_col1:
|
291 |
+
st.download_button(
|
292 |
+
label="📥 Download as Excel",
|
293 |
+
data=to_excel(updated_df),
|
294 |
+
file_name="ner_results.xlsx"#,
|
295 |
+
#mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
296 |
+
)
|
297 |
+
with download_col2:
|
298 |
+
st.download_button(
|
299 |
+
label="📥 Download as CSV",
|
300 |
+
data=to_csv(updated_df),
|
301 |
+
file_name="ner_results.csv"#,
|
302 |
+
#mime="text/csv",
|
303 |
+
)
|
304 |
+
|
305 |
+
if stop_button:
|
306 |
+
st.session_state.stop_processing = True
|
307 |
+
st.warning("Processing stopped by user.")
|
308 |
+
|
309 |
+
if __name__ == "__main__":
|
310 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|