import streamlit as st from transformers import pipeline from io import StringIO unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m') st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] # warnings.append(f"= {sentences.items()}") for key, (language, sentence) in sentences.items(): if language == 'choose language': warnings.append(f"Warning: Choose language for {sentence}") continue if language != 'choose language' and sentence == "": warnings.append(f"Warning: Enter sentence for {language}") continue if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[key] = (unmasked,language,sentence) else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-XLM-Roberta") st.write(f"") st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] input_sentences = {} with col1: with st.container(border=True): st.markdown("Input :clipboard:") select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] sample_sentence = {'tshivenda': "Rabulasi wa u khou bvelela nga u lima.", "tsonga": "N'wana wa xisati u ku tsaka." } language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] option_selected = st.selectbox(f"Select an input option:", select_options, index=0) if option_selected == 'Enter text input': st.session_state['warnings'].clear() @st.fragment def choose_language(i): language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}', index=0) return language input1, input2 = st.columns(2) for i in range(5): with input1: language = choose_language(i) # st.write(f"lang : {language}") with input2: sentence = st.text_input(f"Enter sentence for input {i+1} (with ):", key=f'text_input_{i}') if sentence: if language: input_sentences[f'{i+1}'] = (language.lower(), sentence) else: warnings = [] warnings.append(f"Warning: Choose the language for input {i+1}") st.session_state['warnings'] = warnings if st.button("Submit",use_container_width=True): result, warnings = fill_mask(input_sentences) st.session_state['warnings'] = warnings if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.session_state['warnings'].clear() if option_selected == 'Upload a file(csv/txt)': uploaded_file = st.file_uploader("Choose a file-(one sentence per line)") if uploaded_file is not None: warnings = [] stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) string_data = stringio.read() sentences = string_data.split("\n") i = 0 for sentence in sentences: i += 1 if ":" in sentence: splitted = sentence.split(":") language = splitted[0] sentence_mask = splitted[1] input_sentences[f'{i}'] = (language.lower(), sentence) else: warnings.append(f"Warning: No ':' token found in sentence: {sentence} in line {i}") st.session_state['warnings'] = warnings if st.button("Submit",use_container_width=True): result, warnings = fill_mask(input_sentences) st.session_state['warnings'] = warnings if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.session_state['warnings'].clear() st.markdown("Example") st.code(sample_sentence, wrap_lines=True) if st.button("Test Example",use_container_width=True): result, warnings = fill_mask(sample_sentence) with col2: with st.container(border=True): st.markdown("Output :bar_chart:") if 'result' in locals() and result: if len(result) == 1: for key,(predictions, language, sentence) in result.items(): for prediction in predictions: predicted_word = prediction['token_str'] score = prediction['score'] * 100 st.markdown(f"""
{predicted_word}
{score:.2f}%
""", unsafe_allow_html=True) else: for key,(predictions, language, sentence) in result.items(): if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 st.markdown(f"""
{predicted_word} ({language})
{score:.2f}%
""", unsafe_allow_html=True) if 'result' in locals(): if result: line = 0 for key,(predictions, language, sentence) in result.items(): line += 1 predicted_word = predictions[0]['token_str'] full_sentence = replace_mask(sentence, predicted_word) st.write(f"**Sentence {line}:** {full_sentence }") css = """ """ st.markdown(css, unsafe_allow_html=True)