import streamlit as st from transformers import pipeline from io import StringIO unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m') st.stop_rerun = True st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] for language, sentence in sentences.items(): if "" in sentence: masked_sentence = sentence.replace('', unmasker.tokenizer.mask_token) unmasked = unmasker(masked_sentence) results[language] = unmasked else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") # Set up title and description st.title("Fill Mask | Zabantu-XLM-Roberta") st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...") # Initialize session state if 'warnings' not in st.session_state: st.session_state['warnings'] = [] if 'results' not in st.session_state: st.session_state['results'] = {} # Define layout col1, col2 = st.columns(2) with col1: st.markdown("### Input :clipboard:") select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] sample_sentence = { 'tshivenda': "Rabulasi wa u khou bvelela nga u lima.", "tsonga": "N'wana wa xisati u ku tsaka." } option_selected = st.selectbox("Select an input option:", select_options, index=0) input_sentences = {} if option_selected == 'Enter text input': st.session_state['warnings'].clear() # Clear warnings before new input language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] for i in range(5): language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}') sentence = st.text_input(f"Enter sentence for input {i+1} (with ):", key=f'sentence_{i}') # Only process filled language and sentence pairs if language != 'Choose language' and sentence: input_sentences[language.lower()] = sentence if st.button("Submit"): if input_sentences: results, warnings = fill_mask(input_sentences) st.session_state['results'] = results st.session_state['warnings'] = warnings else: st.warning("Please fill at least one language and sentence.") elif option_selected == 'Upload a file(csv/txt)': uploaded_file = st.file_uploader("Choose a file (one sentence per line)") if uploaded_file: stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) sentences = stringio.read().splitlines() for i, sentence in enumerate(sentences): # Here, you might need to define how to assign a language to each sentence # Assuming all sentences are in the same language for simplicity input_sentences[f'input_{i+1}'] = sentence if st.button("Submit"): results, warnings = fill_mask(input_sentences) st.session_state['results'] = results st.session_state['warnings'] = warnings st.markdown("### Example") st.code(sample_sentence) if st.button("Test Example"): result, warnings = fill_mask(sample_sentence) st.session_state['results'] = result st.session_state['warnings'] = warnings with col2: st.markdown("### Output :bar_chart:") if st.session_state['results']: # Use st.fragment for dynamic content with st.container(): for language, predictions in st.session_state['results'].items(): if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 # Displaying the prediction with fragment st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)") st.markdown(f"
", unsafe_allow_html=True) if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) # CSS for styling css = """ """ st.markdown(css, unsafe_allow_html=True)