UnarineLeo commited on
Commit
cff14a8
1 Parent(s): 6cbe2dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -80
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import pipeline
3
  from io import StringIO
4
 
5
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
6
- st.stop_rerun = True
7
  st.set_page_config(layout="wide")
8
 
9
  def fill_mask(sentences):
@@ -21,101 +21,227 @@ def fill_mask(sentences):
21
  def replace_mask(sentence, predicted_word):
22
  return sentence.replace("<mask>", f"**{predicted_word}**")
23
 
24
- # Set up title and description
25
  st.title("Fill Mask | Zabantu-XLM-Roberta")
26
- st.markdown("Zabantu-XLMR refers to a fleet of models trained on South African Bantu languages...")
27
 
28
- # Initialize session state
29
- if 'warnings' not in st.session_state:
30
- st.session_state['warnings'] = []
31
- if 'results' not in st.session_state:
32
- st.session_state['results'] = {}
33
 
34
- # Define layout
35
  col1, col2 = st.columns(2)
36
 
37
- with col1:
38
- st.markdown("### Input :clipboard:")
39
 
40
- select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
41
- sample_sentence = {
42
- 'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
43
- "tsonga": "N'wana wa xisati u <mask> ku tsaka."
44
- }
45
 
46
- option_selected = st.selectbox("Select an input option:", select_options, index=0)
47
- input_sentences = {}
 
48
 
49
- if option_selected == 'Enter text input':
50
- st.session_state['warnings'].clear() # Clear warnings before new input
 
 
51
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
52
 
53
- for i in range(5):
54
- language = st.selectbox(f"Select language for input {i+1}:", language_options, key=f'language_{i}')
55
- sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'sentence_{i}')
 
 
 
 
 
 
56
 
57
- # Only process filled language and sentence pairs
58
- if language != 'Choose language' and sentence:
59
- input_sentences[language.lower()] = sentence
60
-
61
- if st.button("Submit"):
62
- if input_sentences:
63
- results, warnings = fill_mask(input_sentences)
64
- st.session_state['results'] = results
65
- st.session_state['warnings'] = warnings
66
- else:
67
- st.warning("Please fill at least one language and sentence.")
68
-
69
- elif option_selected == 'Upload a file(csv/txt)':
70
- uploaded_file = st.file_uploader("Choose a file (one sentence per line)")
71
- if uploaded_file:
72
- stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
73
- sentences = stringio.read().splitlines()
74
-
75
- for i, sentence in enumerate(sentences):
76
- # Here, you might need to define how to assign a language to each sentence
77
- # Assuming all sentences are in the same language for simplicity
78
- input_sentences[f'input_{i+1}'] = sentence
79
-
80
- if st.button("Submit"):
81
- results, warnings = fill_mask(input_sentences)
82
- st.session_state['results'] = results
83
- st.session_state['warnings'] = warnings
84
-
85
- st.markdown("### Example")
86
- st.code(sample_sentence)
87
- if st.button("Test Example"):
88
- result, warnings = fill_mask(sample_sentence)
89
- st.session_state['results'] = result
90
- st.session_state['warnings'] = warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  with col2:
93
- st.markdown("### Output :bar_chart:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- if st.session_state['results']:
96
- # Use st.fragment for dynamic content
97
- with st.container():
98
- for language, predictions in st.session_state['results'].items():
99
- if predictions:
100
- top_prediction = predictions[0]
101
- predicted_word = top_prediction['token_str']
102
- score = top_prediction['score'] * 100
103
-
104
- # Displaying the prediction with fragment
105
- st.markdown(f"**{language.capitalize()} Prediction:** {predicted_word} ({score:.2f}%)")
106
- st.markdown(f"<div class='bar'><div class='bar-fill' style='width:{score}%;'></div></div>", unsafe_allow_html=True)
107
-
108
- if st.session_state['warnings']:
109
- for warning in st.session_state['warnings']:
110
- st.warning(warning)
111
-
112
- # CSS for styling
 
 
113
  css = """
114
  <style>
115
  footer {display:none !important;}
116
- .bar {width: 70%; background-color: #e6e6e6; border-radius: 12px; height: 5px;}
117
- .bar-fill {background-color: #17152e; height: 100%; border-radius: 12px;}
118
- .container {display: flex; justify-content: space-between; align-items: center; margin-bottom: 5px;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  </style>
120
  """
121
- st.markdown(css, unsafe_allow_html=True)
 
 
3
  from io import StringIO
4
 
5
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
6
+
7
  st.set_page_config(layout="wide")
8
 
9
  def fill_mask(sentences):
 
21
  def replace_mask(sentence, predicted_word):
22
  return sentence.replace("<mask>", f"**{predicted_word}**")
23
 
 
24
  st.title("Fill Mask | Zabantu-XLM-Roberta")
25
+ st.write(f"")
26
 
27
+ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.")
 
 
 
 
28
 
 
29
  col1, col2 = st.columns(2)
30
 
31
+ if 'text_input' not in st.session_state:
32
+ st.session_state['text_input'] = ""
33
 
34
+ if 'warnings' not in st.session_state:
35
+ st.session_state['warnings'] = []
 
 
 
36
 
37
+ with col1:
38
+ with st.container(border=True):
39
+ st.markdown("Input :clipboard:")
40
 
41
+ select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
42
+ sample_sentence = {'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
43
+ "tsonga": "N'wana wa xisati u <mask> ku tsaka."
44
+ }
45
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
46
 
47
+ option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
48
+ input_sentences = {}
49
+
50
+ if option_selected == 'Enter text input':
51
+ st.session_state['warnings'].clear()
52
+ @st.fragment
53
+ def choose_language(i):
54
+ language = st.selectbox(f"Select language for input {i+1}:",
55
+ language_options, key=f'language_{i}', index=0)
56
 
57
+ input1, input2 = st.columns(2)
58
+ for i in range(5):
59
+ with input1:
60
+ language = choose_language(i)
61
+ st.write(f"lang : {language}")
62
+ with input2:
63
+ sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'text_input_{i}')
64
+ if sentence:
65
+ if language:
66
+ input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
67
+ else:
68
+ warnings = []
69
+ warnings.append(f"Warning: Choose the language for input {i+1}")
70
+ st.session_state['warnings'] = warnings
71
+
72
+ if st.button("Submit",use_container_width=True):
73
+ if st.session_state['warnings']:
74
+ # will print in next output
75
+ else:
76
+ result, warnings = fill_mask(input_sentences)
77
+ st.session_state['warnings'] = warnings
78
+
79
+ if st.session_state['warnings']:
80
+ for warning in st.session_state['warnings']:
81
+ st.warning(warning)
82
+
83
+ if option_selected == 'Upload a file(csv/txt)':
84
+
85
+ uploaded_file = st.file_uploader("Choose a file-(one sentence per line)")
86
+ if uploaded_file is not None:
87
+
88
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
89
+ string_data = stringio.read()
90
+
91
+ input_sentences = string_data.split("\n")
92
+
93
+ if st.button("Submit",use_container_width=True):
94
+ if st.session_state['warnings']:
95
+ for warning in st.session_state['warnings']:
96
+ st.warning(warning)
97
+ else:
98
+ result, warnings = fill_mask(input_sentences)
99
+ st.session_state['warnings'] = warnings
100
+
101
+ if st.session_state['warnings']:
102
+ for warning in st.session_state['warnings']:
103
+ st.warning(warning)
104
+
105
+ st.markdown("Example")
106
+ st.code(sample_sentence, wrap_lines=True)
107
+ if st.button("Test Example",use_container_width=True):
108
+ result, warnings = fill_mask(sample_sentence)
109
 
110
  with col2:
111
+ with st.container(border=True):
112
+ st.markdown("Output :bar_chart:")
113
+ if 'result' in locals() and result:
114
+ if len(result) == 1:
115
+ for language, predictions in result.items():
116
+ for prediction in predictions:
117
+ predicted_word = prediction['token_str']
118
+ score = prediction['score'] * 100
119
+
120
+ st.markdown(f"""
121
+ <div class="bar">
122
+ <div class="bar-fill" style="width: {score}%;"></div>
123
+ </div>
124
+ <div class="container">
125
+ <div style="align-items: left;">{predicted_word}</div>
126
+ <div style="align-items: center;">{score:.2f}%</div>
127
+ </div>
128
+ """, unsafe_allow_html=True)
129
+
130
+ else:
131
+ for language, predictions in result.items():
132
+ if predictions:
133
+ top_prediction = predictions[0]
134
+ predicted_word = top_prediction['token_str']
135
+ score = top_prediction['score'] * 100
136
 
137
+ st.markdown(f"""
138
+ <div class="bar">
139
+ <div class="bar-fill" style="width: {score}%;"></div>
140
+ </div>
141
+ <div class="container">
142
+ <div style="align-items: left;">{predicted_word} ({language})</div>
143
+ <div style="align-items: right;">{score:.2f}%</div>
144
+ </div>
145
+ """, unsafe_allow_html=True)
146
+
147
+
148
+ if 'result' in locals():
149
+ if result:
150
+ line = 0
151
+ for sentence, predictions in result.items():
152
+ line += 1
153
+ predicted_word = predictions[0]['token_str']
154
+ full_sentence = replace_mask(sentence, predicted_word)
155
+ st.write(f"**Sentence {line}:** {full_sentence }")
156
+
157
  css = """
158
  <style>
159
  footer {display:none !important;}
160
+
161
+ .gr-button-primary {
162
+ z-index: 14;
163
+ height: 43px;
164
+ width: 130px;
165
+ left: 0px;
166
+ top: 0px;
167
+ padding: 0px;
168
+ cursor: pointer !important;
169
+ background: none rgb(17, 20, 45) !important;
170
+ border: none !important;
171
+ text-align: center !important;
172
+ font-family: Poppins !important;
173
+ font-size: 14px !important;
174
+ font-weight: 500 !important;
175
+ color: rgb(255, 255, 255) !important;
176
+ line-height: 1 !important;
177
+ border-radius: 12px !important;
178
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
179
+ box-shadow: none !important;
180
+ }
181
+ .gr-button-primary:hover{
182
+ z-index: 14;
183
+ height: 43px;
184
+ width: 130px;
185
+ left: 0px;
186
+ top: 0px;
187
+ padding: 0px;
188
+ cursor: pointer !important;
189
+ background: none rgb(66, 133, 244) !important;
190
+ border: none !important;
191
+ text-align: center !important;
192
+ font-family: Poppins !important;
193
+ font-size: 14px !important;
194
+ font-weight: 500 !important;
195
+ color: rgb(255, 255, 255) !important;
196
+ line-height: 1 !important;
197
+ border-radius: 12px !important;
198
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
199
+ box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
200
+ }
201
+ .hover\:bg-orange-50:hover {
202
+ --tw-bg-opacity: 1 !important;
203
+ background-color: rgb(229,225,255) !important;
204
+ }
205
+ .to-orange-200 {
206
+ --tw-gradient-to: rgb(37 56 133 / 37%) !important;
207
+ }
208
+ .from-orange-400 {
209
+ --tw-gradient-from: rgb(17, 20, 45) !important;
210
+ --tw-gradient-to: rgb(255 150 51 / 0);
211
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
212
+ }
213
+ .group-hover\:from-orange-500{
214
+ --tw-gradient-from:rgb(17, 20, 45) !important;
215
+ --tw-gradient-to: rgb(37 56 133 / 37%);
216
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
217
+ }
218
+ .group:hover .group-hover\:text-orange-500{
219
+ --tw-text-opacity: 1 !important;
220
+ color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
221
+ }
222
+
223
+ .container {
224
+ display: flex;
225
+ justify-content: space-between;
226
+ align-items: center;
227
+ margin-bottom: 5px;
228
+ width: 100%;
229
+ }
230
+ .bar {
231
+ # width: 70%;
232
+ background-color: #e6e6e6;
233
+ border-radius: 12px;
234
+ overflow: hidden;
235
+ margin-right: 10px;
236
+ height: 5px;
237
+ }
238
+ .bar-fill {
239
+ background-color: #17152e;
240
+ height: 100%;
241
+ border-radius: 12px;
242
+ }
243
+
244
  </style>
245
  """
246
+
247
+ st.markdown(css, unsafe_allow_html=True)