Spaces:
Sleeping
Sleeping
import os | |
import warnings | |
import logging | |
import random | |
import numpy as np | |
import torch.nn as nn | |
from transformers import AutoConfig, PreTrainedModel, T5ForConditionalGeneration | |
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer | |
from datasets.utils.logging import disable_progress_bar | |
# Suppress warnings and logging | |
warnings.filterwarnings("ignore") | |
logging.disable(logging.WARNING) | |
disable_progress_bar() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
import streamlit as st | |
st.title('ReactionT5 task yield') | |
st.markdown('##### At this space, you can predict the yields of reactions from their inputs.') | |
st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}REAGENT:{reagents, catalysts, or solvents of the reaction}PRODUCT:{products of the reaction}".') | |
st.markdown('##### If there are no reagents or catalysts, fill the blank with a space. And if there are multiple reactants, concatenate them with "."') | |
display_text = 'input the reaction smiles (e.g. REACTANT:CC(C)n1ncnc1-c1cn2c(n1)-c1cnc(O)cc1OCC2.CCN(C(C)C)C(C)C.Cl.NC(=O)[C@@H]1C[C@H](F)CN1REAGENT: PRODUCT:O=C(NNC(=O)C(F)(F)F)C(F)(F)F )' | |
st.download_button( | |
label="Download demo_input.csv", | |
data=pd.read_csv('demo_input.csv').to_csv(index=False), | |
file_name='demo_input.csv', | |
mime='text/csv', | |
) | |
class CFG(): | |
uploaded_file = st.file_uploader("Choose a CSV file") | |
data = st.text_area(display_text) | |
model = 't5' | |
model_name_or_path = 'sagawa/ReactionT5v2-yield' | |
max_len = 400 | |
batch_size = 5 | |
fc_dropout = 0.0 | |
seed = 42 | |
num_workers=1 | |
def seed_everything(seed=42): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def prepare_input(cfg, text): | |
""" | |
Prepare input tensors for the model. | |
Args: | |
cfg (argparse.Namespace): Configuration object. | |
text (str): Input text. | |
Returns: | |
dict: Tokenized input tensors. | |
""" | |
inputs = cfg.tokenizer( | |
text, | |
add_special_tokens=True, | |
max_length=cfg.max_len, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
) | |
return {k: torch.tensor(v, dtype=torch.long) for k, v in inputs.items()} | |
def inference_fn(test_loader, model, cfg): | |
""" | |
Inference function. | |
Args: | |
test_loader (DataLoader): DataLoader for test data. | |
model (nn.Module): Model for inference. | |
cfg (argparse.Namespace): Configuration object. | |
Returns: | |
np.ndarray: Predictions. | |
""" | |
model.eval() | |
model.to(cfg.device) | |
preds = [] | |
for inputs in test_loader: | |
inputs = {k: v.to(cfg.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
y_preds = model(inputs) | |
preds.append(y_preds.to("cpu").numpy()) | |
return np.concatenate(preds) | |
class TestDataset(Dataset): | |
""" | |
Dataset class for training. | |
""" | |
def __init__(self, cfg, df): | |
self.cfg = cfg | |
self.inputs = df["input"].values | |
def __len__(self): | |
return len(self.inputs) | |
def __getitem__(self, item): | |
inputs = prepare_input(self.cfg, self.inputs[item]) | |
return inputs | |
class ReactionT5Yield(PreTrainedModel): | |
config_class = AutoConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.model = T5ForConditionalGeneration.from_pretrained(self.config._name_or_path) | |
self.model.resize_token_embeddings(self.config.vocab_size) | |
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2) | |
self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2) | |
self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size) | |
self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
self.fc5 = nn.Linear(self.config.hidden_size, 1) | |
self._init_weights(self.fc1) | |
self._init_weights(self.fc2) | |
self._init_weights(self.fc3) | |
self._init_weights(self.fc4) | |
self._init_weights(self.fc5) | |
def _init_weights(self, module): | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.01) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=0.01) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def forward(self, inputs): | |
encoder_outputs = self.model.encoder(**inputs) | |
encoder_hidden_states = encoder_outputs[0] | |
outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1), | |
self.config.decoder_start_token_id, | |
dtype=torch.long), encoder_hidden_states=encoder_hidden_states) | |
last_hidden_states = outputs[0] | |
output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size)) | |
output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size)) | |
output = self.fc3(torch.hstack((output1, output2))) | |
output = self.fc4(output) | |
output = self.fc5(output) | |
return output*100 | |
if st.button('predict'): | |
with st.spinner('Now processing. This process takes about 4 seconds per reaction.'): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
CFG.device = device | |
seed_everything(seed=CFG.seed) | |
CFG.tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt') | |
model = ReactionT5Yield.from_pretrained(CFG.model_name_or_path) | |
if CFG.uploaded_file is not None: | |
test_ds = pd.read_csv(CFG.uploaded_file) | |
else: | |
test_ds = pd.DataFrame.from_dict({"input": [CFG.data]}, orient="index").T | |
test_dataset = TestDataset(CFG, test_ds) | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=CFG.batch_size, | |
shuffle=False, | |
num_workers=CFG.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
prediction = inference_fn(test_loader, model, CFG) | |
test_ds["prediction"] = prediction | |
test_ds["prediction"] = test_ds["prediction"].clip(0, 100) | |
csv = test_ds.to_csv(index=False) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name='output.csv', | |
mime='text/csv' | |
) |