import streamlit as st from transformers import pipeline unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] for language, sentence in sentences.items(): if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[language] = unmasked else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-XLM-Roberta") st.write("") st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] with col1: with st.container(border=True): st.markdown("Input :clipboard:") sample_sentence = { 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." } text_input = st.text_area( "Enter sentences with token:", value=st.session_state['text_input'] ) input_sentences = {f"sentence_{i}": sentence for i, sentence in enumerate(text_input.split("\n")) if sentence} button1, button2, _ = st.columns([2, 2, 4]) with button1: if st.button("Test Example"): result, warnings = fill_mask(sample_sentence) # Passing the correct dictionary format st.session_state['text_input'] = "\n".join([f"{lang}: {sentence}" for lang, sentence in sample_sentence.items()]) with button2: if st.button("Submit"): result, warnings = fill_mask(input_sentences) # input_sentences is already a dictionary st.session_state['warnings'] = warnings if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.markdown("Example") st.code(sample_sentence, wrap_lines=True) with col2: with st.container(border=True): st.markdown("Output :bar_chart:") if input_sentences: for language, sentence in input_sentences.items(): masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) predictions = unmasker(masked_sentence) if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 st.markdown(f"""
{predicted_word} ({language})
{score:.2f}%
""", unsafe_allow_html=True) css = """ """ st.markdown(css, unsafe_allow_html=True)