File size: 4,774 Bytes
f5f8f9a e255bad f5f8f9a 4d5270d f5f8f9a 5b5eeef 4d5270d f5f8f9a 4d5270d 5b5eeef 4d5270d 5b5eeef 4d5270d f5f8f9a 7a621b0 b92b795 7a621b0 4d5270d 36c8dd5 4d5270d b51c864 d104ff1 5b5eeef 4d5270d 5b5eeef e2a34c5 e28273f 5b5eeef 4d5270d e28273f 9f4d0d7 fa59459 e28273f 4eb39ad 5b5eeef fa59459 5b5eeef 4d5270d e28273f 5b5eeef 91470f6 5b5eeef 4d5270d d104ff1 9f4d0d7 5fb1b33 e28273f 422acf2 9ba3728 b51c864 e28273f 9f4d0d7 5b5eeef 4eb39ad 91470f6 5b5eeef f5f8f9a b51c864 f5f8f9a 91470f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
st.set_page_config(layout="wide")
st.stop_rerun = True
def fill_mask(sentences):
results = {}
warnings = []
for key, (language, sentence) in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[key] = (language, unmasked)
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
if 'result' not in st.session_state:
st.session_state['result'] = {}
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
input_sentences = {}
with col1:
with st.container():
st.markdown("Input :clipboard:")
input1, input2 = st.columns(2)
for i in range(5):
with input1:
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
with input2:
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
if sentence:
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
button1, button2, _ = st.columns([2, 2, 4])
if st.button("Test Example"):
sample_sentences = {
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
if st.button("Submit"):
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("Example")
st.code({
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}, wrap_lines=True)
with col2:
with st.container():
st.markdown("Output :bar_chart:")
if st.session_state['result']:
for key, (language, predictions) in st.session_state['result'].items():
original_sentence = input_sentences[key][1] if key in input_sentences else ""
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
if 'predictions' in locals():
if result:
for language, language_predictions in result.items():
original_sentence = sample_sentence[language]
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True) |