File size: 3,660 Bytes
f5f8f9a 9ba3728 e255bad f5f8f9a 9ba3728 f5f8f9a 7a621b0 f5f8f9a 9ba3728 f5f8f9a 9ba3728 f5f8f9a 9ba3728 7a621b0 9ba3728 e255bad f5f8f9a 9ba3728 aea2baa 9ba3728 d104ff1 9ba3728 f5f8f9a 9ba3728 d104ff1 9ba3728 d104ff1 f5f8f9a d104ff1 9ba3728 d104ff1 9ba3728 d104ff1 9ba3728 d104ff1 f5f8f9a 9ba3728 f5f8f9a 9ba3728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import streamlit as st
from transformers import pipeline
# Initialize the pipeline for masked language model
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
# Sample sentences with masked words in various languages
sample_sentences = {
'Zulu': "Le ndoda ithi izo____ ukudla.",
'Tshivenda': "Mufana uyo____ vhukuma.",
'Sepedi': "Mosadi o ____ pheka.",
'Tswana': "Monna o ____ tsamaya.",
'Tsonga': "N'wana wa xisati u ____ ku tsaka."
}
# Function to fill mask for each language
def fill_mask_for_languages(sentences):
results = {}
for language, sentence in sentences.items():
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked # Store the predictions in the results dictionary
return results
# Function to replace the mask token with the predicted word
def replace_mask(sentence, predicted_word):
return sentence.replace("____", predicted_word)
# Streamlit app
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta")
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")
# Get user input
user_sentence = st.text_input("Enter your own sentence with a masked word (use '____'):", "\n".join(
f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
))
# When user submits the input sentence
if st.button("Submit"):
# Replace the placeholder with the actual mask token
user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)
# Get predictions for the user's sentence
user_predictions = unmasker(user_masked_sentence)
st.write("### Your Input:")
st.write(f"Original sentence: {user_sentence}")
# Check the structure of predictions
st.write(user_predictions) # Print to see the structure
# Display the top prediction for the masked token
if len(user_predictions) > 0:
st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
# Predictions for sample sentences
st.write("### Predictions for Sample Sentences:")
predictions = fill_mask_for_languages(sample_sentences)
for language, language_predictions in predictions.items():
original_sentence = sample_sentences[language]
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) # Use token_str for prediction
st.write(f"Original sentence ({language}): {original_sentence}")
st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
st.write("=" * 80)
# Custom CSS for styling
css = """
<style>
footer {display:none !important}
.stButton > button {
background-color: #17152e;
color: white;
border: none;
padding: 0.75em 2em;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 16px;
margin: 4px 2px;
cursor: pointer;
border-radius: 12px;
transition: background-color 0.3s ease;
}
.stButton > button:hover {
background-color: #3c4a6b;
}
.stTextInput, .stTextArea {
border: 1px solid #e6e6e6;
padding: 0.75em;
border-radius: 10px;
font-size: 16px;
width: 100%;
}
.stTextInput:focus, .stTextArea:focus {
border-color: #17152e;
outline: none;
box-shadow: 0px 0px 5px rgba(23, 21, 46, 0.5);
}
div[data-testid="stMarkdownContainer"] p {
font-size: 16px;
}
.stApp {
padding: 2em;
font-family: 'Poppins', sans-serif;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)
|