UnarineLeo
commited on
Commit
•
e28273f
1
Parent(s):
27c53ba
Update app.py
Browse files
app.py
CHANGED
@@ -21,8 +21,6 @@ def replace_mask(sentence, predicted_word):
|
|
21 |
return sentence.replace("<mask>", f"**{predicted_word}**")
|
22 |
|
23 |
st.title("Fill Mask | Zabantu-XLM-Roberta")
|
24 |
-
st.write("")
|
25 |
-
|
26 |
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
|
27 |
|
28 |
col1, col2 = st.columns(2)
|
@@ -33,56 +31,63 @@ if 'text_input' not in st.session_state:
|
|
33 |
if 'warnings' not in st.session_state:
|
34 |
st.session_state['warnings'] = []
|
35 |
|
|
|
|
|
36 |
with col1:
|
37 |
-
with st.container(
|
38 |
-
st.markdown("Input :clipboard:")
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
text_input = st.text_area(
|
48 |
-
"Enter sentences with <mask> token:",
|
49 |
-
value=st.session_state['text_input']
|
50 |
-
)
|
51 |
-
|
52 |
-
input_sentences = {f"sentence_{i}": sentence for i, sentence in enumerate(text_input.split("\n")) if sentence}
|
53 |
|
54 |
button1, button2, _ = st.columns([2, 2, 4])
|
55 |
with button1:
|
56 |
if st.button("Test Example"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
input_sentences = sample_sentence
|
58 |
-
result, warnings = fill_mask(input_sentences)
|
59 |
-
# st.session_state['text_input'] = "\n".join([f"{lang}: {sentence}" for lang, sentence in sample_sentence.items()])
|
60 |
|
61 |
with button2:
|
62 |
if st.button("Submit"):
|
63 |
result, warnings = fill_mask(input_sentences)
|
64 |
-
st.session_state['warnings'] = warnings
|
65 |
-
|
66 |
if st.session_state['warnings']:
|
67 |
for warning in st.session_state['warnings']:
|
68 |
st.warning(warning)
|
69 |
|
70 |
-
st.markdown("Example")
|
71 |
-
st.code(
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
with col2:
|
74 |
-
with st.container(
|
75 |
-
st.markdown("Output :bar_chart:")
|
76 |
if input_sentences:
|
77 |
for language, sentence in input_sentences.items():
|
78 |
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
79 |
predictions = unmasker(masked_sentence)
|
80 |
-
|
81 |
if predictions:
|
82 |
top_prediction = predictions[0]
|
83 |
predicted_word = top_prediction['token_str']
|
84 |
score = top_prediction['score'] * 100
|
85 |
-
|
86 |
st.markdown(f"""
|
87 |
<div class="bar">
|
88 |
<div class="bar-fill" style="width: {score}%;"></div>
|
@@ -96,8 +101,8 @@ with col2:
|
|
96 |
if 'predictions' in locals():
|
97 |
if result:
|
98 |
for language, language_predictions in result.items():
|
99 |
-
original_sentence =
|
100 |
-
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
|
101 |
st.write(f"{language}: {predicted_sentence}\n")
|
102 |
|
103 |
css = """
|
@@ -125,5 +130,4 @@ footer {display:none !important;}
|
|
125 |
}
|
126 |
</style>
|
127 |
"""
|
128 |
-
|
129 |
st.markdown(css, unsafe_allow_html=True)
|
|
|
21 |
return sentence.replace("<mask>", f"**{predicted_word}**")
|
22 |
|
23 |
st.title("Fill Mask | Zabantu-XLM-Roberta")
|
|
|
|
|
24 |
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
|
25 |
|
26 |
col1, col2 = st.columns(2)
|
|
|
31 |
if 'warnings' not in st.session_state:
|
32 |
st.session_state['warnings'] = []
|
33 |
|
34 |
+
language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
|
35 |
+
|
36 |
with col1:
|
37 |
+
with st.container():
|
38 |
+
st.markdown("### Input :clipboard:")
|
39 |
+
|
40 |
+
input_sentences = {}
|
41 |
+
for i in range(5):
|
42 |
+
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
|
43 |
+
sentence = st.text_area(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
|
44 |
+
if sentence:
|
45 |
+
input_sentences[language.lower()] = sentence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
button1, button2, _ = st.columns([2, 2, 4])
|
48 |
with button1:
|
49 |
if st.button("Test Example"):
|
50 |
+
sample_sentence = {
|
51 |
+
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
52 |
+
'tshivenda': "Mufana uyo <mask> vhukuma.",
|
53 |
+
'sepedi': "Mosadi o <mask> pheka.",
|
54 |
+
'tswana': "Monna o <mask> tsamaya.",
|
55 |
+
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
56 |
+
}
|
57 |
input_sentences = sample_sentence
|
58 |
+
result, warnings = fill_mask(input_sentences)
|
|
|
59 |
|
60 |
with button2:
|
61 |
if st.button("Submit"):
|
62 |
result, warnings = fill_mask(input_sentences)
|
63 |
+
st.session_state['warnings'] = warnings
|
64 |
+
|
65 |
if st.session_state['warnings']:
|
66 |
for warning in st.session_state['warnings']:
|
67 |
st.warning(warning)
|
68 |
|
69 |
+
st.markdown("### Example")
|
70 |
+
st.code({
|
71 |
+
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
72 |
+
'tshivenda': "Mufana uyo <mask> vhukuma.",
|
73 |
+
'sepedi': "Mosadi o <mask> pheka.",
|
74 |
+
'tswana': "Monna o <mask> tsamaya.",
|
75 |
+
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
76 |
+
}, wrap_lines=True)
|
77 |
|
78 |
with col2:
|
79 |
+
with st.container():
|
80 |
+
st.markdown("### Output :bar_chart:")
|
81 |
if input_sentences:
|
82 |
for language, sentence in input_sentences.items():
|
83 |
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
84 |
predictions = unmasker(masked_sentence)
|
85 |
+
|
86 |
if predictions:
|
87 |
top_prediction = predictions[0]
|
88 |
predicted_word = top_prediction['token_str']
|
89 |
score = top_prediction['score'] * 100
|
90 |
+
|
91 |
st.markdown(f"""
|
92 |
<div class="bar">
|
93 |
<div class="bar-fill" style="width: {score}%;"></div>
|
|
|
101 |
if 'predictions' in locals():
|
102 |
if result:
|
103 |
for language, language_predictions in result.items():
|
104 |
+
original_sentence = input_sentences[language]
|
105 |
+
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
|
106 |
st.write(f"{language}: {predicted_sentence}\n")
|
107 |
|
108 |
css = """
|
|
|
130 |
}
|
131 |
</style>
|
132 |
"""
|
|
|
133 |
st.markdown(css, unsafe_allow_html=True)
|