File size: 2,906 Bytes
f5f8f9a
 
 
 
 
 
7a621b0
 
 
 
 
f5f8f9a
 
 
 
 
 
 
 
 
 
 
 
7a621b0
 
 
f5f8f9a
 
 
aea2baa
 
 
 
d104ff1
f5f8f9a
d104ff1
 
 
f5f8f9a
d104ff1
 
 
 
 
 
 
7a621b0
 
d104ff1
 
 
f5f8f9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a621b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
from transformers import pipeline

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')

sample_sentences = {
    'Zulu': "Le ndoda ithi izo____ ukudla.",  
    'Tshivenda': "Mufana uyo____ vhukuma.",  
    'Sepedi': "Mosadi o ____ pheka.",  
    'Tswana': "Monna o ____ tsamaya.",  
    'Tsonga': "N'wana wa xisati u ____ ku tsaka."  
}

def fill_mask_for_languages(sentences):
    results = {}
    for language, sentence in sentences.items():
        masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
        
        unmasked = unmasker(masked_sentence)
        
        results[language] = unmasked
    return results

def replace_mask(sentence, predicted_word):
    return sentence.replace("____", predicted_word)

st.title("Fill Mask for Multiple Languages | Zabantu-Bantu-250m")
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")

user_sentence = st.text_input("Enter your own sentence with a masked word (use '____'):", "\n".join(
        f"'{lang}': '{sentence}',"
        for lang, sentence in sample_sentences.items()
    )

if st.button("Submit"):
    user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)

    user_predictions = unmasker(user_masked_sentence)

    st.write("### Your Input:")
    st.write(f"Original sentence: {user_sentence}")
    st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
    
    st.write("### Predictions for Sample Sentences:")
    for language, predictions in fill_mask_for_languages(sample_sentences).items():
        original_sentence = sample_sentences[language]
        predicted_sentence = replace_mask(sentence, predictions[0]['sequence'])
        
        st.write(f"Original sentence ({language}): {original_sentence}")
        st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
        st.write("=" * 80)

css = """
<style>
footer {display:none !important}

.stButton > button {
    background-color: #17152e;
    color: white;
    border: none;
    padding: 0.75em 2em;
    text-align: center;
    text-decoration: none;
    display: inline-block;
    font-size: 16px;
    margin: 4px 2px;
    cursor: pointer;
    border-radius: 12px;
    transition: background-color 0.3s ease;
}

.stButton > button:hover {
    background-color: #3c4a6b;
}

.stTextInput, .stTextArea {
    border: 1px solid #e6e6e6;
    padding: 0.75em;
    border-radius: 10px;
    font-size: 16px;
    width: 100%;
}

.stTextInput:focus, .stTextArea:focus {
    border-color: #17152e;
    outline: none;
    box-shadow: 0px 0px 5px rgba(23, 21, 46, 0.5);
}

div[data-testid="stMarkdownContainer"] p {
    font-size: 16px;
}

.stApp {
    padding: 2em;
    font-family: 'Poppins', sans-serif;
}
</style>
"""

st.markdown(css, unsafe_allow_html=True)