sagawa commited on
Commit
3aff373
·
1 Parent(s): 94883b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -22
app.py CHANGED
@@ -21,6 +21,7 @@ st.markdown('### If there are no catalyst or reagent, fill the blank with a spac
21
  display_text = 'input the reaction smiles (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)'
22
 
23
  class CFG():
 
24
  input_data = st.text_area(display_text)
25
  model_name_or_path = 'sagawa/ZINC-t5-productpredicition'
26
  model = 't5'
@@ -48,25 +49,66 @@ if CFG.model == 't5':
48
  elif CFG.model == 'deberta':
49
  model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device)
50
 
51
- input_compound = CFG.input_data
52
- min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
53
- inp = tokenizer(input_compound, return_tensors='pt').to(device)
54
- output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
55
- scores = output['sequences_scores'].tolist()
56
- output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
57
- for ith, out in enumerate(output):
58
- mol = Chem.MolFromSmiles(out.rstrip('.'))
59
- if type(mol) == rdkit.Chem.rdchem.Mol:
60
- output.append(out.rstrip('.'))
61
- scores.append(scores[ith])
62
- break
63
- if type(mol) == None:
64
- output.append(None)
65
- scores.append(None)
66
- output += scores
67
- output = [input_compound] + output
68
- try:
69
- output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
70
- st.table(output_df)
71
- except:
72
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  display_text = 'input the reaction smiles (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)'
22
 
23
  class CFG():
24
+ uploaded_file = st.file_uploader("Choose a CSV file")
25
  input_data = st.text_area(display_text)
26
  model_name_or_path = 'sagawa/ZINC-t5-productpredicition'
27
  model = 't5'
 
49
  elif CFG.model == 'deberta':
50
  model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device)
51
 
52
+
53
+ if CFG.uploaded_file is not None:
54
+ input_data = pd.read_csv(CFG.uploaded_file)
55
+ outputs = []
56
+ for idx, row in input_data.iterrows():
57
+ input_compound = row['input']
58
+ min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
59
+ inp = tokenizer(input_compound, return_tensors='pt').to(device)
60
+ output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
61
+ scores = output['sequences_scores'].tolist()
62
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
63
+ for ith, out in enumerate(output):
64
+ mol = Chem.MolFromSmiles(out.rstrip('.'))
65
+ if type(mol) == rdkit.Chem.rdchem.Mol:
66
+ output.append(out.rstrip('.'))
67
+ scores.append(scores[ith])
68
+ break
69
+ if type(mol) == None:
70
+ output.append(None)
71
+ scores.append(None)
72
+ output += scores
73
+ output = [input_compound] + output
74
+ outputs.append(output)
75
+
76
+ output_df = pd.DataFrame(outputs, columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
77
+
78
+ @st.cache
79
+ def convert_df(df):
80
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
81
+ return df.to_csv(index=False).encode('utf-8')
82
+
83
+ output_df = convert_df(output_df)
84
+
85
+ st.download_button(
86
+ label="Download data as CSV",
87
+ data=output_df,
88
+ file_name=input_data + '_result.csv',
89
+ mime='text/csv',
90
+ )
91
+
92
+ else:
93
+ input_compound = CFG.input_data
94
+ min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
95
+ inp = tokenizer(input_compound, return_tensors='pt').to(device)
96
+ output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
97
+ scores = output['sequences_scores'].tolist()
98
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
99
+ for ith, out in enumerate(output):
100
+ mol = Chem.MolFromSmiles(out.rstrip('.'))
101
+ if type(mol) == rdkit.Chem.rdchem.Mol:
102
+ output.append(out.rstrip('.'))
103
+ scores.append(scores[ith])
104
+ break
105
+ if type(mol) == None:
106
+ output.append(None)
107
+ scores.append(None)
108
+ output += scores
109
+ output = [input_compound] + output
110
+ try:
111
+ output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
112
+ st.table(output_df)
113
+ except:
114
+ pass