UnarineLeo commited on
Commit
15eff6a
1 Parent(s): 91470f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -70
app.py CHANGED
@@ -1,20 +1,18 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
3
 
4
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
5
 
6
  st.set_page_config(layout="wide")
7
 
8
- st.stop_rerun = True
9
-
10
  def fill_mask(sentences):
11
  results = {}
12
  warnings = []
13
- for key, (language, sentence) in sentences.items():
14
  if "<mask>" in sentence:
15
- masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
16
- unmasked = unmasker(masked_sentence)
17
- results[key] = (language, unmasked)
18
  else:
19
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
20
  return results, warnings
@@ -23,7 +21,9 @@ def replace_mask(sentence, predicted_word):
23
  return sentence.replace("<mask>", f"**{predicted_word}**")
24
 
25
  st.title("Fill Mask | Zabantu-XLM-Roberta")
26
- st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
 
 
27
 
28
  col1, col2 = st.columns(2)
29
 
@@ -33,87 +33,165 @@ if 'text_input' not in st.session_state:
33
  if 'warnings' not in st.session_state:
34
  st.session_state['warnings'] = []
35
 
36
- if 'result' not in st.session_state:
37
- st.session_state['result'] = {}
38
-
39
- language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
40
-
41
- input_sentences = {}
42
-
43
  with col1:
44
- with st.container():
45
  st.markdown("Input :clipboard:")
46
 
47
- input1, input2 = st.columns(2)
48
-
49
- for i in range(5):
50
- with input1:
51
- language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
52
- with input2:
53
- sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
54
- if sentence:
55
- input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
56
-
57
- button1, button2, _ = st.columns([2, 2, 4])
58
-
59
- if st.button("Test Example"):
60
- sample_sentences = {
61
- 'zulu': "Le ndoda ithi izo <mask> ukudla.",
62
- 'tshivenda': "Mufana uyo <mask> vhukuma.",
63
- 'sepedi': "Mosadi o <mask> pheka.",
64
- 'tswana': "Monna o <mask> tsamaya.",
65
- 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
66
- }
67
- st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
68
-
69
- if st.button("Submit"):
70
- st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if st.session_state['warnings']:
73
  for warning in st.session_state['warnings']:
74
  st.warning(warning)
75
 
76
  st.markdown("Example")
77
- st.code({
78
- 'zulu': "Le ndoda ithi izo <mask> ukudla.",
79
- 'tshivenda': "Mufana uyo <mask> vhukuma.",
80
- 'sepedi': "Mosadi o <mask> pheka.",
81
- 'tswana': "Monna o <mask> tsamaya.",
82
- 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
83
- }, wrap_lines=True)
84
 
85
  with col2:
86
- with st.container():
87
  st.markdown("Output :bar_chart:")
88
- if st.session_state['result']:
89
- for key, (language, predictions) in st.session_state['result'].items():
90
- original_sentence = input_sentences[key][1] if key in input_sentences else ""
91
- if predictions:
92
- top_prediction = predictions[0]
93
- predicted_word = top_prediction['token_str']
94
- score = top_prediction['score'] * 100
95
-
96
- st.markdown(f"""
97
- <div class="bar">
98
- <div class="bar-fill" style="width: {score}%;"></div>
99
- </div>
100
- <div class="container">
101
- <div style="align-items: left;">{predicted_word} ({language})</div>
102
- <div style="align-items: right;">{score:.2f}%</div>
103
- </div>
104
- """, unsafe_allow_html=True)
105
-
106
- if 'predictions' in locals():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if result:
108
- for language, language_predictions in result.items():
109
- original_sentence = sample_sentence[language]
110
- predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
111
- st.write(f"{language}: {predicted_sentence}\n")
 
 
112
 
113
  css = """
114
  <style>
115
  footer {display:none !important;}
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  .container {
118
  display: flex;
119
  justify-content: space-between;
@@ -122,6 +200,7 @@ footer {display:none !important;}
122
  width: 100%;
123
  }
124
  .bar {
 
125
  background-color: #e6e6e6;
126
  border-radius: 12px;
127
  overflow: hidden;
@@ -133,6 +212,8 @@ footer {display:none !important;}
133
  height: 100%;
134
  border-radius: 12px;
135
  }
 
136
  </style>
137
  """
 
138
  st.markdown(css, unsafe_allow_html=True)
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from io import StringIO
4
 
5
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
6
 
7
  st.set_page_config(layout="wide")
8
 
 
 
9
  def fill_mask(sentences):
10
  results = {}
11
  warnings = []
12
+ for sentence in sentences:
13
  if "<mask>" in sentence:
14
+ unmasked = unmasker(sentence)
15
+ results[sentence] = unmasked
 
16
  else:
17
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
18
  return results, warnings
 
21
  return sentence.replace("<mask>", f"**{predicted_word}**")
22
 
23
  st.title("Fill Mask | Zabantu-XLM-Roberta")
24
+ st.write(f"")
25
+
26
+ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.")
27
 
28
  col1, col2 = st.columns(2)
29
 
 
33
  if 'warnings' not in st.session_state:
34
  st.session_state['warnings'] = []
35
 
 
 
 
 
 
 
 
36
  with col1:
37
+ with st.container(border=True):
38
  st.markdown("Input :clipboard:")
39
 
40
+ select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
41
+ sample_sentence = "Rabulasi wa <mask> u khou bvelela nga u lima."
42
+
43
+ option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ if option_selected == 'Enter text input':
46
+ text_input = st.text_area(
47
+ "Enter sentences with <mask> token(one sentence per line):",
48
+ value=st.session_state['text_input']
49
+ )
50
+
51
+ input_sentences = text_input.split("\n")
52
+
53
+ if st.button("Submit",use_container_width=True):
54
+ result, warnings = fill_mask(input_sentences)
55
+ st.session_state['warnings'] = warnings
56
+
57
+ if option_selected == 'Upload a file(csv/txt)':
58
+
59
+ uploaded_file = st.file_uploader("Choose a file-(one sentence per line)")
60
+ if uploaded_file is not None:
61
+
62
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
63
+ string_data = stringio.read()
64
+
65
+ input_sentences = string_data.split("\n")
66
+
67
+ if st.button("Submit",use_container_width=True):
68
+ result, warnings = fill_mask(input_sentences)
69
+ st.session_state['warnings'] = warnings
70
+
71
  if st.session_state['warnings']:
72
  for warning in st.session_state['warnings']:
73
  st.warning(warning)
74
 
75
  st.markdown("Example")
76
+ st.code(sample_sentence, wrap_lines=True)
77
+ if st.button("Test Example",use_container_width=True):
78
+ result, warnings = fill_mask(sample_sentence.split("\n"))
 
 
 
 
79
 
80
  with col2:
81
+ with st.container(border=True):
82
  st.markdown("Output :bar_chart:")
83
+ if 'result' in locals() and result:
84
+ if len(result) == 1:
85
+ for sentence, predictions in result.items():
86
+ for prediction in predictions:
87
+ predicted_word = prediction['token_str']
88
+ score = prediction['score'] * 100
89
+
90
+ st.markdown(f"""
91
+ <div class="bar">
92
+ <div class="bar-fill" style="width: {score}%;"></div>
93
+ </div>
94
+ <div class="container">
95
+ <div style="align-items: left;">{predicted_word}</div>
96
+ <div style="align-items: center;">{score:.2f}%</div>
97
+ </div>
98
+ """, unsafe_allow_html=True)
99
+
100
+ else:
101
+ index = 0
102
+ for sentence, predictions in result.items():
103
+ index += 1
104
+ if predictions:
105
+ top_prediction = predictions[0]
106
+ predicted_word = top_prediction['token_str']
107
+ score = top_prediction['score'] * 100
108
+
109
+ st.markdown(f"""
110
+ <div class="bar">
111
+ <div class="bar-fill" style="width: {score}%;"></div>
112
+ </div>
113
+ <div class="container">
114
+ <div style="align-items: left;">{predicted_word} (line {index})</div>
115
+ <div style="align-items: right;">{score:.2f}%</div>
116
+ </div>
117
+ """, unsafe_allow_html=True)
118
+
119
+
120
+ if 'result' in locals():
121
  if result:
122
+ line = 0
123
+ for sentence, predictions in result.items():
124
+ line += 1
125
+ predicted_word = predictions[0]['token_str']
126
+ full_sentence = replace_mask(sentence, predicted_word)
127
+ st.write(f"**Sentence {line}:** {full_sentence }")
128
 
129
  css = """
130
  <style>
131
  footer {display:none !important;}
132
 
133
+ .gr-button-primary {
134
+ z-index: 14;
135
+ height: 43px;
136
+ width: 130px;
137
+ left: 0px;
138
+ top: 0px;
139
+ padding: 0px;
140
+ cursor: pointer !important;
141
+ background: none rgb(17, 20, 45) !important;
142
+ border: none !important;
143
+ text-align: center !important;
144
+ font-family: Poppins !important;
145
+ font-size: 14px !important;
146
+ font-weight: 500 !important;
147
+ color: rgb(255, 255, 255) !important;
148
+ line-height: 1 !important;
149
+ border-radius: 12px !important;
150
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
151
+ box-shadow: none !important;
152
+ }
153
+ .gr-button-primary:hover{
154
+ z-index: 14;
155
+ height: 43px;
156
+ width: 130px;
157
+ left: 0px;
158
+ top: 0px;
159
+ padding: 0px;
160
+ cursor: pointer !important;
161
+ background: none rgb(66, 133, 244) !important;
162
+ border: none !important;
163
+ text-align: center !important;
164
+ font-family: Poppins !important;
165
+ font-size: 14px !important;
166
+ font-weight: 500 !important;
167
+ color: rgb(255, 255, 255) !important;
168
+ line-height: 1 !important;
169
+ border-radius: 12px !important;
170
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
171
+ box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
172
+ }
173
+ .hover\:bg-orange-50:hover {
174
+ --tw-bg-opacity: 1 !important;
175
+ background-color: rgb(229,225,255) !important;
176
+ }
177
+ .to-orange-200 {
178
+ --tw-gradient-to: rgb(37 56 133 / 37%) !important;
179
+ }
180
+ .from-orange-400 {
181
+ --tw-gradient-from: rgb(17, 20, 45) !important;
182
+ --tw-gradient-to: rgb(255 150 51 / 0);
183
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
184
+ }
185
+ .group-hover\:from-orange-500{
186
+ --tw-gradient-from:rgb(17, 20, 45) !important;
187
+ --tw-gradient-to: rgb(37 56 133 / 37%);
188
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
189
+ }
190
+ .group:hover .group-hover\:text-orange-500{
191
+ --tw-text-opacity: 1 !important;
192
+ color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
193
+ }
194
+
195
  .container {
196
  display: flex;
197
  justify-content: space-between;
 
200
  width: 100%;
201
  }
202
  .bar {
203
+ # width: 70%;
204
  background-color: #e6e6e6;
205
  border-radius: 12px;
206
  overflow: hidden;
 
212
  height: 100%;
213
  border-radius: 12px;
214
  }
215
+
216
  </style>
217
  """
218
+
219
  st.markdown(css, unsafe_allow_html=True)