|
import streamlit as st |
|
from transformers import pipeline |
|
from io import StringIO |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m') |
|
st.stop_rerun = True |
|
st.set_page_config(layout="wide") |
|
|
|
def fill_mask(sentences): |
|
results = {} |
|
warnings = [] |
|
for language, sentence in sentences.items(): |
|
if "<mask>" in sentence: |
|
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token) |
|
unmasked = unmasker(masked_sentence) |
|
results[language] = unmasked |
|
else: |
|
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") |
|
return results, warnings |
|
|
|
def replace_mask(sentence, predicted_word): |
|
return sentence.replace("<mask>", f"**{predicted_word}**") |
|
|
|
|
|
st.title("Fill Mask | Zabantu-XLM-Roberta") |
|
st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...") |
|
|
|
|
|
if 'warnings' not in st.session_state: |
|
st.session_state['warnings'] = [] |
|
if 'results' not in st.session_state: |
|
st.session_state['results'] = {} |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.markdown("### Input :clipboard:") |
|
|
|
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] |
|
sample_sentence = { |
|
'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.", |
|
"tsonga": "N'wana wa xisati u <mask> ku tsaka." |
|
} |
|
|
|
option_selected = st.selectbox("Select an input option:", select_options, index=0) |
|
input_sentences = {} |
|
|
|
if option_selected == 'Enter text input': |
|
st.session_state['warnings'].clear() |
|
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] |
|
|
|
for i in range(5): |
|
language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}') |
|
sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'sentence_{i}') |
|
|
|
|
|
if language != 'Choose language' and sentence: |
|
input_sentences[language.lower()] = sentence |
|
|
|
if st.button("Submit"): |
|
if input_sentences: |
|
results, warnings = fill_mask(input_sentences) |
|
st.session_state['results'] = results |
|
st.session_state['warnings'] = warnings |
|
else: |
|
st.warning("Please fill at least one language and sentence.") |
|
|
|
elif option_selected == 'Upload a file(csv/txt)': |
|
uploaded_file = st.file_uploader("Choose a file (one sentence per line)") |
|
if uploaded_file: |
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
sentences = stringio.read().splitlines() |
|
|
|
for i, sentence in enumerate(sentences): |
|
|
|
|
|
input_sentences[f'input_{i+1}'] = sentence |
|
|
|
if st.button("Submit"): |
|
results, warnings = fill_mask(input_sentences) |
|
st.session_state['results'] = results |
|
st.session_state['warnings'] = warnings |
|
|
|
st.markdown("### Example") |
|
st.code(sample_sentence) |
|
if st.button("Test Example"): |
|
result, warnings = fill_mask(sample_sentence) |
|
st.session_state['results'] = result |
|
st.session_state['warnings'] = warnings |
|
|
|
with col2: |
|
st.markdown("### Output :bar_chart:") |
|
|
|
if st.session_state['results']: |
|
|
|
with st.container(): |
|
for language, predictions in st.session_state['results'].items(): |
|
if predictions: |
|
top_prediction = predictions[0] |
|
predicted_word = top_prediction['token_str'] |
|
score = top_prediction['score'] * 100 |
|
|
|
|
|
st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)") |
|
st.markdown(f"<div class='bar'><div class='bar-fill' style='width:{score}%;'></div></div>", unsafe_allow_html=True) |
|
|
|
if st.session_state['warnings']: |
|
for warning in st.session_state['warnings']: |
|
st.warning(warning) |
|
|
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important;} |
|
.bar {width: 70%; background-color: #e6e6e6; border-radius: 12px; height: 5px;} |
|
.bar-fill {background-color: #17152e; height: 100%; border-radius: 12px;} |
|
.container {display: flex; justify-content: space-between; align-items: center; margin-bottom: 5px;} |
|
</style> |
|
""" |
|
st.markdown(css, unsafe_allow_html=True) |
|
|