|
import streamlit as st |
|
from transformers import pipeline |
|
from io import StringIO |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m') |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
def fill_mask(sentences): |
|
results = {} |
|
warnings = [] |
|
|
|
for key, (language, sentence) in sentences.items(): |
|
if language == 'choose language': |
|
warnings.append(f"Warning: Choose language for {sentence}") |
|
continue |
|
|
|
if language != 'choose language' and sentence == "": |
|
warnings.append(f"Warning: Enter sentence for {language}") |
|
continue |
|
|
|
if "<mask>" in sentence: |
|
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token) |
|
unmasked = unmasker(masked_sentence) |
|
results[key] = (unmasked,language,sentence) |
|
else: |
|
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") |
|
|
|
return results, warnings |
|
|
|
def replace_mask(sentence, predicted_word): |
|
return sentence.replace("<mask>", f"**{predicted_word}**") |
|
|
|
st.write(f"") |
|
img1, img2, img3 = st.columns(3) |
|
with img2: |
|
with st.container(border=False): |
|
st.image("logo_transparent_small.png") |
|
|
|
st.markdown(""" |
|
<div style='text-align: center;'> |
|
<a href='https://github.com/dsfsi' target='_blank'>Github</a> | |
|
<a href='https://docs.google.com/forms/d/e/1FAIpQLSf7S36dyAUPx2egmXbFpnTBuzoRulhL5Elu-N1eoMhaO7v10w/viewform' target='_blank'>Feedback Form</a> | |
|
<a href='https://huggingface.co./papers/1911.02116' target='_blank'>arxiv</a> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
<div style='text-align: center;'> |
|
<h2>Fill Mask | Zabantu-XLM-Roberta</h2> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.write(f"") |
|
|
|
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.") |
|
|
|
|
|
with st.expander("More information about the space"): |
|
st.write(''' |
|
Authors: Alexis Conneau, Kartikay Khandelwal, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer, Veselin Stoyanov |
|
''') |
|
cit1,cit2 = st.columns(2) |
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
if 'text_input' not in st.session_state: |
|
st.session_state['text_input'] = "" |
|
|
|
if 'warnings' not in st.session_state: |
|
st.session_state['warnings'] = [] |
|
|
|
input_sentences = {} |
|
with col1: |
|
with st.container(border=True): |
|
st.markdown("Input :clipboard:") |
|
|
|
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] |
|
sample_sentence = {'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.", |
|
"tsonga": "N'wana wa xisati u <mask> ku tsaka." |
|
} |
|
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] |
|
|
|
option_selected = st.selectbox(f"Select an input option:", select_options, index=0) |
|
|
|
if option_selected == 'Enter text input': |
|
st.session_state['warnings'].clear() |
|
@st.fragment |
|
def choose_language(i): |
|
language = st.selectbox(f"Select language for input {i+1}:", |
|
language_options, key=f'language_{i}', index=0) |
|
return language |
|
|
|
input1, input2 = st.columns(2) |
|
for i in range(5): |
|
with input1: |
|
language = choose_language(i) |
|
|
|
with input2: |
|
sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'text_input_{i}') |
|
if sentence: |
|
if language: |
|
input_sentences[f'{i+1}'] = (language.lower(), sentence) |
|
else: |
|
warnings = [] |
|
warnings.append(f"Warning: Choose the language for input {i+1}") |
|
st.session_state['warnings'] = warnings |
|
|
|
if st.button("Submit",use_container_width=True): |
|
result, warnings = fill_mask(input_sentences) |
|
st.session_state['warnings'] = warnings |
|
|
|
if st.session_state['warnings']: |
|
for warning in st.session_state['warnings']: |
|
st.warning(warning) |
|
st.session_state['warnings'].clear() |
|
|
|
if option_selected == 'Upload a file(csv/txt)': |
|
|
|
uploaded_file = st.file_uploader("Choose a file-(one sentence per line)") |
|
if uploaded_file is not None: |
|
warnings = [] |
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
string_data = stringio.read() |
|
|
|
sentences = string_data.split("\n") |
|
|
|
i = 0 |
|
for sentence in sentences: |
|
i += 1 |
|
if ":" in sentence: |
|
splitted = sentence.split(":") |
|
language = splitted[0] |
|
sentence_mask = splitted[1] |
|
input_sentences[f'{i}'] = (language.lower(), sentence) |
|
|
|
else: |
|
warnings.append(f"Warning: No ':' token found in sentence: {sentence} in line {i}") |
|
st.session_state['warnings'] = warnings |
|
|
|
if st.button("Submit",use_container_width=True): |
|
result, warnings = fill_mask(input_sentences) |
|
st.session_state['warnings'] = warnings |
|
|
|
if st.session_state['warnings']: |
|
for warning in st.session_state['warnings']: |
|
st.warning(warning) |
|
st.session_state['warnings'].clear() |
|
|
|
st.markdown("Example") |
|
code = '''Tshivenda: Rabulasi wa <mask> u khou bvelela nga u lima. |
|
Tsonga: N'wana wa xisati u <mask> ku tsaka.''' |
|
st.code(code, language="python", wrap_lines=True) |
|
|
|
if st.button("Test Example",use_container_width=True): |
|
result, warnings = fill_mask(sample_sentence) |
|
|
|
with col2: |
|
with st.container(border=True): |
|
st.markdown("Output :bar_chart:") |
|
if 'result' in locals() and result: |
|
if len(result) == 1: |
|
for key,(predictions, language, sentence) in result.items(): |
|
for prediction in predictions: |
|
predicted_word = prediction['token_str'] |
|
score = prediction['score'] * 100 |
|
|
|
st.markdown(f""" |
|
<div class="bar"> |
|
<div class="bar-fill" style="width: {score}%;"></div> |
|
</div> |
|
<div class="container"> |
|
<div style="align-items: left;">{predicted_word}</div> |
|
<div style="align-items: center;">{score:.2f}%</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
else: |
|
for key,(predictions, language, sentence) in result.items(): |
|
if predictions: |
|
top_prediction = predictions[0] |
|
predicted_word = top_prediction['token_str'] |
|
score = top_prediction['score'] * 100 |
|
|
|
st.markdown(f""" |
|
<div class="bar"> |
|
<div class="bar-fill" style="width: {score}%;"></div> |
|
</div> |
|
<div class="container"> |
|
<div style="align-items: left;">{predicted_word} ({language})</div> |
|
<div style="align-items: right;">{score:.2f}%</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if 'result' in locals(): |
|
if result: |
|
line = 0 |
|
for key,(predictions, language, sentence) in result.items(): |
|
line += 1 |
|
predicted_word = predictions[0]['token_str'] |
|
full_sentence = replace_mask(sentence, predicted_word) |
|
st.write(f"**Sentence {line}:** {full_sentence }") |
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important;} |
|
|
|
.gr-button-primary { |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(17, 20, 45) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: none !important; |
|
} |
|
.gr-button-primary:hover{ |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(66, 133, 244) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; |
|
} |
|
.hover\:bg-orange-50:hover { |
|
--tw-bg-opacity: 1 !important; |
|
background-color: rgb(229,225,255) !important; |
|
} |
|
.to-orange-200 { |
|
--tw-gradient-to: rgb(37 56 133 / 37%) !important; |
|
} |
|
.from-orange-400 { |
|
--tw-gradient-from: rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(255 150 51 / 0); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group-hover\:from-orange-500{ |
|
--tw-gradient-from:rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(37 56 133 / 37%); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group:hover .group-hover\:text-orange-500{ |
|
--tw-text-opacity: 1 !important; |
|
color:rgb(37 56 133 / var(--tw-text-opacity)) !important; |
|
} |
|
|
|
.container { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-bottom: 5px; |
|
width: 100%; |
|
} |
|
.bar { |
|
# width: 70%; |
|
background-color: #e6e6e6; |
|
border-radius: 12px; |
|
overflow: hidden; |
|
margin-right: 10px; |
|
height: 5px; |
|
} |
|
.bar-fill { |
|
background-color: #17152e; |
|
height: 100%; |
|
border-radius: 12px; |
|
} |
|
|
|
</style> |
|
""" |
|
|
|
st.markdown(css, unsafe_allow_html=True) |