Atom Bioworks commited on
Commit
2616ade
1 Parent(s): 24c4c83

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +199 -0
utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import math
4
+
5
+ from sklearn.metrics import *
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset
11
+ import pickle
12
+
13
+
14
+ def word2idx(word, words):
15
+ if word in words.keys():
16
+ return int(words[word])
17
+
18
+ return 0
19
+
20
+ def pad_seq(dataset, max_len):
21
+ output = []
22
+ for seq in dataset:
23
+ pad = np.zeros(max_len)
24
+ pad[:len(seq)] = seq
25
+ output.append(pad)
26
+
27
+ return np.array(output)
28
+
29
+ def str2bool(seq):
30
+ out = []
31
+ for s in seq:
32
+ if s == "positive":
33
+ out.append(1)
34
+ elif s == "negative":
35
+ out.append(0)
36
+
37
+ return np.array(out)
38
+
39
+ class API_Dataset(Dataset):
40
+ def __init__(self, apta, esm_prot, y, apta_attn_mask, prot_attn_mask):
41
+ super(Dataset, self).__init__()
42
+
43
+ self.apta = np.array(apta, dtype=np.int64)
44
+ self.esm_prot = np.array(esm_prot, dtype=np.int64)
45
+ self.y = np.array(y, dtype=np.int64)
46
+ self.apta_attn_mask = np.array(apta_attn_mask)
47
+ self.prot_attn_mask = np.array(prot_attn_mask)
48
+ self.len = len(self.apta)
49
+
50
+ def __len__(self):
51
+ return self.len
52
+
53
+ def __getitem__(self, index):
54
+ return torch.tensor(self.apta[index], dtype=torch.int64), torch.tensor(self.esm_prot[index], dtype=torch.int64), torch.tensor(self.y[index], dtype=torch.int64), torch.tensor(self.apta_attn_mask[index], dtype=torch.int64), torch.tensor(self.prot_attn_mask[index], dtype=torch.int64)
55
+
56
+ def find_opt_threshold(target, pred):
57
+ result = 0
58
+ best = 0
59
+
60
+ for i in range(0, 1000):
61
+ pred_threshold = np.where(pred > i/1000, 1, 0)
62
+ now = f1_score(target, pred_threshold)
63
+ if now > best:
64
+ result = i/1000
65
+ best = now
66
+
67
+ return result
68
+
69
+ def argument_seqset(seqset):
70
+ arg_seqset = []
71
+ for s, ss in seqset:
72
+ arg_seqset.append([s, ss])
73
+
74
+ arg_seqset.append([s[::-1], ss[::-1]])
75
+
76
+ return arg_seqset
77
+
78
+ def augment_apis(apta, prot, ys):
79
+ aug_apta = []
80
+ aug_prot = []
81
+ aug_y = []
82
+ for a, p, y in zip(apta, prot, ys):
83
+ aug_apta.append(a)
84
+ aug_prot.append(p)
85
+ aug_y.append(y)
86
+
87
+ aug_apta.append(a[::-1])
88
+ aug_prot.append(p)
89
+ aug_y.append(y)
90
+
91
+ return np.array(aug_apta), np.array(aug_prot), np.array(aug_y)
92
+
93
+
94
+
95
+ def load_data_source(filepath):
96
+ with open(filepath,"rb") as fr:
97
+ dataset = pickle.load(fr)
98
+ dataset_train = np.array(dataset[dataset["dataset"]=="training dataset"])
99
+ dataset_test = np.array(dataset[dataset["dataset"]=="test dataset"])
100
+ dataset_bench = np.array(dataset[dataset['dataset']=='benchmark dataset'])
101
+
102
+ return dataset_train, dataset_test, dataset_bench
103
+
104
+
105
+ def get_dataset(filepath, prot_max_len, n_prot_vocabs, prot_words):
106
+ dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
107
+
108
+
109
+ arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
110
+ datasets_train = [rna2vec(arg_apta), tokenize_sequences(arg_prot, prot_max_len, n_prot_vocabs, prot_words), str2bool(arg_y)]
111
+ datasets_test = [rna2vec(dataset_test[:, 0]), tokenize_sequences(dataset_test[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_test[:, 2])]
112
+ datasets_bench = [rna2vec(dataset_bench[:, 0]), tokenize_sequences(dataset_bench[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_bench[:, 2])]
113
+
114
+ return datasets_train, datasets_test, datasets_bench
115
+
116
+
117
+ def get_esm_dataset(filepath, batch_converter, alphabet):
118
+ dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
119
+
120
+ # arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
121
+ # arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label
122
+
123
+ arg_apta, arg_prot, arg_y = dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]
124
+ arg_apta, arg_prot, arg_y = augment_apis(arg_apta, arg_prot, arg_y)
125
+
126
+ train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)]
127
+ _, _, prot_tokens = batch_converter(train_inputs)
128
+ datasets_train = [rna2vec(arg_apta), prot_tokens, str2bool(arg_y)]
129
+
130
+ test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])]
131
+ _, _, test_prot_tokens = batch_converter(test_inputs)
132
+ datasets_test = [rna2vec(dataset_test[:, 0]), test_prot_tokens, str2bool(dataset_test[:, 2])]
133
+
134
+ bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])]
135
+ _, _, bench_prot_tokens = batch_converter(bench_inputs)
136
+ # truncating
137
+ bench_prot_tokenized = bench_prot_tokens[:, :1678]
138
+ # padding
139
+ prot_ex = torch.ones((bench_prot_tokenized.shape[0], 1678), dtype=torch.int64)*alphabet.padding_idx
140
+ prot_ex[:, :bench_prot_tokenized.shape[1]] = bench_prot_tokenized
141
+ datasets_bench = [rna2vec(dataset_bench[:, 0]), prot_ex, str2bool(dataset_bench[:, 2])]
142
+
143
+ return datasets_train, datasets_test, datasets_bench
144
+
145
+ def get_nt_esm_dataset(filepath, nt_tokenizer, batch_converter, alphabet):
146
+ dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
147
+
148
+ arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
149
+ # arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label
150
+ max_length = 275#nt_tokenizer.model_max_length
151
+
152
+ train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)]
153
+ _, _, prot_tokens = batch_converter(train_inputs)
154
+ apta_toks = nt_tokenizer.batch_encode_plus(arg_apta, return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
155
+ apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
156
+ prot_attention_mask = prot_tokens != alphabet.padding_idx
157
+ # datasets_train = [apta_toks, prot_tokens, str2bool(arg_y)]
158
+ datasets_train = [apta_toks, prot_tokens, str2bool(arg_y), apta_attention_mask, prot_attention_mask]
159
+
160
+ test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])]
161
+ _, _, test_prot_tokens = batch_converter(test_inputs)
162
+ prot_ex = torch.ones((test_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx
163
+ prot_ex[:, :test_prot_tokens.shape[1]] = test_prot_tokens
164
+ apta_toks = nt_tokenizer.batch_encode_plus(dataset_test[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
165
+ apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
166
+ prot_attention_mask = prot_ex != alphabet.padding_idx
167
+ datasets_test = [apta_toks, prot_ex, str2bool(dataset_test[:, 2]), apta_attention_mask, prot_attention_mask]
168
+
169
+ bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])]
170
+ _, _, bench_prot_tokens = batch_converter(bench_inputs)
171
+ # padding
172
+ prot_ex = torch.ones((bench_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx
173
+ prot_ex[:, :bench_prot_tokens.shape[1]] = bench_prot_tokens
174
+ apta_toks = nt_tokenizer.batch_encode_plus(dataset_bench[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
175
+ apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
176
+ prot_attention_mask = prot_ex != alphabet.padding_idx
177
+ datasets_bench = [apta_toks, prot_ex, str2bool(dataset_bench[:, 2]), apta_attention_mask, prot_attention_mask]
178
+
179
+ return datasets_train, datasets_test, datasets_bench
180
+
181
+ def get_scores(target, pred):
182
+ threshold = find_opt_threshold(target, pred)
183
+ pred_threshold = np.where(pred > threshold, 1, 0)
184
+ acc = accuracy_score(target, pred_threshold)
185
+ roc_auc = roc_auc_score(target, pred)
186
+ mcc = matthews_corrcoef(target, pred_threshold)
187
+ f1 = f1_score(target, pred_threshold)
188
+ pr_auc = average_precision_score(target, pred)
189
+ cls_report = classification_report(target, pred_threshold)
190
+ scores = {
191
+ 'threshold': threshold,
192
+ 'acc': acc,
193
+ 'roc_auc': roc_auc,
194
+ 'mcc': mcc,
195
+ 'f1': f1,
196
+ 'pr_auc': pr_auc,
197
+ 'cls_report': cls_report
198
+ }
199
+ return scores