import streamlit as st from transformers import pipeline unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] for key, (language, sentence) in sentences.items(): if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[key] = (language, unmasked) else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-XLM-Roberta") st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] if 'result' not in st.session_state: st.session_state['result'] = {} language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] input_sentences = {} with col1: with st.container(): st.markdown("Input :clipboard:") input1, input2 = st.columns(2) for i in range(5): with input1: language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}') with input2: disabled = True if language == "Choose language" else False sentence = st.text_input(f"Enter sentence for {language} (with ):", key=f'text_input_{i}', disabled=disabled) if not disabled and sentence: input_sentences[language.lower()] = sentence button1, button2, _ = st.columns([2, 2, 4]) if st.button("Test Example"): sample_sentences = { 'zulu_1': ('zulu', "Le ndoda ithi izo ukudla."), 'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u nga Listeriosis."), 'tshivenda_3': ('tshivenda', "Rabulasi wa u khou bvelela nga u lima"), 'tswana_4': ('tswana', "Monna o tsamaya."), 'tsonga_5': ('tsonga', "N'wana wa xisati u ku tsaka.") } st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences) if st.button("Submit"): st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences) if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.markdown("Example") st.code({ 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." }, wrap_lines=True) with col2: with st.container(): st.markdown("Output :bar_chart:") if st.session_state['result']: for key, (language, predictions) in st.session_state['result'].items(): original_sentence = input_sentences[key][1] predicted_word = predictions[0]['token_str'] score = predictions[0]['score'] * 100 st.markdown(f"""
{predicted_word} ({language})
{score:.2f}%
""", unsafe_allow_html=True) predicted_sentence = replace_mask(original_sentence, predicted_word) st.write(f"{language}: {predicted_sentence}\n") css = """ """ st.markdown(css, unsafe_allow_html=True)