alex6095 commited on
Commit
5ef0c66
β€’
1 Parent(s): 4e2e011

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -95,7 +95,7 @@ if name == 'Topic Classification':
95
  col2.bar_chart(chart_data)
96
 
97
  elif name == 'Date Prediction':
98
- st.markdown("## Predict Date")
99
  if text:
100
  with st.spinner('processing..'):
101
  text = RegexSubstitution(r'\([^()]+\)|[<>\'"β–³β–²β–‘β– ]')(text)
@@ -104,13 +104,13 @@ elif name == 'Date Prediction':
104
  raw_input_ids + [tokenizer.eos_token_id]
105
  outputs = model.generate(torch.tensor([input_ids]),
106
  early_stopping=True,
107
- repetition_penalty=2.0,
108
  do_sample=True, #μƒ˜ν”Œλ§ μ „λž΅ μ‚¬μš©
109
  max_length=50, # μ΅œλŒ€ λ””μ½”λ”© κΈΈμ΄λŠ” 50
110
  top_k=50, # ν™•λ₯  μˆœμœ„κ°€ 50μœ„ 밖인 토큰은 μƒ˜ν”Œλ§μ—μ„œ μ œμ™Έ
111
  top_p=0.95, # λˆ„μ  ν™•λ₯ μ΄ 95%인 ν›„λ³΄μ§‘ν•©μ—μ„œλ§Œ 생성
112
  num_return_sequences=3 #3개의 κ²°κ³Όλ₯Ό λ””μ½”λ”©ν•΄λ‚Έλ‹€
113
  )
 
114
  for output in outputs:
115
- pred_print = tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
116
- st.write(pred_print)
 
95
  col2.bar_chart(chart_data)
96
 
97
  elif name == 'Date Prediction':
98
+ st.markdown("## Predict 3 possible Date")
99
  if text:
100
  with st.spinner('processing..'):
101
  text = RegexSubstitution(r'\([^()]+\)|[<>\'"β–³β–²β–‘β– ]')(text)
 
104
  raw_input_ids + [tokenizer.eos_token_id]
105
  outputs = model.generate(torch.tensor([input_ids]),
106
  early_stopping=True,
 
107
  do_sample=True, #μƒ˜ν”Œλ§ μ „λž΅ μ‚¬μš©
108
  max_length=50, # μ΅œλŒ€ λ””μ½”λ”© κΈΈμ΄λŠ” 50
109
  top_k=50, # ν™•λ₯  μˆœμœ„κ°€ 50μœ„ 밖인 토큰은 μƒ˜ν”Œλ§μ—μ„œ μ œμ™Έ
110
  top_p=0.95, # λˆ„μ  ν™•λ₯ μ΄ 95%인 ν›„λ³΄μ§‘ν•©μ—μ„œλ§Œ 생성
111
  num_return_sequences=3 #3개의 κ²°κ³Όλ₯Ό λ””μ½”λ”©ν•΄λ‚Έλ‹€
112
  )
113
+ pred_print = []
114
  for output in outputs:
115
+ pred_print.append(tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True))
116
+ st.write(", ".join(pred_print))