UnarineLeo's picture
Update app.py
e28273f verified
raw
history blame
4.66 kB
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for language, sentence in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
with col1:
with st.container():
st.markdown("### Input :clipboard:")
input_sentences = {}
for i in range(5):
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
sentence = st.text_area(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
if sentence:
input_sentences[language.lower()] = sentence
button1, button2, _ = st.columns([2, 2, 4])
with button1:
if st.button("Test Example"):
sample_sentence = {
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}
input_sentences = sample_sentence
result, warnings = fill_mask(input_sentences)
with button2:
if st.button("Submit"):
result, warnings = fill_mask(input_sentences)
st.session_state['warnings'] = warnings
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("### Example")
st.code({
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}, wrap_lines=True)
with col2:
with st.container():
st.markdown("### Output :bar_chart:")
if input_sentences:
for language, sentence in input_sentences.items():
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
predictions = unmasker(masked_sentence)
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
if 'predictions' in locals():
if result:
for language, language_predictions in result.items():
original_sentence = input_sentences[language]
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)