File size: 4,611 Bytes
f5f8f9a
 
 
e255bad
f5f8f9a
4d5270d
f5f8f9a
4d5270d
f5f8f9a
4d5270d
e2a34c5
4d5270d
 
 
e2a34c5
4d5270d
 
 
f5f8f9a
7a621b0
b92b795
7a621b0
4d5270d
36c8dd5
4d5270d
b51c864
d104ff1
4d5270d
 
 
 
 
 
8091567
 
 
e2a34c5
e28273f
8091567
 
4d5270d
e28273f
9f4d0d7
fa59459
 
e28273f
 
fa59459
 
 
e2a34c5
 
 
 
4d5270d
 
8091567
 
 
 
 
 
 
 
 
 
 
 
 
e28273f
4d5270d
 
 
d104ff1
9f4d0d7
5fb1b33
e28273f
 
 
 
 
422acf2
9ba3728
b51c864
e28273f
9f4d0d7
8091567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9bd9b9
f5f8f9a
 
b51c864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f8f9a
 
 
ffffb96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
from transformers import pipeline

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')

st.set_page_config(layout="wide")

def fill_mask(sentences):
    results = {}
    warnings = []
    for key, (language, sentence) in sentences.items():
        if "<mask>" in sentence:
            masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
            unmasked = unmasker(masked_sentence)
            results[key] = (language, unmasked)
        else:
            warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
    return results, warnings

def replace_mask(sentence, predicted_word):
    return sentence.replace("<mask>", f"**{predicted_word}**")

st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")

col1, col2 = st.columns(2)

if 'text_input' not in st.session_state:
    st.session_state['text_input'] = ""

if 'warnings' not in st.session_state:
    st.session_state['warnings'] = []

if 'result' not in st.session_state:
    st.session_state['result'] = {}

language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']

input_sentences = {}

with col1:
    with st.container():
        st.markdown("Input :clipboard:")

        input1, input2 = st.columns(2)
        
        for i in range(5):
            with input1:
                language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
            with input2:
                sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
                if sentence:
                    # Create a unique key for each sentence
                    input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)

        button1, button2, _ = st.columns([2, 2, 4])

        if st.button("Test Example"):
            sample_sentences = {
                'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
                'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
                'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
                'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
                'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
            }
            st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)

        if st.button("Submit"):
            st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)

        if st.session_state['warnings']:
            for warning in st.session_state['warnings']:
                st.warning(warning)

        st.markdown("Example")
        st.code({
            'zulu': "Le ndoda ithi izo <mask> ukudla.",
            'tshivenda': "Mufana uyo <mask> vhukuma.",
            'sepedi': "Mosadi o <mask> pheka.",
            'tswana': "Monna o <mask> tsamaya.",
            'tsonga': "N'wana wa xisati u <mask> ku tsaka."
        }, wrap_lines=True)

with col2:
    with st.container():
        st.markdown("Output :bar_chart:")
        if st.session_state['result']:
            for key, (language, predictions) in st.session_state['result'].items():
                original_sentence = input_sentences[key][1]
                predicted_word = predictions[0]['token_str']
                score = predictions[0]['score'] * 100

                st.markdown(f"""
                <div class="bar">
                    <div class="bar-fill" style="width: {score}%;"></div>
                </div>
                <div class="container">
                    <div style="align-items: left;">{predicted_word} ({language})</div>
                    <div style="align-items: right;">{score:.2f}%</div>
                </div>
                """, unsafe_allow_html=True)

                predicted_sentence = replace_mask(original_sentence, predicted_word)
                st.write(f"{language}: {predicted_sentence}\n")

css = """
<style>
footer {display:none !important;}

.container {
    display: flex;
    justify-content: space-between;
    align-items: center;
    margin-bottom: 5px;
    width: 100%;
}
.bar {
    background-color: #e6e6e6;
    border-radius: 12px;
    overflow: hidden;
    margin-right: 10px;
    height: 5px;
}
.bar-fill {
    background-color: #17152e;
    height: 100%;
    border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)