UnarineLeo commited on
Commit
4d5270d
1 Parent(s): b92b795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -51
app.py CHANGED
@@ -3,72 +3,111 @@ from transformers import pipeline
3
 
4
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
5
 
6
- sample_sentences = {
7
- 'zulu': "Le ndoda ithi izo <mask> ukudla.",
8
- 'tshivenda': "Mufana uyo <mask> vhukuma.",
9
- 'sepedi': "Mosadi o <mask> pheka.",
10
- 'tswana': "Monna o <mask> tsamaya.",
11
- 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
12
- }
13
 
14
- def fill_mask_for_languages(sentences):
15
  results = {}
 
16
  for language, sentence in sentences.items():
17
- masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
18
- unmasked = unmasker(masked_sentence)
19
- results[language] = unmasked
20
- return results
 
 
 
21
 
22
  def replace_mask(sentence, predicted_word):
23
  return sentence.replace("<mask>", f"**{predicted_word}**")
24
 
25
- st.title("Fill Mask| Zabantu-XLM-Roberta")
26
  st.write(f"")
27
 
 
 
28
  col1, col2 = st.columns(2)
29
 
30
- with col1:
31
- user_sentence = st.text_area("Enter your own sentence with a masked word (use '____'):", "\n".join(
32
- f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
33
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if st.button("Submit"):
36
- user_masked_sentence = user_sentence.replace('<mask>', unmasker.tokenizer.mask_token)
37
 
38
  with col2:
39
- if 'user_masked_sentence' in locals():
40
- if user_masked_sentence:
41
- user_predictions = unmasker(user_masked_sentence)
 
 
 
 
42
 
43
- # st.write(user_predictions)
44
-
45
- if len(user_predictions) > 0:
46
- # st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
47
-
48
- predictions = fill_mask_for_languages(sample_sentences)
49
- for language, language_predictions in predictions.items():
50
- predicted_word = language_predictions[0]['token_str']
51
- score = language_predictions[0]['score'] * 100
52
-
53
- st.markdown(f"""
54
- <div class="bar">
55
- <div class="bar-fill" style="width: {score}%;"></div>
56
- </div>
57
- <div class="container">
58
- <div style="align-items: left;">{predicted_word}({language})</div>
59
- <div style="align-items: right;">{score:.2f}%</div>
60
- </div>
61
- """, unsafe_allow_html=True)
62
 
63
  if 'predictions' in locals():
64
- if predictions:
65
- for language, language_predictions in predictions.items():
66
- original_sentence = sample_sentences[language]
67
- predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
68
- # st.write(language_predictions)
69
- # st.write(f"Original sentence ({language}): {original_sentence}")
70
- st.write(f"{language}: {predicted_sentence}\n")
71
-
72
 
73
  css = """
74
  <style>
@@ -135,6 +174,7 @@ footer {display:none !important;}
135
  --tw-text-opacity: 1 !important;
136
  color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
137
  }
 
138
  .container {
139
  display: flex;
140
  justify-content: space-between;
@@ -143,7 +183,7 @@ footer {display:none !important;}
143
  width: 100%;
144
  }
145
  .bar {
146
- width: 70%;
147
  background-color: #e6e6e6;
148
  border-radius: 12px;
149
  overflow: hidden;
@@ -155,6 +195,8 @@ footer {display:none !important;}
155
  height: 100%;
156
  border-radius: 12px;
157
  }
 
158
  </style>
159
  """
160
- st.markdown(css, unsafe_allow_html=True)
 
 
3
 
4
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
5
 
6
+ st.set_page_config(layout="wide")
 
 
 
 
 
 
7
 
8
+ def fill_mask(sentences):
9
  results = {}
10
+ warnings = []
11
  for language, sentence in sentences.items():
12
+ if "<mask>" in sentence:
13
+ masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
14
+ unmasked = unmasker(masked_sentence)
15
+ results[language] = unmasked
16
+ else:
17
+ warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
18
+ return results, warnings
19
 
20
  def replace_mask(sentence, predicted_word):
21
  return sentence.replace("<mask>", f"**{predicted_word}**")
22
 
23
+ st.title("Fill Mask | Zabantu-XLM-Roberta")
24
  st.write(f"")
25
 
26
+ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. These include: Zabantu-VEN, Zabantu-NSO, Zabantu-NSO+VEN, Zabantu-SOT+VEN, Zabantu-BANTU(from 9 South African Bantu languages)")
27
+
28
  col1, col2 = st.columns(2)
29
 
30
+ if 'text_input' not in st.session_state:
31
+ st.session_state['text_input'] = ""
32
+
33
+ if 'warnings' not in st.session_state:
34
+ st.session_state['warnings'] = []
35
+
36
+ with col1:
37
+ with st.container(border=True):
38
+ st.markdown("Input :clipboard:")
39
+ sample_sentence = {
40
+ 'zulu': "Le ndoda ithi izo <mask> ukudla.",
41
+ 'tshivenda': "Mufana uyo <mask> vhukuma.",
42
+ 'sepedi': "Mosadi o <mask> pheka.",
43
+ 'tswana': "Monna o <mask> tsamaya.",
44
+ 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
45
+ }
46
+
47
+ text_input = st.text_area(
48
+ "Enter sentences with <mask> token:",
49
+ value=st.session_state['text_input']
50
+ )
51
+
52
+ input_sentences = text_input.split("\n")
53
+
54
+ button1, button2, _ = st.columns([2, 2, 4])
55
+ with button1:
56
+ if st.button("Test Example"):
57
+ user_sentence = f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
58
+ user_masked_sentence = user_sentence.replace('<mask>', unmasker.tokenizer.mask_token)
59
+ # st.rerun()
60
+ # result, warnings = fill_mask(sample_sentence.split("\n"))
61
+ # st.session_state['text_input'] = sample_sentence
62
+
63
+ with button2:
64
+ if st.button("Submit"):
65
+ user_masked_sentence = input_sentences.replace('<mask>', unmasker.tokenizer.mask_token)
66
+ # result, warnings = fill_mask(input_sentences)
67
+ # st.session_state['warnings'] = warnings
68
+
69
+ if st.session_state['warnings']:
70
+ for warning in st.session_state['warnings']:
71
+ st.warning(warning)
72
 
73
+ st.markdown("Example")
74
+ st.code(sample_sentence, wrap_lines=True)
75
 
76
  with col2:
77
+ with st.container(border=True):
78
+ st.markdown("Output :bar_chart:")
79
+ if 'user_masked_sentence' in locals():
80
+ if user_masked_sentence:
81
+ user_predictions = unmasker(user_masked_sentence)
82
+
83
+ # st.write(user_predictions)
84
 
85
+ if len(user_predictions) > 0:
86
+ # st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
87
+
88
+ predictions = fill_mask_for_languages(sample_sentences)
89
+ for language, language_predictions in predictions.items():
90
+ predicted_word = language_predictions[0]['token_str']
91
+ score = language_predictions[0]['score'] * 100
92
+
93
+ st.markdown(f"""
94
+ <div class="bar">
95
+ <div class="bar-fill" style="width: {score}%;"></div>
96
+ </div>
97
+ <div class="container">
98
+ <div style="align-items: left;">{predicted_word}({language})</div>
99
+ <div style="align-items: right;">{score:.2f}%</div>
100
+ </div>
101
+ """, unsafe_allow_html=True)
 
 
102
 
103
  if 'predictions' in locals():
104
+ if predictions:
105
+ for language, language_predictions in predictions.items():
106
+ original_sentence = sample_sentences[language]
107
+ predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
108
+ # st.write(language_predictions)
109
+ # st.write(f"Original sentence ({language}): {original_sentence}")
110
+ st.write(f"{language}: {predicted_sentence}\n")
 
111
 
112
  css = """
113
  <style>
 
174
  --tw-text-opacity: 1 !important;
175
  color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
176
  }
177
+
178
  .container {
179
  display: flex;
180
  justify-content: space-between;
 
183
  width: 100%;
184
  }
185
  .bar {
186
+ # width: 70%;
187
  background-color: #e6e6e6;
188
  border-radius: 12px;
189
  overflow: hidden;
 
195
  height: 100%;
196
  border-radius: 12px;
197
  }
198
+
199
  </style>
200
  """
201
+
202
+ st.markdown(css, unsafe_allow_html=True)