sagawa commited on
Commit
acb2f9c
1 Parent(s): 3787ba6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import warnings
5
+ import pandas as pd
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import gc
10
+ import streamlit as st
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ st.title('ReactionT5_task_forward')
16
+ st.markdown('''
17
+ ##### At this space, you can predict the reactants of reactions from their products.
18
+ ##### The code expects input_data as a string or CSV file that contains an "input" column.
19
+ ##### The format of the string or contents of the column should be smiles generated by RDKit.
20
+ ##### For multiple compounds, concatenate them with ".".
21
+ ##### The output contains SMILES of predicted reactants and the sum of log-likelihood for each prediction, ordered by their log-likelihood (0th is the most probable reactant).
22
+ ''')
23
+
24
+ display_text = 'input the product smiles (e.g. CCN(CC)CCNC(=S)NC1CCCc2cc(C)cnc21)'
25
+
26
+ st.download_button(
27
+ label="Download demo_input.csv",
28
+ data=pd.read_csv('demo_input.csv').to_csv(index=False),
29
+ file_name='demo_input.csv',
30
+ mime='text/csv',
31
+ )
32
+
33
+ class CFG():
34
+ num_beams = st.number_input(label='num beams', min_value=1, max_value=10, value=5, step=1)
35
+ num_return_sequences = num_beams
36
+ uploaded_file = st.file_uploader("Choose a CSV file")
37
+ input_data = st.text_area(display_text)
38
+ model_name_or_path = 'sagawa/ReactionT5v2-retrosynthesis'
39
+ input_column = 'input'
40
+ input_max_length = 100
41
+ model = 't5'
42
+ seed = 42
43
+ batch_size=1
44
+
45
+ def seed_everything(seed=42):
46
+ random.seed(seed)
47
+ os.environ['PYTHONHASHSEED'] = str(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+ torch.cuda.manual_seed(seed)
51
+ torch.backends.cudnn.deterministic = True
52
+
53
+
54
+
55
+ def prepare_input(cfg, text):
56
+ inputs = tokenizer(
57
+ text,
58
+ return_tensors="pt",
59
+ max_length=cfg.input_max_length,
60
+ padding="max_length",
61
+ truncation=True,
62
+ )
63
+ dic = {"input_ids": [], "attention_mask": []}
64
+ for k, v in inputs.items():
65
+ dic[k].append(torch.tensor(v[0], dtype=torch.long))
66
+ return dic
67
+
68
+
69
+ class ProductDataset(Dataset):
70
+ def __init__(self, cfg, df):
71
+ self.cfg = cfg
72
+ self.inputs = df[cfg.input_column].values
73
+
74
+ def __len__(self):
75
+ return len(self.inputs)
76
+
77
+ def __getitem__(self, idx):
78
+ return prepare_input(self.cfg, self.inputs[idx])
79
+
80
+
81
+ def predict_single_input(input_compound):
82
+ inp = tokenizer(input_compound, return_tensors="pt").to(device)
83
+ with torch.no_grad():
84
+ output = model.generate(
85
+ **inp,
86
+ num_beams=CFG.num_beams,
87
+ num_return_sequences=CFG.num_return_sequences,
88
+ return_dict_in_generate=True,
89
+ output_scores=True,
90
+ )
91
+ return output
92
+
93
+
94
+ def decode_output(output):
95
+ sequences = [
96
+ tokenizer.decode(seq, skip_special_tokens=True).replace(" ", "").rstrip(".")
97
+ for seq in output["sequences"]
98
+ ]
99
+ if CFG.num_beams > 1:
100
+ scores = output["sequences_scores"].tolist()
101
+ return sequences, scores
102
+ return sequences, None
103
+
104
+
105
+ def save_single_prediction(input_compound, output, scores):
106
+ output_data = [input_compound] + output + (scores if scores else [])
107
+ columns = (
108
+ ["input"]
109
+ + [f"{i}th" for i in range(CFG.num_beams)]
110
+ + ([f"{i}th score" for i in range(CFG.num_beams)] if scores else [])
111
+ )
112
+ output_df = pd.DataFrame([output_data], columns=columns)
113
+ return output_df
114
+
115
+
116
+ def save_multiple_predictions(input_data, sequences, scores):
117
+ output_list = [
118
+ [input_data.loc[i // CFG.num_return_sequences, CFG.input_column]]
119
+ + sequences[i : i + CFG.num_return_sequences]
120
+ + scores[i : i + CFG.num_return_sequences]
121
+ for i in range(0, len(sequences), CFG.num_return_sequences)
122
+ ]
123
+ columns = (
124
+ ["input"]
125
+ + [f"{i}th" for i in range(CFG.num_return_sequences)]
126
+ + ([f"{i}th score" for i in range(CFG.num_return_sequences)] if scores else [])
127
+ )
128
+ output_df = pd.DataFrame(output_list, columns=columns)
129
+ return output_df
130
+
131
+
132
+ if st.button('predict'):
133
+ with st.spinner('Now processing. If num beams=5, this process takes about 15 seconds per reaction.'):
134
+
135
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136
+
137
+ seed_everything(seed=CFG.seed)
138
+
139
+ tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors="pt")
140
+ model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device)
141
+ model.eval()
142
+
143
+ if CFG.uploaded_file is None:
144
+ input_compound = CFG.input_data
145
+ output = predict_single_input(input_compound)
146
+ sequences, scores = decode_output(output)
147
+ output_df = save_single_prediction(input_compound, sequences, scores)
148
+ else:
149
+ input_data = pd.read_csv(CFG.uploaded_file)
150
+ dataset = ProductDataset(CFG, input_data)
151
+ dataloader = DataLoader(
152
+ dataset,
153
+ batch_size=CFG.batch_size,
154
+ shuffle=False,
155
+ num_workers=4,
156
+ pin_memory=True,
157
+ drop_last=False,
158
+ )
159
+
160
+ all_sequences, all_scores = [], []
161
+ for inputs in dataloader:
162
+ inputs = {k: v[0].to(device) for k, v in inputs.items()}
163
+ with torch.no_grad():
164
+ output = model.generate(
165
+ **inputs,
166
+ num_beams=CFG.num_beams,
167
+ num_return_sequences=CFG.num_return_sequences,
168
+ return_dict_in_generate=True,
169
+ output_scores=True,
170
+ )
171
+ sequences, scores = decode_output(output)
172
+ all_sequences.extend(sequences)
173
+ if scores:
174
+ all_scores.extend(scores)
175
+ del output
176
+ torch.cuda.empty_cache()
177
+ gc.collect()
178
+
179
+ output_df = save_multiple_predictions(input_data, all_sequences, all_scores)
180
+
181
+ @st.cache
182
+ def convert_df(df):
183
+ return df.to_csv(index=False)
184
+
185
+ csv = convert_df(output_df)
186
+
187
+ st.download_button(
188
+ label="Download data as CSV",
189
+ data=csv,
190
+ file_name='output.csv',
191
+ mime='text/csv',
192
+ )