UnarineLeo commited on
Commit
4eb39ad
1 Parent(s): d75ba48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -42
app.py CHANGED
@@ -8,11 +8,11 @@ st.set_page_config(layout="wide")
8
  def fill_mask(sentences):
9
  results = {}
10
  warnings = []
11
- for key, (language, sentence) in sentences.items():
12
  if "<mask>" in sentence:
13
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
14
  unmasked = unmasker(masked_sentence)
15
- results[key] = (language, unmasked)
16
  else:
17
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
18
  return results, warnings
@@ -25,48 +25,51 @@ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combi
25
 
26
  col1, col2 = st.columns(2)
27
 
28
- if 'text_input' not in st.session_state:
29
- st.session_state['text_input'] = ""
30
-
31
  if 'warnings' not in st.session_state:
32
  st.session_state['warnings'] = []
33
 
34
- if 'result' not in st.session_state:
35
- st.session_state['result'] = {}
36
-
37
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
38
 
39
- input_sentences = {}
40
-
41
  with col1:
42
  with st.container():
43
  st.markdown("Input :clipboard:")
44
 
45
  input1, input2 = st.columns(2)
46
 
 
47
  for i in range(5):
48
- with input1:
49
  language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
50
  with input2:
 
51
  disabled = True if language == "Choose language" else False
52
  sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}', disabled=disabled)
53
  if not disabled and sentence:
54
  input_sentences[language.lower()] = sentence
55
 
56
  button1, button2, _ = st.columns([2, 2, 4])
57
-
58
- if st.button("Test Example"):
59
- sample_sentences = {
60
- 'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
61
- 'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
62
- 'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
63
- 'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
64
- 'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
65
- }
66
- st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
67
-
68
- if st.button("Submit"):
69
- st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
 
 
 
 
 
 
70
 
71
  if st.session_state['warnings']:
72
  for warning in st.session_state['warnings']:
@@ -84,25 +87,29 @@ with col1:
84
  with col2:
85
  with st.container():
86
  st.markdown("Output :bar_chart:")
87
- if st.session_state['result']:
88
- for key, (language, predictions) in st.session_state['result'].items():
89
- original_sentence = input_sentences[key][1]
90
- predicted_word = predictions[0]['token_str']
91
- score = predictions[0]['score'] * 100
92
-
93
- st.markdown(f"""
94
- <div class="bar">
95
- <div class="bar-fill" style="width: {score}%;"></div>
96
- </div>
97
- <div class="container">
98
- <div style="align-items: left;">{predicted_word} ({language})</div>
99
- <div style="align-items: right;">{score:.2f}%</div>
100
- </div>
101
- """, unsafe_allow_html=True)
102
-
103
- predicted_sentence = replace_mask(original_sentence, predicted_word)
104
- st.write(f"{language}: {predicted_sentence}\n")
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  css = """
107
  <style>
108
  footer {display:none !important;}
 
8
  def fill_mask(sentences):
9
  results = {}
10
  warnings = []
11
+ for language, sentence in sentences.items():
12
  if "<mask>" in sentence:
13
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
14
  unmasked = unmasker(masked_sentence)
15
+ results[language] = unmasked
16
  else:
17
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
18
  return results, warnings
 
25
 
26
  col1, col2 = st.columns(2)
27
 
28
+ # Initialize session states
29
+ if 'submit_clicked' not in st.session_state:
30
+ st.session_state['submit_clicked'] = False
31
  if 'warnings' not in st.session_state:
32
  st.session_state['warnings'] = []
33
 
 
 
 
34
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
35
 
 
 
36
  with col1:
37
  with st.container():
38
  st.markdown("Input :clipboard:")
39
 
40
  input1, input2 = st.columns(2)
41
 
42
+ input_sentences = {}
43
  for i in range(5):
44
+ with input1:
45
  language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
46
  with input2:
47
+ # Disable text input if language is not selected
48
  disabled = True if language == "Choose language" else False
49
  sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}', disabled=disabled)
50
  if not disabled and sentence:
51
  input_sentences[language.lower()] = sentence
52
 
53
  button1, button2, _ = st.columns([2, 2, 4])
54
+
55
+ with button1:
56
+ if st.button("Test Example"):
57
+ sample_sentence = {
58
+ 'zulu': "Le ndoda ithi izo <mask> ukudla.",
59
+ 'tshivenda': "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis.",
60
+ 'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima",
61
+ 'tswana': "Monna o <mask> tsamaya.",
62
+ 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
63
+ }
64
+ input_sentences = sample_sentence
65
+ result, warnings = fill_mask(input_sentences)
66
+
67
+ with button2:
68
+ # Set session state when "Submit" is clicked
69
+ if st.button("Submit"):
70
+ st.session_state['submit_clicked'] = True
71
+ result, warnings = fill_mask(input_sentences)
72
+ st.session_state['warnings'] = warnings
73
 
74
  if st.session_state['warnings']:
75
  for warning in st.session_state['warnings']:
 
87
  with col2:
88
  with st.container():
89
  st.markdown("Output :bar_chart:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Ensure output only runs after "Submit" is clicked
92
+ if st.session_state['submit_clicked'] and input_sentences:
93
+ for language, sentence in input_sentences.items():
94
+ masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
95
+ predictions = unmasker(masked_sentence)
96
+
97
+ if predictions:
98
+ top_prediction = predictions[0]
99
+ predicted_word = top_prediction['token_str']
100
+ score = top_prediction['score'] * 100
101
+
102
+ st.markdown(f"""
103
+ <div class="bar">
104
+ <div class="bar-fill" style="width: {score}%;"></div>
105
+ </div>
106
+ <div class="container">
107
+ <div style="align-items: left;">{predicted_word} ({language})</div>
108
+ <div style="align-items: right;">{score:.2f}%</div>
109
+ </div>
110
+ """, unsafe_allow_html=True)
111
+
112
+ # CSS to hide footer and style the output
113
  css = """
114
  <style>
115
  footer {display:none !important;}