import streamlit as st from transformers import pipeline from io import StringIO unmasker = pipeline('fill-mask', model='dsfsi/zabantu-ven-120m') st.set_page_config(layout="wide") def fill_mask(sentences): results = {} warnings = [] for sentence in sentences: if "" in sentence: unmasked = unmasker(sentence) results[sentence] = unmasked else: warnings.append(f"Warning: No token found in sentence: {sentence}") return results, warnings def replace_mask(sentence, predicted_word): return sentence.replace("", f"**{predicted_word}**") st.title("Fill Mask | Zabantu-ven-120m") st.write(f"") st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Tshivenda(ven) sentences on a transformer network with 120 million traininable parameters.") col1, col2 = st.columns(2) if 'text_input' not in st.session_state: st.session_state['text_input'] = "" if 'warnings' not in st.session_state: st.session_state['warnings'] = [] with col1: with st.container(border=True): st.markdown("Input :clipboard:") select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u nga Listeriosis." option_selected = st.selectbox(f"Select an input option:", select_options, index=0) if option_selected == 'Enter text input': text_input = st.text_area( "Enter sentences with token(one sentence per line):", value=st.session_state['text_input'] ) input_sentences = text_input.split("\n") if st.button("Submit",use_container_width=True): result, warnings = fill_mask(input_sentences) st.session_state['warnings'] = warnings if option_selected == 'Upload a file(csv/txt)': uploaded_file = st.file_uploader("Choose a file-(one sentence per line)") if uploaded_file is not None: stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) string_data = stringio.read() input_sentences = string_data.split("\n") if st.button("Submit",use_container_width=True): result, warnings = fill_mask(input_sentences) st.session_state['warnings'] = warnings if st.session_state['warnings']: for warning in st.session_state['warnings']: st.warning(warning) st.markdown("Example") st.code(sample_sentence, wrap_lines=True) if st.button("Test Example",use_container_width=True): result, warnings = fill_mask(sample_sentence.split("\n")) with col2: with st.container(border=True): st.markdown("Output :bar_chart:") if 'result' in locals() and result: if len(result) == 1: for sentence, predictions in result.items(): for prediction in predictions: predicted_word = prediction['token_str'] score = prediction['score'] * 100 st.markdown(f"""
{predicted_word}
{score:.2f}%
""", unsafe_allow_html=True) else: index = 0 for sentence, predictions in result.items(): index += 1 if predictions: top_prediction = predictions[0] predicted_word = top_prediction['token_str'] score = top_prediction['score'] * 100 st.markdown(f"""
{predicted_word} (line {index})
{score:.2f}%
""", unsafe_allow_html=True) if 'result' in locals(): if result: line = 0 for sentence, predictions in result.items(): line += 1 predicted_word = predictions[0]['token_str'] full_sentence = replace_mask(sentence, predicted_word) st.write(f"**Sentence {line}:** {full_sentence }") css = """ """ st.markdown(css, unsafe_allow_html=True)