UnarineLeo's picture
Update app.py
b51c864 verified
raw
history blame
4.89 kB
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
sample_sentences = {
'Zulu': "Le ndoda ithi izo____ ukudla.",
'Tshivenda': "Mufana uyo____ vhukuma.",
'Sepedi': "Mosadi o ____ pheka.",
'Tswana': "Monna o ____ tsamaya.",
'Tsonga': "N'wana wa xisati u ____ ku tsaka."
}
def fill_mask_for_languages(sentences):
results = {}
for language, sentence in sentences.items():
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
return results
def replace_mask(sentence, predicted_word):
return sentence.replace("____", f"**predicted_word**")
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta")
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")
st.write(f"")
col1, col2 = st.columns(2)
with col1:
user_sentence = st.text_area("Enter your own sentence with a masked word (use '____'):", "\n".join(
f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
))
if st.button("Submit"):
user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)
with col2:
if 'user_masked_sentence' in locals():
if user_masked_sentence:
user_predictions = unmasker(user_masked_sentence)
# st.write(user_predictions)
if len(user_predictions) > 0:
# st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
st.write("### Predictions for Sample Sentences:")
predictions = fill_mask_for_languages(sample_sentences)
st.write(f"{predictions}")
if 'predictions' in locals():
if predictions:
for language, language_predictions in predictions.items():
# original_sentence = sample_sentences[language]
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
# st.write(language_predictions)
# st.write(f"Original sentence ({language}): {original_sentence}")
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.gr-button-primary {
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.gr-button-primary:hover{
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(66, 133, 244) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
}
.hover\:bg-orange-50:hover {
--tw-bg-opacity: 1 !important;
background-color: rgb(229,225,255) !important;
}
.to-orange-200 {
--tw-gradient-to: rgb(37 56 133 / 37%) !important;
}
.from-orange-400 {
--tw-gradient-from: rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(255 150 51 / 0);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group-hover\:from-orange-500{
--tw-gradient-from:rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(37 56 133 / 37%);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group:hover .group-hover\:text-orange-500{
--tw-text-opacity: 1 !important;
color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
width: 70%;
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)