import streamlit as st from transformers import pipeline unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') st.set_page_config(layout="wide") st.stop_rerun = True def fill_mask(sentences): results = {} warnings = [] for key, (language, sentence) in sentences.items(): if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[key] = (language, unmasked) else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-XLM-Roberta") st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] if 'result' not in st.session_state: st.session_state['result'] = {} language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] input_sentences = {} with col1: with st.container(): st.markdown("Input :clipboard:") input1, input2 = st.columns(2) for i in range(5): with input1: language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0) with input2: sentence = st.text_input(f"Enter sentence for {language} (with ):", key=f'text_input_{i}') if sentence: input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence) button1, button2, _ = st.columns([2, 2, 4]) if st.button("Test Example"): sample_sentences = { 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." } st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences) if st.button("Submit"): st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences) if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.markdown("Example") st.code({ 'zulu': "Le ndoda ithi izo ukudla.", 'tshivenda': "Mufana uyo vhukuma.", 'sepedi': "Mosadi o pheka.", 'tswana': "Monna o tsamaya.", 'tsonga': "N'wana wa xisati u ku tsaka." }, wrap_lines=True) with col2: with st.container(): st.markdown("Output :bar_chart:") if st.session_state['result']: for key, (language, predictions) in st.session_state['result'].items(): original_sentence = input_sentences[key][1] if key in input_sentences else "" if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 st.markdown(f"""
{predicted_word} ({language})
{score:.2f}%
""", unsafe_allow_html=True) if 'predictions' in locals(): if result: for language, language_predictions in result.items(): original_sentence = sample_sentence[language] predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) st.write(f"{language}: {predicted_sentence}\n") css = """ """ st.markdown(css, unsafe_allow_html=True)