File size: 4,874 Bytes
f5f8f9a
 
15eff6a
f5f8f9a
0802504
6cbe2dc
4d5270d
f5f8f9a
4d5270d
f5f8f9a
4d5270d
b854242
4d5270d
0ce811e
 
b854242
4d5270d
 
 
f5f8f9a
7a621b0
b92b795
7a621b0
d99c1e8
4d5270d
d99c1e8
15eff6a
d99c1e8
 
 
 
 
4d5270d
d99c1e8
b51c864
d104ff1
d99c1e8
 
5b5eeef
d99c1e8
 
 
 
 
4d5270d
d99c1e8
 
fa59459
d99c1e8
 
0716ef3
 
d99c1e8
 
 
89b629e
b3dcfe5
d99c1e8
 
 
 
b3dcfe5
d99c1e8
 
 
b3dcfe5
 
d99c1e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba3728
b51c864
d99c1e8
15eff6a
d99c1e8
b3dcfe5
 
 
 
 
 
 
 
 
 
 
d99c1e8
 
 
 
5b5eeef
d99c1e8
f5f8f9a
 
b51c864
d99c1e8
 
 
f5f8f9a
 
b3dcfe5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
from transformers import pipeline
from io import StringIO

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
st.stop_rerun = True
st.set_page_config(layout="wide")

def fill_mask(sentences):
    results = {}
    warnings = []
    for language, sentence in sentences.items():
        if "<mask>" in sentence:
            masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
            unmasked = unmasker(masked_sentence)
            results[language] = unmasked
        else:
            warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
    return results, warnings

def replace_mask(sentence, predicted_word):
    return sentence.replace("<mask>", f"**{predicted_word}**")

# Set up title and description
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...")

# Initialize session state
if 'warnings' not in st.session_state:
    st.session_state['warnings'] = []
if 'results' not in st.session_state:
    st.session_state['results'] = {}

# Define layout
col1, col2 = st.columns(2)

with col1:
    st.markdown("### Input :clipboard:")

    select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
    sample_sentence = {
        'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
        "tsonga": "N'wana wa xisati u <mask> ku tsaka."
    }

    option_selected = st.selectbox("Select an input option:", select_options, index=0)
    input_sentences = {}

    if option_selected == 'Enter text input':
        st.session_state['warnings'].clear()  # Clear warnings before new input
        language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
        
        for i in range(5):
            language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}')
            sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'sentence_{i}')
            
            # Only process filled language and sentence pairs
            if language != 'Choose language' and sentence:
                input_sentences[language.lower()] = sentence

        if st.button("Submit"):
            if input_sentences:
                results, warnings = fill_mask(input_sentences)
                st.session_state['results'] = results
                st.session_state['warnings'] = warnings
            else:
                st.warning("Please fill at least one language and sentence.")

    elif option_selected == 'Upload a file(csv/txt)':
        uploaded_file = st.file_uploader("Choose a file (one sentence per line)")
        if uploaded_file:
            stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
            sentences = stringio.read().splitlines()

            for i, sentence in enumerate(sentences):
                # Here, you might need to define how to assign a language to each sentence
                # Assuming all sentences are in the same language for simplicity
                input_sentences[f'input_{i+1}'] = sentence

            if st.button("Submit"):
                results, warnings = fill_mask(input_sentences)
                st.session_state['results'] = results
                st.session_state['warnings'] = warnings

    st.markdown("### Example")
    st.code(sample_sentence)
    if st.button("Test Example"):
        result, warnings = fill_mask(sample_sentence)
        st.session_state['results'] = result
        st.session_state['warnings'] = warnings

with col2:
    st.markdown("### Output :bar_chart:")
    
    if st.session_state['results']:
        # Use st.fragment for dynamic content
        with st.container():
            for language, predictions in st.session_state['results'].items():
                if predictions:
                    top_prediction = predictions[0]
                    predicted_word = top_prediction['token_str']
                    score = top_prediction['score'] * 100

                    # Displaying the prediction with fragment
                    st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)")
                    st.markdown(f"<div class='bar'><div class='bar-fill' style='width:{score}%;'></div></div>", unsafe_allow_html=True)

    if st.session_state['warnings']:
        for warning in st.session_state['warnings']:
            st.warning(warning)

# CSS for styling
css = """
<style>
footer {display:none !important;}
.bar {width: 70%; background-color: #e6e6e6; border-radius: 12px; height: 5px;}
.bar-fill {background-color: #17152e; height: 100%; border-radius: 12px;}
.container {display: flex; justify-content: space-between; align-items: center; margin-bottom: 5px;}
</style>
"""
st.markdown(css, unsafe_allow_html=True)