UnarineLeo commited on
Commit
14d6210
1 Parent(s): ec6cd8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
3
 
4
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-ven-120m')
5
 
@@ -35,26 +36,37 @@ if 'warnings' not in st.session_state:
35
  with col1:
36
  with st.container(border=True):
37
  st.markdown("Input :clipboard:")
 
 
38
  sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."
39
-
40
- text_input = st.text_area(
41
- "Enter sentences with <mask> token:",
42
- value=st.session_state['text_input']
43
- )
44
-
45
- input_sentences = text_input.split("\n")
46
-
47
- button1, button2, _ = st.columns([2, 2, 4])
48
- with button1:
49
- if st.button("Test Example"):
50
- # st.rerun()
51
- result, warnings = fill_mask(sample_sentence.split("\n"))
52
- # st.session_state['text_input'] = sample_sentence
53
 
54
- with button2:
55
- if st.button("Submit"):
 
56
  result, warnings = fill_mask(input_sentences)
57
  st.session_state['warnings'] = warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  if st.session_state['warnings']:
60
  for warning in st.session_state['warnings']:
@@ -62,12 +74,14 @@ with col1:
62
 
63
  st.markdown("Example")
64
  st.code(sample_sentence, wrap_lines=True)
 
 
65
 
66
  with col2:
67
  with st.container(border=True):
68
  st.markdown("Output :bar_chart:")
69
  if 'result' in locals() and result:
70
- if result:
71
  for sentence, predictions in result.items():
72
  for prediction in predictions:
73
  predicted_word = prediction['token_str']
@@ -83,12 +97,34 @@ with col2:
83
  </div>
84
  """, unsafe_allow_html=True)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if 'result' in locals():
87
- if result:
88
- for sentence, predictions in result.items():
89
- predicted_word = predictions[0]['token_str']
90
- full_sentence = replace_mask(sentence, predicted_word)
91
- st.write(f"**Sentence:** {full_sentence }")
 
 
92
 
93
  css = """
94
  <style>
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from io import StringIO
4
 
5
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-ven-120m')
6
 
 
36
  with col1:
37
  with st.container(border=True):
38
  st.markdown("Input :clipboard:")
39
+
40
+ select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
41
  sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."
42
+
43
+ option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
44
+
45
+ if option_selected == 'Enter text input':
46
+ text_input = st.text_area(
47
+ "Enter sentences with <mask> token:",
48
+ value=st.session_state['text_input']
49
+ )
 
 
 
 
 
 
50
 
51
+ input_sentences = text_input.split("\n")
52
+
53
+ if st.button("Submit",use_container_width=True):
54
  result, warnings = fill_mask(input_sentences)
55
  st.session_state['warnings'] = warnings
56
+
57
+ if option_selected == 'Upload a file(csv/txt)':
58
+
59
+ uploaded_file = st.file_uploader("Choose a file")
60
+ if uploaded_file is not None:
61
+
62
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
63
+ string_data = stringio.read()
64
+
65
+ input_sentences = string_data.split("\n")
66
+
67
+ if st.button("Submit",use_container_width=True):
68
+ result, warnings = fill_mask(input_sentences)
69
+ st.session_state['warnings'] = warnings
70
 
71
  if st.session_state['warnings']:
72
  for warning in st.session_state['warnings']:
 
74
 
75
  st.markdown("Example")
76
  st.code(sample_sentence, wrap_lines=True)
77
+ if st.button("Test Example",use_container_width=True):
78
+ result, warnings = fill_mask(sample_sentence.split("\n"))
79
 
80
  with col2:
81
  with st.container(border=True):
82
  st.markdown("Output :bar_chart:")
83
  if 'result' in locals() and result:
84
+ if len(result) == 1:
85
  for sentence, predictions in result.items():
86
  for prediction in predictions:
87
  predicted_word = prediction['token_str']
 
97
  </div>
98
  """, unsafe_allow_html=True)
99
 
100
+ else:
101
+ index = 0
102
+ for sentence, predictions in result.items():
103
+ index += 1
104
+ if predictions:
105
+ top_prediction = predictions[0]
106
+ predicted_word = top_prediction['token_str']
107
+ score = top_prediction['score'] * 100
108
+
109
+ st.markdown(f"""
110
+ <div class="bar">
111
+ <div class="bar-fill" style="width: {score}%;"></div>
112
+ </div>
113
+ <div class="container">
114
+ <div style="align-items: left;">{predicted_word} (line {index})</div>
115
+ <div style="align-items: right;">{score:.2f}%</div>
116
+ </div>
117
+ """, unsafe_allow_html=True)
118
+
119
+
120
  if 'result' in locals():
121
+ if result:
122
+ line = 0
123
+ for sentence, predictions in result.items():
124
+ line += 1
125
+ predicted_word = predictions[0]['token_str']
126
+ full_sentence = replace_mask(sentence, predicted_word)
127
+ st.write(f"**Sentence {line}:** {full_sentence }")
128
 
129
  css = """
130
  <style>