File size: 3,660 Bytes
f5f8f9a
 
 
9ba3728
e255bad
f5f8f9a
9ba3728
f5f8f9a
7a621b0
 
 
 
 
f5f8f9a
 
9ba3728
f5f8f9a
 
 
 
 
9ba3728
f5f8f9a
 
9ba3728
7a621b0
 
 
9ba3728
e255bad
f5f8f9a
 
9ba3728
aea2baa
9ba3728
 
d104ff1
9ba3728
f5f8f9a
9ba3728
d104ff1
 
9ba3728
d104ff1
f5f8f9a
d104ff1
 
 
9ba3728
 
 
 
 
 
 
 
d104ff1
9ba3728
 
 
d104ff1
9ba3728
d104ff1
 
 
f5f8f9a
9ba3728
f5f8f9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba3728
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from transformers import pipeline

# Initialize the pipeline for masked language model
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')

# Sample sentences with masked words in various languages
sample_sentences = {
    'Zulu': "Le ndoda ithi izo____ ukudla.",  
    'Tshivenda': "Mufana uyo____ vhukuma.",  
    'Sepedi': "Mosadi o ____ pheka.",  
    'Tswana': "Monna o ____ tsamaya.",  
    'Tsonga': "N'wana wa xisati u ____ ku tsaka."  
}

# Function to fill mask for each language
def fill_mask_for_languages(sentences):
    results = {}
    for language, sentence in sentences.items():
        masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
        unmasked = unmasker(masked_sentence)
        results[language] = unmasked  # Store the predictions in the results dictionary
    return results

# Function to replace the mask token with the predicted word
def replace_mask(sentence, predicted_word):
    return sentence.replace("____", predicted_word)

# Streamlit app
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta")
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")

# Get user input
user_sentence = st.text_input("Enter your own sentence with a masked word (use '____'):", "\n".join(
    f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
))

# When user submits the input sentence
if st.button("Submit"):
    # Replace the placeholder with the actual mask token
    user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)

    # Get predictions for the user's sentence
    user_predictions = unmasker(user_masked_sentence)

    st.write("### Your Input:")
    st.write(f"Original sentence: {user_sentence}")
    
    # Check the structure of predictions
    st.write(user_predictions)  # Print to see the structure

    # Display the top prediction for the masked token
    if len(user_predictions) > 0:
        st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")

    # Predictions for sample sentences
    st.write("### Predictions for Sample Sentences:")
    predictions = fill_mask_for_languages(sample_sentences)
    
    for language, language_predictions in predictions.items():
        original_sentence = sample_sentences[language]
        predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])  # Use token_str for prediction
        st.write(f"Original sentence ({language}): {original_sentence}")
        st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
        st.write("=" * 80)

# Custom CSS for styling
css = """
<style>
footer {display:none !important}

.stButton > button {
    background-color: #17152e;
    color: white;
    border: none;
    padding: 0.75em 2em;
    text-align: center;
    text-decoration: none;
    display: inline-block;
    font-size: 16px;
    margin: 4px 2px;
    cursor: pointer;
    border-radius: 12px;
    transition: background-color 0.3s ease;
}

.stButton > button:hover {
    background-color: #3c4a6b;
}

.stTextInput, .stTextArea {
    border: 1px solid #e6e6e6;
    padding: 0.75em;
    border-radius: 10px;
    font-size: 16px;
    width: 100%;
}

.stTextInput:focus, .stTextArea:focus {
    border-color: #17152e;
    outline: none;
    box-shadow: 0px 0px 5px rgba(23, 21, 46, 0.5);
}

div[data-testid="stMarkdownContainer"] p {
    font-size: 16px;
}

.stApp {
    padding: 2em;
    font-family: 'Poppins', sans-serif;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)