File size: 4,611 Bytes
f5f8f9a e255bad f5f8f9a 4d5270d f5f8f9a 4d5270d f5f8f9a 4d5270d e2a34c5 4d5270d e2a34c5 4d5270d f5f8f9a 7a621b0 b92b795 7a621b0 4d5270d 36c8dd5 4d5270d b51c864 d104ff1 4d5270d 8091567 e2a34c5 e28273f 8091567 4d5270d e28273f 9f4d0d7 fa59459 e28273f fa59459 e2a34c5 4d5270d 8091567 e28273f 4d5270d d104ff1 9f4d0d7 5fb1b33 e28273f 422acf2 9ba3728 b51c864 e28273f 9f4d0d7 8091567 a9bd9b9 f5f8f9a b51c864 f5f8f9a ffffb96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
from transformers import pipeline
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for key, (language, sentence) in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[key] = (language, unmasked)
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
if 'result' not in st.session_state:
st.session_state['result'] = {}
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
input_sentences = {}
with col1:
with st.container():
st.markdown("Input :clipboard:")
input1, input2 = st.columns(2)
for i in range(5):
with input1:
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
with input2:
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
if sentence:
# Create a unique key for each sentence
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
button1, button2, _ = st.columns([2, 2, 4])
if st.button("Test Example"):
sample_sentences = {
'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
}
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
if st.button("Submit"):
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("Example")
st.code({
'zulu': "Le ndoda ithi izo <mask> ukudla.",
'tshivenda': "Mufana uyo <mask> vhukuma.",
'sepedi': "Mosadi o <mask> pheka.",
'tswana': "Monna o <mask> tsamaya.",
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
}, wrap_lines=True)
with col2:
with st.container():
st.markdown("Output :bar_chart:")
if st.session_state['result']:
for key, (language, predictions) in st.session_state['result'].items():
original_sentence = input_sentences[key][1]
predicted_word = predictions[0]['token_str']
score = predictions[0]['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
predicted_sentence = replace_mask(original_sentence, predicted_word)
st.write(f"{language}: {predicted_sentence}\n")
css = """
<style>
footer {display:none !important;}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)
|