File size: 4,874 Bytes
f5f8f9a 15eff6a f5f8f9a 0802504 6cbe2dc 4d5270d f5f8f9a 4d5270d f5f8f9a 4d5270d b854242 4d5270d 0ce811e b854242 4d5270d f5f8f9a 7a621b0 b92b795 7a621b0 d99c1e8 4d5270d d99c1e8 15eff6a d99c1e8 4d5270d d99c1e8 b51c864 d104ff1 d99c1e8 5b5eeef d99c1e8 4d5270d d99c1e8 fa59459 d99c1e8 0716ef3 d99c1e8 89b629e b3dcfe5 d99c1e8 b3dcfe5 d99c1e8 b3dcfe5 d99c1e8 9ba3728 b51c864 d99c1e8 15eff6a d99c1e8 b3dcfe5 d99c1e8 5b5eeef d99c1e8 f5f8f9a b51c864 d99c1e8 f5f8f9a b3dcfe5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
from transformers import pipeline
from io import StringIO
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
st.stop_rerun = True
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for language, sentence in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
# Set up title and description
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...")
# Initialize session state
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
if 'results' not in st.session_state:
st.session_state['results'] = {}
# Define layout
col1, col2 = st.columns(2)
with col1:
st.markdown("### Input :clipboard:")
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
sample_sentence = {
'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
"tsonga": "N'wana wa xisati u <mask> ku tsaka."
}
option_selected = st.selectbox("Select an input option:", select_options, index=0)
input_sentences = {}
if option_selected == 'Enter text input':
st.session_state['warnings'].clear() # Clear warnings before new input
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
for i in range(5):
language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}')
sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'sentence_{i}')
# Only process filled language and sentence pairs
if language != 'Choose language' and sentence:
input_sentences[language.lower()] = sentence
if st.button("Submit"):
if input_sentences:
results, warnings = fill_mask(input_sentences)
st.session_state['results'] = results
st.session_state['warnings'] = warnings
else:
st.warning("Please fill at least one language and sentence.")
elif option_selected == 'Upload a file(csv/txt)':
uploaded_file = st.file_uploader("Choose a file (one sentence per line)")
if uploaded_file:
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
sentences = stringio.read().splitlines()
for i, sentence in enumerate(sentences):
# Here, you might need to define how to assign a language to each sentence
# Assuming all sentences are in the same language for simplicity
input_sentences[f'input_{i+1}'] = sentence
if st.button("Submit"):
results, warnings = fill_mask(input_sentences)
st.session_state['results'] = results
st.session_state['warnings'] = warnings
st.markdown("### Example")
st.code(sample_sentence)
if st.button("Test Example"):
result, warnings = fill_mask(sample_sentence)
st.session_state['results'] = result
st.session_state['warnings'] = warnings
with col2:
st.markdown("### Output :bar_chart:")
if st.session_state['results']:
# Use st.fragment for dynamic content
with st.container():
for language, predictions in st.session_state['results'].items():
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
# Displaying the prediction with fragment
st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)")
st.markdown(f"<div class='bar'><div class='bar-fill' style='width:{score}%;'></div></div>", unsafe_allow_html=True)
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
# CSS for styling
css = """
<style>
footer {display:none !important;}
.bar {width: 70%; background-color: #e6e6e6; border-radius: 12px; height: 5px;}
.bar-fill {background-color: #17152e; height: 100%; border-radius: 12px;}
.container {display: flex; justify-content: space-between; align-items: center; margin-bottom: 5px;}
</style>
"""
st.markdown(css, unsafe_allow_html=True)
|