Spaces:
Sleeping
Sleeping
from functools import partial | |
from pathlib import Path | |
from pandas import DataFrame, read_csv, read_excel | |
import streamlit as st | |
from more_itertools import ichunked | |
from stqdm import stqdm | |
from onnx_model_utils import predict, predict_bulk, max_pred_bulk, RELEASE_TAG | |
from download import download_link | |
PRED_BATCH_SIZE = 4 | |
st.set_page_config(page_title="ROTA", initial_sidebar_state="collapsed") | |
st.markdown(Path("ABOUT.md").read_text()) | |
st.markdown("## ✏️ Single Coder Demo") | |
input_text = st.text_input( | |
"Input Offense", | |
value="FRAUDULENT USE OF A CREDIT CARD OR DEBT CARD >= $25,000", | |
) | |
predictions = predict(input_text) | |
st.markdown("Predictions") | |
labels = ["Charge Category"] | |
st.dataframe( | |
DataFrame(predictions[0]) | |
.assign( | |
confidence=lambda d: d["score"].apply(lambda d: round(d * 100, 0)).astype(int) | |
) | |
.drop("score", axis="columns") | |
) | |
st.markdown("---") | |
st.markdown("## 📑 Bulk Coder") | |
st.warning( | |
"⚠️ *Note:* Your input data will be deduplicated" | |
" on the selected column to reduce computation requirements." | |
" You will need to re-join the results on your offense text column." | |
) | |
st.markdown("1️⃣ **Upload File**") | |
uploaded_file = st.file_uploader("Bulk Upload", type=["xlsx", "csv"]) | |
file_readers = {"csv": read_csv, "xlsx": partial(read_excel, engine="openpyxl")} | |
if uploaded_file is not None: | |
for filetype, reader in file_readers.items(): | |
if uploaded_file.name.endswith(filetype): | |
df = reader(uploaded_file) | |
file_name = uploaded_file.name | |
del uploaded_file | |
st.write("2️⃣ **Select Column of Offense Descriptions**") | |
string_columns = list(df.select_dtypes("object").columns) | |
longest_column = max( | |
[(df[c].str.len().mean(), c) for c in string_columns], key=lambda x: x[0] | |
)[1] | |
selected_column = st.selectbox( | |
"Select Column", | |
options=list(string_columns), | |
index=string_columns.index(longest_column), | |
) | |
original_length = len(df) | |
df_unique = df.drop_duplicates(subset=[selected_column]).copy() | |
del df | |
st.markdown( | |
f"Uploaded Data Sample `(Deduplicated. N Rows = {len(df_unique)}, Original N = {original_length})`" | |
) | |
st.dataframe(df_unique.head(20)) | |
st.write(f"3️⃣ **Predict Using Column: `{selected_column}`**") | |
column = df_unique[selected_column].copy() | |
del df_unique | |
if st.button(f"Compute Predictions"): | |
input_texts = (value for _, value in column.items()) | |
n_batches = (len(column) // PRED_BATCH_SIZE) + 1 | |
bulk_preds = [] | |
for batch in stqdm( | |
ichunked(input_texts, PRED_BATCH_SIZE), | |
total=n_batches, | |
desc="Bulk Predict Progress", | |
): | |
batch_preds = predict_bulk(batch) | |
bulk_preds.extend(batch_preds) | |
pred_df = column.to_frame() | |
max_preds = max_pred_bulk(bulk_preds) | |
pred_df["charge_category_pred"] = [p["label"] for p in max_preds] | |
pred_df["charge_category_pred_confidence"] = [ | |
int(round(p["score"] * 100, 0)) for p in max_preds | |
] | |
del column | |
del bulk_preds | |
del max_preds | |
# # TODO: Add all scores | |
st.write("**Sample Output**") | |
st.dataframe(pred_df.head(100)) | |
tmp_download_link = download_link( | |
pred_df, | |
f"{file_name}-ncrp-predictions.csv", | |
"⬇️ Download as CSV", | |
) | |
st.markdown(tmp_download_link, unsafe_allow_html=True) | |