UnarineLeo
commited on
Commit
•
9ba3728
1
Parent(s):
e255bad
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
|
|
|
4 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
5 |
|
|
|
6 |
sample_sentences = {
|
7 |
'Zulu': "Le ndoda ithi izo____ ukudla.",
|
8 |
'Tshivenda': "Mufana uyo____ vhukuma.",
|
@@ -11,46 +13,58 @@ sample_sentences = {
|
|
11 |
'Tsonga': "N'wana wa xisati u ____ ku tsaka."
|
12 |
}
|
13 |
|
|
|
14 |
def fill_mask_for_languages(sentences):
|
15 |
results = {}
|
16 |
for language, sentence in sentences.items():
|
17 |
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
|
18 |
-
|
19 |
unmasked = unmasker(masked_sentence)
|
20 |
-
|
21 |
-
results[language] = unmasked
|
22 |
return results
|
23 |
|
|
|
24 |
def replace_mask(sentence, predicted_word):
|
25 |
return sentence.replace("____", predicted_word)
|
26 |
|
|
|
27 |
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta")
|
28 |
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")
|
29 |
|
|
|
30 |
user_sentence = st.text_input("Enter your own sentence with a masked word (use '____'):", "\n".join(
|
31 |
-
|
32 |
-
|
33 |
-
for lang, sentence in sample_sentences.items()
|
34 |
-
))
|
35 |
|
|
|
36 |
if st.button("Submit"):
|
|
|
37 |
user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)
|
38 |
|
|
|
39 |
user_predictions = unmasker(user_masked_sentence)
|
40 |
|
41 |
st.write("### Your Input:")
|
42 |
st.write(f"Original sentence: {user_sentence}")
|
43 |
-
st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
st.write("### Predictions for Sample Sentences:")
|
46 |
-
|
|
|
|
|
47 |
original_sentence = sample_sentences[language]
|
48 |
-
predicted_sentence = replace_mask(
|
49 |
-
|
50 |
st.write(f"Original sentence ({language}): {original_sentence}")
|
51 |
st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
|
52 |
st.write("=" * 80)
|
53 |
|
|
|
54 |
css = """
|
55 |
<style>
|
56 |
footer {display:none !important}
|
@@ -98,5 +112,4 @@ div[data-testid="stMarkdownContainer"] p {
|
|
98 |
}
|
99 |
</style>
|
100 |
"""
|
101 |
-
|
102 |
-
st.markdown(css, unsafe_allow_html=True)
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
|
4 |
+
# Initialize the pipeline for masked language model
|
5 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
6 |
|
7 |
+
# Sample sentences with masked words in various languages
|
8 |
sample_sentences = {
|
9 |
'Zulu': "Le ndoda ithi izo____ ukudla.",
|
10 |
'Tshivenda': "Mufana uyo____ vhukuma.",
|
|
|
13 |
'Tsonga': "N'wana wa xisati u ____ ku tsaka."
|
14 |
}
|
15 |
|
16 |
+
# Function to fill mask for each language
|
17 |
def fill_mask_for_languages(sentences):
|
18 |
results = {}
|
19 |
for language, sentence in sentences.items():
|
20 |
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
|
|
|
21 |
unmasked = unmasker(masked_sentence)
|
22 |
+
results[language] = unmasked # Store the predictions in the results dictionary
|
|
|
23 |
return results
|
24 |
|
25 |
+
# Function to replace the mask token with the predicted word
|
26 |
def replace_mask(sentence, predicted_word):
|
27 |
return sentence.replace("____", predicted_word)
|
28 |
|
29 |
+
# Streamlit app
|
30 |
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta")
|
31 |
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")
|
32 |
|
33 |
+
# Get user input
|
34 |
user_sentence = st.text_input("Enter your own sentence with a masked word (use '____'):", "\n".join(
|
35 |
+
f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
|
36 |
+
))
|
|
|
|
|
37 |
|
38 |
+
# When user submits the input sentence
|
39 |
if st.button("Submit"):
|
40 |
+
# Replace the placeholder with the actual mask token
|
41 |
user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token)
|
42 |
|
43 |
+
# Get predictions for the user's sentence
|
44 |
user_predictions = unmasker(user_masked_sentence)
|
45 |
|
46 |
st.write("### Your Input:")
|
47 |
st.write(f"Original sentence: {user_sentence}")
|
|
|
48 |
|
49 |
+
# Check the structure of predictions
|
50 |
+
st.write(user_predictions) # Print to see the structure
|
51 |
+
|
52 |
+
# Display the top prediction for the masked token
|
53 |
+
if len(user_predictions) > 0:
|
54 |
+
st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
|
55 |
+
|
56 |
+
# Predictions for sample sentences
|
57 |
st.write("### Predictions for Sample Sentences:")
|
58 |
+
predictions = fill_mask_for_languages(sample_sentences)
|
59 |
+
|
60 |
+
for language, language_predictions in predictions.items():
|
61 |
original_sentence = sample_sentences[language]
|
62 |
+
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) # Use token_str for prediction
|
|
|
63 |
st.write(f"Original sentence ({language}): {original_sentence}")
|
64 |
st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
|
65 |
st.write("=" * 80)
|
66 |
|
67 |
+
# Custom CSS for styling
|
68 |
css = """
|
69 |
<style>
|
70 |
footer {display:none !important}
|
|
|
112 |
}
|
113 |
</style>
|
114 |
"""
|
115 |
+
st.markdown(css, unsafe_allow_html=True)
|
|