UnarineLeo commited on
Commit
e28273f
1 Parent(s): 27c53ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -31
app.py CHANGED
@@ -21,8 +21,6 @@ def replace_mask(sentence, predicted_word):
21
  return sentence.replace("<mask>", f"**{predicted_word}**")
22
 
23
  st.title("Fill Mask | Zabantu-XLM-Roberta")
24
- st.write("")
25
-
26
  st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
27
 
28
  col1, col2 = st.columns(2)
@@ -33,56 +31,63 @@ if 'text_input' not in st.session_state:
33
  if 'warnings' not in st.session_state:
34
  st.session_state['warnings'] = []
35
 
 
 
36
  with col1:
37
- with st.container(border=True):
38
- st.markdown("Input :clipboard:")
39
- sample_sentence = {
40
- 'zulu': "Le ndoda ithi izo <mask> ukudla.",
41
- 'tshivenda': "Mufana uyo <mask> vhukuma.",
42
- 'sepedi': "Mosadi o <mask> pheka.",
43
- 'tswana': "Monna o <mask> tsamaya.",
44
- 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
45
- }
46
-
47
- text_input = st.text_area(
48
- "Enter sentences with <mask> token:",
49
- value=st.session_state['text_input']
50
- )
51
-
52
- input_sentences = {f"sentence_{i}": sentence for i, sentence in enumerate(text_input.split("\n")) if sentence}
53
 
54
  button1, button2, _ = st.columns([2, 2, 4])
55
  with button1:
56
  if st.button("Test Example"):
 
 
 
 
 
 
 
57
  input_sentences = sample_sentence
58
- result, warnings = fill_mask(input_sentences)
59
- # st.session_state['text_input'] = "\n".join([f"{lang}: {sentence}" for lang, sentence in sample_sentence.items()])
60
 
61
  with button2:
62
  if st.button("Submit"):
63
  result, warnings = fill_mask(input_sentences)
64
- st.session_state['warnings'] = warnings
65
-
66
  if st.session_state['warnings']:
67
  for warning in st.session_state['warnings']:
68
  st.warning(warning)
69
 
70
- st.markdown("Example")
71
- st.code(sample_sentence, wrap_lines=True)
 
 
 
 
 
 
72
 
73
  with col2:
74
- with st.container(border=True):
75
- st.markdown("Output :bar_chart:")
76
  if input_sentences:
77
  for language, sentence in input_sentences.items():
78
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
79
  predictions = unmasker(masked_sentence)
80
-
81
  if predictions:
82
  top_prediction = predictions[0]
83
  predicted_word = top_prediction['token_str']
84
  score = top_prediction['score'] * 100
85
-
86
  st.markdown(f"""
87
  <div class="bar">
88
  <div class="bar-fill" style="width: {score}%;"></div>
@@ -96,8 +101,8 @@ with col2:
96
  if 'predictions' in locals():
97
  if result:
98
  for language, language_predictions in result.items():
99
- original_sentence = sample_sentence[language]
100
- predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
101
  st.write(f"{language}: {predicted_sentence}\n")
102
 
103
  css = """
@@ -125,5 +130,4 @@ footer {display:none !important;}
125
  }
126
  </style>
127
  """
128
-
129
  st.markdown(css, unsafe_allow_html=True)
 
21
  return sentence.replace("<mask>", f"**{predicted_word}**")
22
 
23
  st.title("Fill Mask | Zabantu-XLM-Roberta")
 
 
24
  st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.")
25
 
26
  col1, col2 = st.columns(2)
 
31
  if 'warnings' not in st.session_state:
32
  st.session_state['warnings'] = []
33
 
34
+ language_options = ['Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
35
+
36
  with col1:
37
+ with st.container():
38
+ st.markdown("### Input :clipboard:")
39
+
40
+ input_sentences = {}
41
+ for i in range(5):
42
+ language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
43
+ sentence = st.text_area(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
44
+ if sentence:
45
+ input_sentences[language.lower()] = sentence
 
 
 
 
 
 
 
46
 
47
  button1, button2, _ = st.columns([2, 2, 4])
48
  with button1:
49
  if st.button("Test Example"):
50
+ sample_sentence = {
51
+ 'zulu': "Le ndoda ithi izo <mask> ukudla.",
52
+ 'tshivenda': "Mufana uyo <mask> vhukuma.",
53
+ 'sepedi': "Mosadi o <mask> pheka.",
54
+ 'tswana': "Monna o <mask> tsamaya.",
55
+ 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
56
+ }
57
  input_sentences = sample_sentence
58
+ result, warnings = fill_mask(input_sentences)
 
59
 
60
  with button2:
61
  if st.button("Submit"):
62
  result, warnings = fill_mask(input_sentences)
63
+ st.session_state['warnings'] = warnings
64
+
65
  if st.session_state['warnings']:
66
  for warning in st.session_state['warnings']:
67
  st.warning(warning)
68
 
69
+ st.markdown("### Example")
70
+ st.code({
71
+ 'zulu': "Le ndoda ithi izo <mask> ukudla.",
72
+ 'tshivenda': "Mufana uyo <mask> vhukuma.",
73
+ 'sepedi': "Mosadi o <mask> pheka.",
74
+ 'tswana': "Monna o <mask> tsamaya.",
75
+ 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
76
+ }, wrap_lines=True)
77
 
78
  with col2:
79
+ with st.container():
80
+ st.markdown("### Output :bar_chart:")
81
  if input_sentences:
82
  for language, sentence in input_sentences.items():
83
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
84
  predictions = unmasker(masked_sentence)
85
+
86
  if predictions:
87
  top_prediction = predictions[0]
88
  predicted_word = top_prediction['token_str']
89
  score = top_prediction['score'] * 100
90
+
91
  st.markdown(f"""
92
  <div class="bar">
93
  <div class="bar-fill" style="width: {score}%;"></div>
 
101
  if 'predictions' in locals():
102
  if result:
103
  for language, language_predictions in result.items():
104
+ original_sentence = input_sentences[language]
105
+ predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
106
  st.write(f"{language}: {predicted_sentence}\n")
107
 
108
  css = """
 
130
  }
131
  </style>
132
  """
 
133
  st.markdown(css, unsafe_allow_html=True)