UnarineLeo commited on
Commit
91470f6
1 Parent(s): 5b5eeef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -5,7 +5,6 @@ unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
5
 
6
  st.set_page_config(layout="wide")
7
 
8
- # Disable auto-rerun when selecting options or typing
9
  st.stop_rerun = True
10
 
11
  def fill_mask(sentences):
@@ -47,33 +46,29 @@ with col1:
47
 
48
  input1, input2 = st.columns(2)
49
 
50
- # Loop to gather input sentences
51
  for i in range(5):
52
  with input1:
53
  language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
54
  with input2:
55
  sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
56
  if sentence:
57
- # Use a unique key for each sentence (even if languages are the same)
58
  input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
59
 
60
  button1, button2, _ = st.columns([2, 2, 4])
61
 
62
- # Call fill_mask on button click, not on form input
63
  if st.button("Test Example"):
64
  sample_sentences = {
65
- 'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
66
- 'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
67
- 'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
68
- 'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
69
- 'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
70
  }
71
  st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
72
 
73
  if st.button("Submit"):
74
  st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
75
 
76
- # Display warnings
77
  if st.session_state['warnings']:
78
  for warning in st.session_state['warnings']:
79
  st.warning(warning)
@@ -90,7 +85,6 @@ with col1:
90
  with col2:
91
  with st.container():
92
  st.markdown("Output :bar_chart:")
93
- # Check for the result in session_state and display predictions
94
  if st.session_state['result']:
95
  for key, (language, predictions) in st.session_state['result'].items():
96
  original_sentence = input_sentences[key][1] if key in input_sentences else ""
@@ -109,8 +103,12 @@ with col2:
109
  </div>
110
  """, unsafe_allow_html=True)
111
 
112
- predicted_sentence = replace_mask(original_sentence, predicted_word)
113
- st.write(f"{language}: {predicted_sentence}\n")
 
 
 
 
114
 
115
  css = """
116
  <style>
@@ -137,4 +135,4 @@ footer {display:none !important;}
137
  }
138
  </style>
139
  """
140
- st.markdown(css, unsafe_allow_html=True)
 
5
 
6
  st.set_page_config(layout="wide")
7
 
 
8
  st.stop_rerun = True
9
 
10
  def fill_mask(sentences):
 
46
 
47
  input1, input2 = st.columns(2)
48
 
 
49
  for i in range(5):
50
  with input1:
51
  language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
52
  with input2:
53
  sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
54
  if sentence:
 
55
  input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
56
 
57
  button1, button2, _ = st.columns([2, 2, 4])
58
 
 
59
  if st.button("Test Example"):
60
  sample_sentences = {
61
+ 'zulu': "Le ndoda ithi izo <mask> ukudla.",
62
+ 'tshivenda': "Mufana uyo <mask> vhukuma.",
63
+ 'sepedi': "Mosadi o <mask> pheka.",
64
+ 'tswana': "Monna o <mask> tsamaya.",
65
+ 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
66
  }
67
  st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
68
 
69
  if st.button("Submit"):
70
  st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
71
 
 
72
  if st.session_state['warnings']:
73
  for warning in st.session_state['warnings']:
74
  st.warning(warning)
 
85
  with col2:
86
  with st.container():
87
  st.markdown("Output :bar_chart:")
 
88
  if st.session_state['result']:
89
  for key, (language, predictions) in st.session_state['result'].items():
90
  original_sentence = input_sentences[key][1] if key in input_sentences else ""
 
103
  </div>
104
  """, unsafe_allow_html=True)
105
 
106
+ if 'predictions' in locals():
107
+ if result:
108
+ for language, language_predictions in result.items():
109
+ original_sentence = sample_sentence[language]
110
+ predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
111
+ st.write(f"{language}: {predicted_sentence}\n")
112
 
113
  css = """
114
  <style>
 
135
  }
136
  </style>
137
  """
138
+ st.markdown(css, unsafe_allow_html=True)