File size: 4,774 Bytes
f5f8f9a
 
 
e255bad
f5f8f9a
4d5270d
f5f8f9a
5b5eeef
 
4d5270d
f5f8f9a
4d5270d
5b5eeef
4d5270d
 
 
5b5eeef
4d5270d
 
 
f5f8f9a
7a621b0
b92b795
7a621b0
4d5270d
36c8dd5
4d5270d
b51c864
d104ff1
5b5eeef
 
 
4d5270d
 
 
5b5eeef
 
 
e2a34c5
e28273f
5b5eeef
 
4d5270d
e28273f
9f4d0d7
fa59459
 
e28273f
 
4eb39ad
5b5eeef
fa59459
5b5eeef
 
 
4d5270d
 
e28273f
5b5eeef
 
91470f6
 
 
 
 
5b5eeef
 
 
 
 
 
4d5270d
 
 
d104ff1
9f4d0d7
5fb1b33
e28273f
 
 
 
 
422acf2
9ba3728
b51c864
e28273f
9f4d0d7
5b5eeef
 
 
4eb39ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91470f6
 
 
 
 
 
5b5eeef
f5f8f9a
 
b51c864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f8f9a
 
 
91470f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
from transformers import pipeline

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')

st.set_page_config(layout="wide")

st.stop_rerun = True

def fill_mask(sentences):
    results = {}
    warnings = []
    for key, (language, sentence) in sentences.items():
        if "<mask>" in sentence:
            masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
            unmasked = unmasker(masked_sentence)
            results[key] = (language, unmasked)
        else:
            warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
    return results, warnings

def replace_mask(sentence, predicted_word):
    return sentence.replace("<mask>", f"**{predicted_word}**")

st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")

col1, col2 = st.columns(2)

if 'text_input' not in st.session_state:
    st.session_state['text_input'] = ""

if 'warnings' not in st.session_state:
    st.session_state['warnings'] = []

if 'result' not in st.session_state:
    st.session_state['result'] = {}

language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']

input_sentences = {}

with col1:
    with st.container():
        st.markdown("Input :clipboard:")

        input1, input2 = st.columns(2)
        
        for i in range(5):
            with input1:
                language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
            with input2:
                sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
                if sentence:
                    input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)

        button1, button2, _ = st.columns([2, 2, 4])

        if st.button("Test Example"):
            sample_sentences = {
                'zulu': "Le ndoda ithi izo <mask> ukudla.",
                'tshivenda': "Mufana uyo <mask> vhukuma.",
                'sepedi': "Mosadi o <mask> pheka.",
                'tswana': "Monna o <mask> tsamaya.",
                'tsonga': "N'wana wa xisati u <mask> ku tsaka."
            }
            st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)

        if st.button("Submit"):
            st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)

        if st.session_state['warnings']:
            for warning in st.session_state['warnings']:
                st.warning(warning)

        st.markdown("Example")
        st.code({
            'zulu': "Le ndoda ithi izo <mask> ukudla.",
            'tshivenda': "Mufana uyo <mask> vhukuma.",
            'sepedi': "Mosadi o <mask> pheka.",
            'tswana': "Monna o <mask> tsamaya.",
            'tsonga': "N'wana wa xisati u <mask> ku tsaka."
        }, wrap_lines=True)

with col2:
    with st.container():
        st.markdown("Output :bar_chart:")
        if st.session_state['result']:
            for key, (language, predictions) in st.session_state['result'].items():
                original_sentence = input_sentences[key][1] if key in input_sentences else ""
                if predictions:
                    top_prediction = predictions[0]
                    predicted_word = top_prediction['token_str']
                    score = top_prediction['score'] * 100

                    st.markdown(f"""
                    <div class="bar">
                        <div class="bar-fill" style="width: {score}%;"></div>
                    </div>
                    <div class="container">
                        <div style="align-items: left;">{predicted_word} ({language})</div>
                        <div style="align-items: right;">{score:.2f}%</div>
                    </div>
                    """, unsafe_allow_html=True)

if 'predictions' in locals():  
    if result:
        for language, language_predictions in result.items(): 
            original_sentence = sample_sentence[language]
            predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])  
            st.write(f"{language}: {predicted_sentence}\n")

css = """
<style>
footer {display:none !important;}

.container {
    display: flex;
    justify-content: space-between;
    align-items: center;
    margin-bottom: 5px;
    width: 100%;
}
.bar {
    background-color: #e6e6e6;
    border-radius: 12px;
    overflow: hidden;
    margin-right: 10px;
    height: 5px;
}
.bar-fill {
    background-color: #17152e;
    height: 100%;
    border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)