UnarineLeo's picture
Update app.py
b3dcfe5 verified
raw
history blame
4.85 kB
import streamlit as st
from transformers import pipeline
from io import StringIO
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for language, sentence in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
# Set up title and description
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...")
# Initialize session state
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
if 'results' not in st.session_state:
st.session_state['results'] = {}
# Define layout
col1, col2 = st.columns(2)
with col1:
st.markdown("### Input :clipboard:")
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
sample_sentence = {
'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
"tsonga": "N'wana wa xisati u <mask> ku tsaka."
}
option_selected = st.selectbox("Select an input option:", select_options, index=0)
input_sentences = {}
if option_selected == 'Enter text input':
st.session_state['warnings'].clear() # Clear warnings before new input
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
for i in range(5):
language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}')
sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'sentence_{i}')
# Only process filled language and sentence pairs
if language != 'Choose language' and sentence:
input_sentences[language.lower()] = sentence
if st.button("Submit"):
if input_sentences:
results, warnings = fill_mask(input_sentences)
st.session_state['results'] = results
st.session_state['warnings'] = warnings
else:
st.warning("Please fill at least one language and sentence.")
elif option_selected == 'Upload a file(csv/txt)':
uploaded_file = st.file_uploader("Choose a file (one sentence per line)")
if uploaded_file:
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
sentences = stringio.read().splitlines()
for i, sentence in enumerate(sentences):
# Here, you might need to define how to assign a language to each sentence
# Assuming all sentences are in the same language for simplicity
input_sentences[f'input_{i+1}'] = sentence
if st.button("Submit"):
results, warnings = fill_mask(input_sentences)
st.session_state['results'] = results
st.session_state['warnings'] = warnings
st.markdown("### Example")
st.code(sample_sentence)
if st.button("Test Example"):
result, warnings = fill_mask(sample_sentence)
st.session_state['results'] = result
st.session_state['warnings'] = warnings
with col2:
st.markdown("### Output :bar_chart:")
if st.session_state['results']:
# Use st.fragment for dynamic content
with st.container():
for language, predictions in st.session_state['results'].items():
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
# Displaying the prediction with fragment
st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)")
st.markdown(f"<div class='bar'><div class='bar-fill' style='width:{score}%;'></div></div>", unsafe_allow_html=True)
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
# CSS for styling
css = """
<style>
footer {display:none !important;}
.bar {width: 70%; background-color: #e6e6e6; border-radius: 12px; height: 5px;}
.bar-fill {background-color: #17152e; height: 100%; border-radius: 12px;}
.container {display: flex; justify-content: space-between; align-items: center; margin-bottom: 5px;}
</style>
"""
st.markdown(css, unsafe_allow_html=True)