sagawa commited on
Commit
cf8038a
·
verified ·
1 Parent(s): 21bf012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -36,12 +36,11 @@ st.download_button(
36
  class CFG():
37
  uploaded_file = st.file_uploader("Choose a CSV file")
38
  data = st.text_area(display_text)
39
- pretrained_model_name_or_path = 'sagawa/ZINC-t5'
40
  model = 't5'
41
- model_name_or_path = './'
42
- max_len = 512
43
  batch_size = 5
44
- fc_dropout = 0.1
45
  seed = 42
46
  num_workers=1
47
 
@@ -206,7 +205,7 @@ if st.button('predict'):
206
  if CFG.uploaded_file is not None:
207
  test_ds = pd.read_csv(CFG.data)
208
  if "input" not in test_ds.columns:
209
- test_ds = preprocess(test_ds, CFG)
210
  else:
211
  test_ds = pd.DataFrame.from_dict({"input": [CFG.data]}, orient="index").T
212
 
 
36
  class CFG():
37
  uploaded_file = st.file_uploader("Choose a CSV file")
38
  data = st.text_area(display_text)
 
39
  model = 't5'
40
+ model_name_or_path = 'sagawa/ReactionT5v2-yield'
41
+ max_len = 400
42
  batch_size = 5
43
+ fc_dropout = 0.0
44
  seed = 42
45
  num_workers=1
46
 
 
205
  if CFG.uploaded_file is not None:
206
  test_ds = pd.read_csv(CFG.data)
207
  if "input" not in test_ds.columns:
208
+ test_ds = preprocess(test_ds)
209
  else:
210
  test_ds = pd.DataFrame.from_dict({"input": [CFG.data]}, orient="index").T
211