UnarineLeo's picture
Update app.py
a21e8a3 verified
raw
history blame
10.3 kB
import streamlit as st
from transformers import pipeline
from io import StringIO
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
for language, sentence in sentences.items():
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[language] = unmasked
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.write(f"")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
with col1:
with st.container(border=True):
st.markdown("Input :clipboard:")
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
sample_sentence = {'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
"tsonga": "N'wana wa xisati u <mask> ku tsaka."
}
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
input_sentences = {}
if option_selected == 'Enter text input':
st.session_state['warnings'].clear()
# Initialize session state to preserve language and sentence inputs between reruns
if 'input_sentences' not in st.session_state:
st.session_state['input_sentences'] = {}
if 'languages_selected' not in st.session_state:
st.session_state['languages_selected'] = {}
input1, input2 = st.columns(2)
for i in range(5):
# Get the previously selected language and sentence, if available
previous_language = st.session_state['languages_selected'].get(f'language_{i}', 'Choose language')
previous_sentence = st.session_state['input_sentences'].get(f'text_input_{i}', '')
# Select language in column 1
with input1:
language = st.selectbox(
f"Select language for input {i+1}:",
language_options,
key=f'language_{i}',
index=language_options.index(previous_language) if previous_language in language_options else 0
)
# Store selected language in session state
st.session_state['languages_selected'][f'language_{i}'] = language
# Enter sentence in column 2
with input2:
sentence = st.text_input(
f"Enter sentence for input {i+1} (with <mask>):",
key=f'text_input_{i}',
value=previous_sentence
)
# Store input sentence in session state
st.session_state['input_sentences'][f'text_input_{i}'] = sentence
if sentence:
if language and language != 'Choose language':
# Add valid language and sentence to input_sentences
st.session_state['input_sentences'][f'{language.lower()}_{i+1}'] = sentence
else:
st.session_state['warnings'].append(f"Warning: Choose the language for input {i+1}")
# Submit button
if st.button("Submit"):
if st.session_state['warnings']:
# Show warnings if any
for warning in st.session_state['warnings']:
st.warning(warning)
else:
# Process the sentences if no warnings
result, warnings = fill_mask(st.session_state['input_sentences'])
st.session_state['warnings'] = warnings
if option_selected == 'Upload a file(csv/txt)':
uploaded_file = st.file_uploader("Choose a file-(one sentence per line)")
if uploaded_file is not None:
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
string_data = stringio.read()
input_sentences = string_data.split("\n")
if st.button("Submit",use_container_width=True):
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
else:
result, warnings = fill_mask(input_sentences)
st.session_state['warnings'] = warnings
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.markdown("Example")
st.code(sample_sentence, wrap_lines=True)
if st.button("Test Example",use_container_width=True):
result, warnings = fill_mask(sample_sentence)
with col2:
with st.container(border=True):
st.markdown("Output :bar_chart:")
if 'result' in locals() and result:
if len(result) == 1:
for language, predictions in result.items():
for prediction in predictions:
predicted_word = prediction['token_str']
score = prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word}</div>
<div style="align-items: center;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
else:
for language, predictions in result.items():
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
if 'result' in locals():
if result:
line = 0
for sentence, predictions in result.items():
line += 1
predicted_word = predictions[0]['token_str']
full_sentence = replace_mask(sentence, predicted_word)
st.write(f"**Sentence {line}:** {full_sentence }")
css = """
<style>
footer {display:none !important;}
.gr-button-primary {
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.gr-button-primary:hover{
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(66, 133, 244) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
}
.hover\:bg-orange-50:hover {
--tw-bg-opacity: 1 !important;
background-color: rgb(229,225,255) !important;
}
.to-orange-200 {
--tw-gradient-to: rgb(37 56 133 / 37%) !important;
}
.from-orange-400 {
--tw-gradient-from: rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(255 150 51 / 0);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group-hover\:from-orange-500{
--tw-gradient-from:rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(37 56 133 / 37%);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group:hover .group-hover\:text-orange-500{
--tw-text-opacity: 1 !important;
color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
# width: 70%;
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)