import streamlit as st from transformers import pipeline unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] for language, sentence in sentences.items(): if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[language] = unmasked else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-XLM-Roberta") st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] with col1: with st.container(): st.markdown("### Input :clipboard:") input1, input2 = st.columns(2) input_sentences = {} for i in range(5): with input1: language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}') with input2: sentence = st.text_input(f"Enter sentence for {language} (with ):", key=f'text_input_{i}') if sentence: input_sentences[language.lower()] = sentence button1, button2, _ = st.columns([2, 2, 4]) with button1: if st.button("Test Example"): sample_sentence = { 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." } input_sentences = sample_sentence result, warnings = fill_mask(input_sentences) with button2: if st.button("Submit"): result, warnings = fill_mask(input_sentences) st.session_state['warnings'] = warnings if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.markdown("### Example") st.code({ 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." }, wrap_lines=True) with col2: with st.container(): st.markdown("### Output :bar_chart:") if input_sentences: for language, sentence in input_sentences.items(): masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) predictions = unmasker(masked_sentence) if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 st.markdown(f"""
{predicted_word} ({language})
{score:.2f}%
""", unsafe_allow_html=True) if 'predictions' in locals(): if result: for language, language_predictions in result.items(): original_sentence = input_sentences[language] predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) st.write(f"{language}: {predicted_sentence}\n") css = """ """ st.markdown(css, unsafe_allow_html=True)