File size: 4,658 Bytes
f5f8f9a e255bad f5f8f9a 4d5270d f5f8f9a 4d5270d f5f8f9a 4d5270d f5f8f9a 4d5270d a9bd9b9 4d5270d f5f8f9a 7a621b0 b92b795 7a621b0 4d5270d 36c8dd5 4d5270d b51c864 d104ff1 4d5270d e28273f 4d5270d e28273f 4d5270d e28273f 97fa574 e28273f 4d5270d a9bd9b9 e28273f 4d5270d d104ff1 e28273f 9ba3728 b51c864 e28273f ffffb96 e28273f ffffb96 e28273f ffffb96 a9bd9b9 ffffb96 f5f8f9a 97fa574 a9bd9b9 8d50985 e28273f 97fa574 a9bd9b9 f5f8f9a b51c864 f5f8f9a ffffb96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for language, sentence in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
with col1:
with st.container():
st.markdown("### Input :clipboard:")
input_sentences = {}
for i in range(5):
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
sentence = st.text_area(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
if sentence:
input_sentences[language.lower()] = sentence
button1, button2, _ = st.columns([2, 2, 4])
with button1:
if st.button("Test Example"):
sample_sentence = {
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}
input_sentences = sample_sentence
result, warnings = fill_mask(input_sentences)
with button2:
if st.button("Submit"):
result, warnings = fill_mask(input_sentences)
st.session_state['warnings'] = warnings
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("### Example")
st.code({
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}, wrap_lines=True)
with col2:
with st.container():
st.markdown("### Output :bar_chart:")
if input_sentences:
for language, sentence in input_sentences.items():
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
predictions = unmasker(masked_sentence)
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
if 'predictions' in locals():
if result:
for language, language_predictions in result.items():
original_sentence = input_sentences[language]
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)
|