import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader import numpy as np from sklearn.metrics import * from omegaconf import OmegaConf import os import random from mcts import MCTS import esm from encoders import AptaBLE from utils import get_scores, API_Dataset, get_nt_esm_dataset from accelerate import Accelerator import glob import os import requests from transformers import AutoTokenizer, AutoModelForMaskedLM # accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) # NOTE: Buggy | Disables unused parameter issue accelerator = Accelerator() class AptaBLE_Pipeline(): """In-house API prediction score pipeline.""" def __init__(self, lr, dropout, weight_decay, epochs, model_type, model_version, model_save_path, accelerate_save_path, tensorboard_logdir, *args, **kwargs): self.device = accelerator.device self.lr = lr self.weight_decay = weight_decay self.epochs = epochs self.model_type = model_type self.model_version = model_version self.model_save_path = model_save_path self.accelerate_save_path = accelerate_save_path self.tensorboard_logdir = tensorboard_logdir esm_prot_encoder, self.esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() # ESM-2 Encoder # Freeze ESM-2 for name, param in esm_prot_encoder.named_parameters(): param.requires_grad = False for name, param in esm_prot_encoder.named_parameters(): if "layers.30" in name or "layers.31" in name or "layers.32" in name: param.requires_grad = True self.batch_converter = self.esm_alphabet.get_batch_converter(truncation_seq_length=1678) # self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-1000g") # nt_encoder = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-1000g") self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True) nt_encoder = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True) self.model = AptaBLE( apta_encoder=nt_encoder, prot_encoder=esm_prot_encoder, dropout=dropout, ).to(self.device) self.criterion = torch.nn.BCELoss().to(self.device) def train(self): print('Training the model!') # Initialize writer instance writer = SummaryWriter(log_dir=f"log/{self.model_type}/{self.model_version}") # Initialize early stopping self.early_stopper = EarlyStopper(3, 3) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, [4, 7, 10], 0.1) # Configure pytorch objects for distributed environment (i.e. sharded dataloader, multiple copies of model, etc.) self.model, self.optimizer, self.train_loader, self.test_loader, self.bench_loader, self.scheduler = accelerator.prepare(self.model, self.optimizer, self.train_loader, self.test_loader, self.bench_loader, self.scheduler) best_loss = 100 for epoch in range(1, self.epochs+1): self.model.train() loss_train, _, _ = self.batch_step(self.train_loader, train_mode=True) self.model.eval() self.scheduler.step() with torch.no_grad(): loss_test, pred_test, target_test = self.batch_step(self.test_loader, train_mode=False) test_scores = get_scores(target_test, pred_test) print("\tTrain Loss: {: .6f}\tTest Loss: {: .6f}\tTest ACC: {:.6f}\tTest AUC: {:.6f}\tTest MCC: {:.6f}\tTest PR_AUC: {:.6f}\tF1: {:.6f}\n".format(loss_train ,loss_test, test_scores['acc'], test_scores['roc_auc'], test_scores['mcc'], test_scores['pr_auc'], test_scores['f1'])) # stop_early = self.early_stopper.early_stop(loss_test) # Early stop - model has not improved on eval set. # if stop_early: # break # Only do checkpointing after near-convergence if epoch > 2: with torch.no_grad(): loss_bench, pred_bench, target_bench = self.batch_step(self.bench_loader, train_mode=False) bench_scores = get_scores(target_bench, pred_bench) print("\Bench Loss: {: .6f}\Bench ACC: {:.6f}\Bench AUC: {:.6f}\tBench MCC: {:.6f}\tBench PR_AUC: {:.6f}\tBench F1: {:.6f}\n".format(loss_bench, bench_scores['acc'], bench_scores['roc_auc'], bench_scores['mcc'], bench_scores['pr_auc'], bench_scores['f1'])) writer.add_scalar("Loss/bench", loss_bench, epoch) for k, v in bench_scores.items(): if isinstance(v, float): writer.add_scalar(f'{k}/bench', bench_scores[k], epoch) # Checkpoint based off of benchmark criteria # If model has improved and early stopping patience counter was just reset: if bench_scores['mcc'] > 0.5 and test_scores['mcc'] > 0.5 and loss_bench < 0.9 and accelerator.is_main_process: best_loss = loss_test # Remove all other files # for f in glob.glob(f'{self.model_save_path}/model*.pt'): # os.remove(f) accelerator.save_state(self.accelerate_save_path) model = accelerator.unwrap_model(self.model) torch.save(model.state_dict(), f'{self.model_save_path}/model_epoch={epoch}.pt') print(f'Model saved at {self.model_save_path}') print(f'Accelerate statistics saved at {self.accelerate_save_path}!') # Access via accelerator.load_state("./output") # logging writer.add_scalar("Loss/train", loss_train, epoch) writer.add_scalar("Loss/test", loss_test, epoch) for k, v in test_scores.items(): if isinstance(v, float): writer.add_scalar(f'{k}/test', test_scores[k], epoch) print("Training finished | access tensorboard via 'tensorboard --logdir=runs'.") writer.flush() writer.close() def batch_step(self, loader, train_mode = True): loss_total = 0 pred = np.array([]) target = np.array([]) for batch_idx, (apta, esm_prot, y, apta_attn, prot_attn) in enumerate(loader): if train_mode: self.optimizer.zero_grad() y_pred = self.predict(apta, esm_prot, apta_attn, prot_attn) y_true = torch.tensor(y, dtype=torch.float32).to(self.device) # not needed since accelerator modifies dataloader to automatically map input objects to correct dev loss = self.criterion(torch.flatten(y_pred), y_true) if train_mode: accelerator.backward(loss) # Accelerate backward() method scales gradients and uses appropriate backward method as configured across devices self.optimizer.step() loss_total += loss.item() pred = np.append(pred, torch.flatten(y_pred).clone().detach().cpu().numpy()) target = np.append(target, torch.flatten(y_true).clone().detach().cpu().numpy()) mode = 'train' if train_mode else 'eval' print(mode + "[{}/{}({:.0f}%)]".format(batch_idx, len(loader), 100. * batch_idx / len(loader)), end = "\r", flush=True) loss_total /= len(loader) return loss_total, pred, target def predict(self, apta, esm_prot, apta_attn, prot_attn): y_pred, _, _, _ = self.model(apta, esm_prot, apta_attn, prot_attn) return y_pred def inference(self, apta, prot, labels): """Perform inference on a batch of aptamer/protein pairs.""" self.model.eval() max_length = 275#nt_tokenizer.model_max_length inputs = [(i, j) for i, j in zip(labels, prot)] _, _, prot_tokens = self.batch_converter(inputs) apta_toks = self.nt_tokenizer.batch_encode_plus(apta, return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id # # truncating prot_tokenized = prot_tokens[:, :1680] # # padding prot_ex = torch.ones((prot_tokenized.shape[0], 1680), dtype=torch.int64)*self.esm_alphabet.padding_idx prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized prot_attention_mask = prot_ex != self.esm_alphabet.padding_idx loader = DataLoader(API_Dataset(apta_toks, prot_ex, labels, apta_attention_mask, prot_attention_mask), batch_size=1, shuffle=False) self.model, loader = accelerator.prepare(self.model, loader) with torch.no_grad(): _, pred, _ = self.batch_step(loader, train_mode=False) return pred def recommend(self, target, n_aptamers, depth, iteration, verbose=True): candidates = [] _, _, prot_tokens = self.batch_converter([(1, target)]) prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64) # adjusting for max protein sequence length during model training encoded_targetprotein = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*self.esm_alphabet.padding_idx encoded_targetprotein[:, :prot_tokenized.shape[1]] = prot_tokenized encoded_targetprotein = encoded_targetprotein.to(self.device) mcts = MCTS(encoded_targetprotein, depth=depth, iteration=iteration, states=8, target_protein=target, device=self.device, esm_alphabet=self.esm_alphabet) for _ in range(n_aptamers): mcts.make_candidate(self.model) candidates.append(mcts.get_candidate()) self.model.eval() with torch.no_grad(): sim_seq = np.array([mcts.get_candidate()]) print('first candidate: ', sim_seq) # apta = torch.tensor(rna2vec(sim_seq), dtype=torch.int64).to(self.device) apta = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids'] apta_attn = apta != self.nt_tokenizer.pad_token_id prot_attn = encoded_targetprotein != self.esm_alphabet.padding_idx score, _, _, _ = self.model(apta.to(self.device), encoded_targetprotein.to(self.device), apta_attn.to(self.device), prot_attn.to(self.device)) if verbose: candidate = mcts.get_candidate() print("candidate:\t", candidate, "\tscore:\t", score) print("*"*80) mcts.reset() def set_data_for_training(self, filepath, batch_size): # ds_train, ds_test, ds_bench = get_nt_esm_dataset(filepath, self.nt_tokenizer, self.batch_converter, self.esm_alphabet) ds_train, ds_test, ds_bench = get_nt_esm_dataset(filepath, self.nt_tokenizer, self.batch_converter, self.esm_alphabet) self.train_loader = DataLoader(API_Dataset(ds_train[0], ds_train[1], ds_train[2], ds_train[3], ds_train[4]), batch_size=batch_size, shuffle=True) self.test_loader = DataLoader(API_Dataset(ds_test[0], ds_test[1], ds_test[2], ds_test[3], ds_test[4]), batch_size=batch_size, shuffle=False) self.bench_loader = DataLoader(API_Dataset(ds_bench[0], ds_bench[1], ds_bench[2], ds_bench[3], ds_bench[4]), batch_size=batch_size, shuffle=False) class EarlyStopper: def __init__(self, patience=1, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.min_validation_loss = float('inf') def early_stop(self, validation_loss): if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: return True return False def seed_torch(seed=5471): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def main(): conf = OmegaConf.load('config.yaml') hyperparameters = conf.hyperparameters logging = conf.logging lr = hyperparameters['lr'] wd = hyperparameters['weight_decay'] dropout = hyperparameters['dropout'] batch_size = hyperparameters['batch_size'] epochs = hyperparameters['epochs'] model_type = logging['model_type'] model_version = logging['model_version'] model_save_path = logging['model_save_path'] accelerate_save_path = logging['accelerate_save_path'] tensorboard_logdir = logging['tensorboard_logdir'] seed = hyperparameters['seed'] if not os.path.exists(model_save_path): os.makedirs(model_save_path) seed_torch(seed=seed) pipeline = AptaBLE_Pipeline( lr=lr, weight_decay=wd, epochs=epochs, model_type=model_type, model_version=model_version, model_save_path=model_save_path, accelerate_save_path=accelerate_save_path, tensorboard_logdir=tensorboard_logdir, d_model=128, d_ff=512, n_layers=6, n_heads=8, dropout=dropout, load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells device='cuda', seed=seed) datapath = "./data/ABW_real_dna_aptamers_HC_v6.pkl" # datapath = './data/ABW_real_dna_aptamers_HC_neg_scrambles_neg_homology.pkl' pipeline.set_data_for_training(datapath, batch_size=batch_size) pipeline.train() endpoint = 'https://slack.atombioworks.com/hooks/t3y99qu6pi81frhwrhef1849wh' msg = {"text": "Model has finished training."} _ = requests.post(endpoint, json=msg, headers={"Content-Type": "application/json"}, ) return if __name__ == "__main__": # launch training w/ the following: "accelerate launch api_prediction.py [args]" main()