File size: 4,658 Bytes
f5f8f9a
 
 
e255bad
f5f8f9a
4d5270d
f5f8f9a
4d5270d
f5f8f9a
4d5270d
f5f8f9a
4d5270d
 
 
a9bd9b9
4d5270d
 
 
f5f8f9a
7a621b0
b92b795
7a621b0
4d5270d
36c8dd5
4d5270d
b51c864
d104ff1
4d5270d
 
 
 
 
 
e28273f
 
4d5270d
e28273f
 
 
 
 
 
 
 
 
4d5270d
 
 
 
e28273f
 
 
 
 
 
 
97fa574
e28273f
4d5270d
 
 
a9bd9b9
e28273f
 
4d5270d
 
 
d104ff1
e28273f
 
 
 
 
 
 
 
9ba3728
b51c864
e28273f
 
ffffb96
 
 
 
e28273f
ffffb96
 
 
 
e28273f
ffffb96
 
 
 
 
a9bd9b9
 
ffffb96
 
f5f8f9a
97fa574
a9bd9b9
8d50985
e28273f
 
97fa574
a9bd9b9
f5f8f9a
 
b51c864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f8f9a
 
 
ffffb96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
from transformers import pipeline

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')

st.set_page_config(layout="wide")

def fill_mask(sentences):
    results = {}
    warnings = []
    for language, sentence in sentences.items():
        if "<mask>" in sentence:
            masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
            unmasked = unmasker(masked_sentence)
            results[language] = unmasked
        else:
            warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
    return results, warnings

def replace_mask(sentence, predicted_word):
    return sentence.replace("<mask>", f"**{predicted_word}**")

st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")

col1, col2 = st.columns(2)

if 'text_input' not in st.session_state:
    st.session_state['text_input'] = ""

if 'warnings' not in st.session_state:
    st.session_state['warnings'] = []

language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']

with col1:
    with st.container():
        st.markdown("### Input :clipboard:")
        
        input_sentences = {}
        for i in range(5):
            language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
            sentence = st.text_area(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
            if sentence:
                input_sentences[language.lower()] = sentence

        button1, button2, _ = st.columns([2, 2, 4])
        with button1:
            if st.button("Test Example"):
                sample_sentence = {
                    'zulu': "Le ndoda ithi izo <mask> ukudla.",
                    'tshivenda': "Mufana uyo <mask> vhukuma.",
                    'sepedi': "Mosadi o <mask> pheka.",
                    'tswana': "Monna o <mask> tsamaya.",
                    'tsonga': "N'wana wa xisati u <mask> ku tsaka."
                }
                input_sentences = sample_sentence
                result, warnings = fill_mask(input_sentences)
        
        with button2:
            if st.button("Submit"):
                result, warnings = fill_mask(input_sentences)
                st.session_state['warnings'] = warnings

        if st.session_state['warnings']:
            for warning in st.session_state['warnings']:
                st.warning(warning)

        st.markdown("### Example")
        st.code({
            'zulu': "Le ndoda ithi izo <mask> ukudla.",
            'tshivenda': "Mufana uyo <mask> vhukuma.",
            'sepedi': "Mosadi o <mask> pheka.",
            'tswana': "Monna o <mask> tsamaya.",
            'tsonga': "N'wana wa xisati u <mask> ku tsaka."
        }, wrap_lines=True)

with col2:
    with st.container():
        st.markdown("### Output :bar_chart:")
        if input_sentences:
            for language, sentence in input_sentences.items():
                masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
                predictions = unmasker(masked_sentence)

                if predictions:
                    top_prediction = predictions[0]
                    predicted_word = top_prediction['token_str']
                    score = top_prediction['score'] * 100

                    st.markdown(f"""
                    <div class="bar">
                        <div class="bar-fill" style="width: {score}%;"></div>
                    </div>
                    <div class="container">
                        <div style="align-items: left;">{predicted_word} ({language})</div>
                        <div style="align-items: right;">{score:.2f}%</div>
                    </div>
                    """, unsafe_allow_html=True)

if 'predictions' in locals():  
    if result:
        for language, language_predictions in result.items(): 
            original_sentence = input_sentences[language]
            predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
            st.write(f"{language}: {predicted_sentence}\n")

css = """
<style>
footer {display:none !important;}

.container {
    display: flex;
    justify-content: space-between;
    align-items: center;
    margin-bottom: 5px;
    width: 100%;
}
.bar {
    background-color: #e6e6e6;
    border-radius: 12px;
    overflow: hidden;
    margin-right: 10px;
    height: 5px;
}
.bar-fill {
    background-color: #17152e;
    height: 100%;
    border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)