UnarineLeo's picture
Update app.py
8091567 verified
raw
history blame
4.61 kB
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for key, (language, sentence) in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[key] = (language, unmasked)
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
if 'result' not in st.session_state:
st.session_state['result'] = {}
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
input_sentences = {}
with col1:
with st.container():
st.markdown("Input :clipboard:")
input1, input2 = st.columns(2)
for i in range(5):
with input1:
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
with input2:
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
if sentence:
# Create a unique key for each sentence
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
button1, button2, _ = st.columns([2, 2, 4])
if st.button("Test Example"):
sample_sentences = {
'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
}
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
if st.button("Submit"):
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("Example")
st.code({
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}, wrap_lines=True)
with col2:
with st.container():
st.markdown("Output :bar_chart:")
if st.session_state['result']:
for key, (language, predictions) in st.session_state['result'].items():
original_sentence = input_sentences[key][1]
predicted_word = predictions[0]['token_str']
score = predictions[0]['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
predicted_sentence = replace_mask(original_sentence, predicted_word)
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)