|
from typing import Literal |
|
import warnings |
|
import time |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.utils.data |
|
from torch.utils.data import Dataset as TorchDataset |
|
from torch_geometric.data import Dataset, Data |
|
from torch_geometric.data.data import BaseData |
|
from torch_geometric.utils import remove_isolated_nodes |
|
from itertools import cycle |
|
from multiprocessing import Pool |
|
from multiprocessing import get_context |
|
from typing import Any, List |
|
import data.utils as utils |
|
import h5py |
|
import lmdb |
|
import pickle |
|
from datetime import datetime |
|
import os |
|
NUM_THREADS = 42 |
|
|
|
|
|
class GraphMutationDataset(Dataset): |
|
""" |
|
MutationDataSet dataset, input a file of mutations, output a star graph and KNN graph |
|
Can be either single mutation or multiple mutations. |
|
|
|
Args: |
|
data_file (string or pd.DataFrame): Path or pd.DataFrame for a csv file for a list of mutations |
|
data_type (string): Type of this data, 'ClinVar', 'DMS', etc |
|
""" |
|
|
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff', 'zero', 'orig'] = 'alt', |
|
computed_graph: bool = True, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
use_lmdb: bool = False, |
|
): |
|
super(GraphMutationDataset, self).__init__() |
|
if isinstance(data_file, pd.DataFrame): |
|
self.data = data_file |
|
self.data_file = 'pd.DataFrame' |
|
elif isinstance(data_file, str): |
|
try: |
|
self.data = pd.read_csv(data_file, index_col=0, low_memory=False) |
|
except UnicodeDecodeError: |
|
self.data = pd.read_csv(data_file, index_col=0, encoding='ISO-8859-1') |
|
self.data_file = data_file |
|
else: |
|
raise ValueError("data_path must be a string or a pandas.DataFrame") |
|
self.data_type = data_type |
|
self._y_columns = self.data.columns[self.data.columns.str.startswith('score')] |
|
self._y_mask_columns = self.data.columns[self.data.columns.str.startswith('confidence.score')] |
|
self.node_embedding_type = node_embedding_type |
|
self.graph_type = graph_type |
|
self.neighbor_type = neighbor_type |
|
self.add_plddt = add_plddt |
|
self.scale_plddt = scale_plddt |
|
self.add_conservation = add_conservation |
|
self.add_position = add_position |
|
self.use_cb = use_cb |
|
self.add_sidechain = add_sidechain |
|
self.add_msa_contacts = add_msa_contacts |
|
self.add_dssp = add_dssp |
|
self.add_msa = add_msa |
|
self.add_af2_single = add_af2_single |
|
self.add_af2_pairwise = add_af2_pairwise |
|
self.loaded_af2_single = loaded_af2_single |
|
self.loaded_af2_pairwise = loaded_af2_pairwise |
|
self.add_confidence = add_confidence |
|
self.loaded_confidence = loaded_confidence |
|
self.add_ptm = add_ptm |
|
self.loaded_msa = loaded_msa |
|
self.loaded_esm = loaded_esm |
|
self.alt_type = alt_type |
|
self.max_len = max_len |
|
self.loop = loop |
|
self.data_augment = data_augment |
|
|
|
self.af2_file_dict = None |
|
self.af2_coord_dict = None |
|
self.af2_plddt_dict = None |
|
self.af2_confidence_dict = None |
|
self.af2_dssp_dict = None |
|
self.af2_graph_dict = None |
|
self.esm_file_dict = None |
|
self.esm_dict = None |
|
self.msa_file_dict = None |
|
self.msa_dict = None |
|
self._check_embedding_files() |
|
if score_transfer: |
|
|
|
if set(self.data['score'].unique()) <= {0, 1}: |
|
self.data['score'] = self.data['score'] * 3 |
|
else: |
|
warnings.warn("score_transfer is only applied when score is 0 or 1") |
|
if data_augment and set(self.data['score'].unique()) > {0, 1}: |
|
|
|
reverse_data = self.data.copy() |
|
|
|
reverse_data = reverse_data.loc[(reverse_data['score'] == 1) | (reverse_data['score'] == 0), :] |
|
reverse_data['ref'] = self.data['alt'] |
|
reverse_data['alt'] = self.data['ref'] |
|
reverse_data['score'] = -reverse_data['score'] |
|
self.data = pd.concat([self.data, reverse_data], ignore_index=True) |
|
self._set_mutations() |
|
self.computed_graph = computed_graph |
|
self._load_af2_features(radius=radius, max_neighbors=max_neighbors, loop=loop, gpu_id=gpu_id) |
|
if (self.add_msa or self.add_conservation) and self.loaded_msa: |
|
self._load_msa_features() |
|
if self.loaded_esm: |
|
self._load_esm_features() |
|
if self.loaded_af2_pairwise or self.loaded_af2_single: |
|
self._load_af2_reps() |
|
self._set_node_embeddings() |
|
self._set_edge_embeddings() |
|
self.unmatched_msa = 0 |
|
|
|
if shuffle: |
|
np.random.seed(0) |
|
shuffle_index = np.random.permutation(len(self.mutations)) |
|
self.data = self.data.iloc[shuffle_index].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, shuffle_index)) |
|
if self.add_ptm: |
|
self.ptm_ref = pd.read_csv('./data.files/ptm.small.csv', index_col=0) |
|
self.get_method = 'default' |
|
|
|
|
|
|
|
def _check_embedding_files(self): |
|
print(f"read in {len(self.data)} mutations from {self.data_file}") |
|
|
|
unique_data = self.data.drop_duplicates(subset=['uniprotID']) |
|
print(f"found {len(unique_data)} unique wt sequences") |
|
|
|
if self.node_embedding_type == 'esm': |
|
with Pool(NUM_THREADS) as p: |
|
embedding_exist = p.starmap(utils.get_embedding_from_esm2, zip(unique_data['uniprotID'], cycle([True]))) |
|
|
|
|
|
to_drop = unique_data['wt.orig'].loc[~np.array(embedding_exist, dtype=bool)] |
|
print(f"drop {np.sum(self.data['wt.orig'].isin(to_drop))} mutations that do not have embedding or msa") |
|
self.data = self.data[~self.data['wt.orig'].isin(to_drop)] |
|
else: |
|
print(f"skip checking embedding files for {self.node_embedding_type}") |
|
|
|
def _set_mutations(self): |
|
if 'af2_file' not in self.data.columns: |
|
self.data['af2_file'] = pd.NA |
|
with Pool(NUM_THREADS) as p: |
|
point_mutations = p.starmap(utils.get_mutations, zip(self.data['uniprotID'], |
|
self.data['ENST'] if 'ENST' in self.data.columns else cycle([None]), |
|
self.data['wt.orig'], |
|
self.data['sequence.len.orig'], |
|
self.data['pos.orig'], |
|
self.data['ref'], |
|
self.data['alt'], |
|
cycle([self.max_len]), |
|
self.data['af2_file'] if 'af2_file' in self.data.columns else cycle([None]),)) |
|
|
|
print(f"drop {np.sum(~np.array(point_mutations, dtype=bool))} mutations that don't have coordinates") |
|
self.data = self.data.loc[np.array(point_mutations, dtype=bool)] |
|
self.mutations = list(filter(bool, point_mutations)) |
|
print(f'Finished loading {len(self.mutations)} mutations') |
|
|
|
def _load_af2_features(self, radius, max_neighbors, loop, gpu_id): |
|
self.af2_file_dict, mutation_idx = np.unique([mutation.af2_file for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_af2_seq_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
self.af2_coord_dict = p.starmap(utils.get_coords_from_af2, zip(self.af2_file_dict, cycle([self.add_sidechain]))) |
|
print(f'Finished loading {len(self.af2_coord_dict)} af2 coords') |
|
self.af2_plddt_dict = p.starmap(utils.get_plddt_from_af2, zip(self.af2_file_dict)) if self.add_plddt else None |
|
print(f'Finished loading plddt') |
|
self.af2_confidence_dict = p.starmap(utils.get_confidence_from_af2file, zip(self.af2_file_dict, self.af2_plddt_dict)) if self.add_plddt and self.add_confidence and self.loaded_confidence else None |
|
print(f'Finished loading confidence') |
|
self.af2_dssp_dict = p.starmap(utils.get_dssp_from_af2, zip(self.af2_file_dict)) if self.add_dssp else None |
|
print(f'Finished loading dssp') |
|
if self.computed_graph: |
|
if self.graph_type == 'af2': |
|
if self.neighbor_type == 'KNN': |
|
self.af2_graph_dict = list(map(utils.get_knn_graphs_from_af2, self.af2_coord_dict, |
|
cycle([radius]), cycle([max_neighbors]), cycle([loop]), cycle([gpu_id]))) |
|
print(f'Finished constructing {len(self.af2_graph_dict)} af2 graphs') |
|
else: |
|
|
|
self.computed_graph = False |
|
print(f'Do not construct graphs from af2 files to save RAM') |
|
elif self.graph_type == '1d-neighbor': |
|
self.af2_graph_dict = list(map(utils.get_graphs_from_neighbor, self.af2_coord_dict, |
|
cycle([max_neighbors]), cycle([loop]))) |
|
print(f'Finished constructing {len(self.af2_graph_dict)} af2 graphs') |
|
else: |
|
print(f'Do not construct graphs from af2 files to save RAM') |
|
self.radius = radius |
|
self.max_neighbors = max_neighbors |
|
self.loop = loop |
|
self.gpu_id = gpu_id |
|
|
|
def _load_esm_features(self): |
|
self.esm_file_dict, mutation_idx = np.unique([mutation.ESM_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_esm_seq_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
self.esm_dict = p.starmap(utils.get_esm_dict_from_uniprot, zip(self.esm_file_dict)) |
|
print(f'Finished loading {len(self.esm_file_dict)} esm embeddings') |
|
|
|
def _load_af2_reps(self): |
|
self.af2_rep_file_prefix_dict, mutation_idx = np.unique([mutation.af2_rep_file_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_af2_rep_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
if self.add_af2_single and self.loaded_af2_single: |
|
self.af2_single_dict = p.starmap(utils.get_af2_single_rep_dict_from_prefix, zip(self.af2_rep_file_prefix_dict)) |
|
print(f'Finished loading {len(self.af2_rep_file_prefix_dict)} alphafold2 single representations') |
|
|
|
if self.add_af2_pairwise and self.loaded_af2_pairwise: |
|
raise ValueError("Not implemented in this version") |
|
|
|
def _load_msa_features(self): |
|
self.msa_file_dict, mutation_idx = np.unique([mutation.uniprot_id for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_msa_seq_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
|
|
self.msa_dict = p.starmap(utils.get_msa_dict_from_transcript, zip(self.msa_file_dict)) |
|
print(f'Finished loading {len(self.msa_dict)} msa seqs') |
|
|
|
def _set_node_embeddings(self): |
|
pass |
|
|
|
def _set_edge_embeddings(self): |
|
pass |
|
|
|
def get_mask(self, mutation: utils.Mutation): |
|
return mutation.pos - 1, mutation |
|
|
|
def get_graph_and_mask(self, mutation: utils.Mutation): |
|
|
|
coords: np.ndarray = self.af2_coord_dict[mutation.af2_seq_index] |
|
if self.computed_graph: |
|
edge_index = self.af2_graph_dict[mutation.af2_seq_index] |
|
else: |
|
if self.graph_type == 'af2': |
|
if self.neighbor_type == 'KNN': |
|
edge_index = utils.get_knn_graphs_from_af2(coords, self.radius, self.max_neighbors, self.loop, self.gpu_id) |
|
elif self.neighbor_type == 'radius': |
|
edge_index = utils.get_radius_graphs_from_af2(coords, self.radius, self.loop, self.gpu_id) |
|
|
|
connected_nodes = edge_index[:, np.isin(edge_index[0], mutation.pos - 1)].flatten() |
|
edge_index = edge_index[:, np.isin(edge_index[0], connected_nodes) | np.isin(edge_index[1], connected_nodes)] |
|
else: |
|
edge_index = utils.get_radius_knn_graphs_from_af2(coords, mutation.pos - 1, self.radius, self.max_neighbors, self.loop) |
|
elif self.graph_type == '1d-neighbor': |
|
edge_index = utils.get_graphs_from_neighbor(coords, self.max_neighbors, self.loop) |
|
|
|
if mutation.crop: |
|
coords = coords[mutation.seq_start - 1:mutation.seq_end, :] |
|
edge_index = edge_index[:, (edge_index[0, :] >= mutation.seq_start - 1) & |
|
(edge_index[1, :] >= mutation.seq_start - 1) & |
|
(edge_index[0, :] < mutation.seq_end) & |
|
(edge_index[1, :] < mutation.seq_end)] |
|
edge_index[0, :] -= mutation.seq_start - 1 |
|
edge_index[1, :] -= mutation.seq_start - 1 |
|
|
|
mask_idx, mutation = self.get_mask(mutation) |
|
|
|
edge_matrix_star = np.zeros((coords.shape[0], coords.shape[0])) |
|
edge_matrix_star[:, mask_idx] = 1 |
|
edge_matrix_star[mask_idx, :] = 1 |
|
edge_index_star = np.array(np.where(edge_matrix_star == 1)) |
|
|
|
if self.neighbor_type == 'radius' or self.neighbor_type == 'KNN': |
|
edge_index_star = edge_index_star[:, np.isin(edge_index_star[0], edge_index.flatten()) & |
|
np.isin(edge_index_star[1], edge_index.flatten())] |
|
elif self.neighbor_type == 'radius-KNN': |
|
edge_index_star = edge_index_star[:, np.isin(edge_index_star[0], np.concatenate((edge_index.flatten(), mask_idx))) & |
|
np.isin(edge_index_star[1], np.concatenate((edge_index.flatten(), mask_idx)))] |
|
|
|
if not self.loop: |
|
edge_index_star = edge_index_star[:, edge_index_star[0] != edge_index_star[1]] |
|
if self.add_msa_contacts: |
|
coevo_strength = utils.get_contacts_from_msa(mutation, False) |
|
if isinstance(coevo_strength, int): |
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 1]) |
|
else: |
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 0]) |
|
start = time.time() |
|
if self.add_af2_pairwise: |
|
if self.loaded_af2_pairwise: |
|
|
|
|
|
byteflow = self.af2_pairwise_txn.get(u'{}'.format(mutation.af2_rep_file_prefix.split('/')[-1]).encode('ascii')) |
|
pairwise_rep = pickle.loads(byteflow) |
|
if pairwise_rep is None: |
|
pairwise_rep = utils.get_af2_pairwise_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
else: |
|
pairwise_rep = utils.get_af2_pairwise_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
|
|
if mutation.af2_rep_file_prefix.find('-F') == -1: |
|
pairwise_rep = pairwise_rep[mutation.seq_start_orig - 1: mutation.seq_end_orig, |
|
mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
pairwise_rep = pairwise_rep[mutation.seq_start - 1: mutation.seq_end, |
|
mutation.seq_start - 1: mutation.seq_end] |
|
coevo_strength = np.concatenate([coevo_strength, pairwise_rep], axis=2) |
|
end = time.time() |
|
print(f'Finished loading pairwise in {end - start:.2f} seconds') |
|
edge_attr = coevo_strength[edge_index[0], edge_index[1], :] |
|
edge_attr_star = coevo_strength[edge_index_star[0], edge_index_star[1], :] |
|
|
|
if self.add_position: |
|
|
|
edge_attr = np.concatenate( |
|
(edge_attr, np.sin(np.pi / 2 * (edge_index[1] - edge_index[0]) / self.max_len).reshape(-1, 1)), |
|
axis=1) |
|
edge_attr_star = np.concatenate( |
|
(edge_attr_star, np.sin(np.pi / 2 * (edge_index_star[1] - edge_index_star[0]) / self.max_len).reshape(-1, 1)), |
|
axis=1) |
|
return coords, edge_index, edge_index_star, edge_attr, edge_attr_star, mask_idx, mutation |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, edge_index, edge_index_star, edge_attr, edge_attr_star, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
if self.node_embedding_type == 'esm': |
|
if self.loaded_esm: |
|
embed_data = utils.get_embedding_from_esm2(self.esm_dict[mutation.esm_seq_index], False, |
|
mutation.seq_start, mutation.seq_end) |
|
else: |
|
embed_data = utils.get_embedding_from_esm2(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
to_alt = np.concatenate([utils.ESM_AA_EMBEDDING_DICT[alt_aa].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([utils.ESM_AA_EMBEDDING_DICT[ref_aa].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'one-hot-idx': |
|
assert not self.add_conservation and not self.add_plddt |
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_onehot_nonzero(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
to_alt = np.concatenate([np.array(utils.AA_DICT.index(alt_aa)).reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.array(utils.AA_DICT.index(ref_aa)).reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'one-hot': |
|
embed_data, one_hot_mat = utils.get_embedding_from_onehot(mutation.seq, return_idx=False, return_onehot_mat=True) |
|
to_alt = np.concatenate([np.eye(len(utils.AA_DICT))[utils.AA_DICT.index(alt_aa)].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.eye(len(utils.AA_DICT))[utils.AA_DICT.index(ref_aa)].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'aa-5dim': |
|
embed_data = utils.get_embedding_from_5dim(mutation.seq) |
|
to_alt = np.concatenate([np.array(utils.AA_5DIM_EMBED[alt_aa]).reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.array(utils.AA_5DIM_EMBED[ref_aa]).reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'esm1b': |
|
embed_data = utils.get_embedding_from_esm1b(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
to_alt = np.concatenate([utils.ESM1b_AA_EMBEDDING_DICT[alt_aa].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([utils.ESM1b_AA_EMBEDDING_DICT[ref_aa].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
if self.alt_type == "zero": |
|
to_alt = np.zeros_like(to_alt)[[0]] |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
to_alt = np.concatenate([to_alt, conservation_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, conservation_data[mask_idx]], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
to_alt = np.concatenate([to_alt, plddt_data[mask_idx, None]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, plddt_data[mask_idx]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
to_alt = np.concatenate([to_alt, dssp_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, dssp_data[mask_idx]], axis=1) |
|
if self.add_ptm: |
|
|
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
to_alt = np.concatenate([to_alt, ptm_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, ptm_data[mask_idx]], axis=1) |
|
if self.add_af2_single: |
|
if self.loaded_af2_single: |
|
single_rep = self.af2_single_dict[mutation.af2_rep_index] |
|
else: |
|
single_rep = utils.get_af2_single_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
|
|
if mutation.af2_rep_file_prefix.find('-F') == -1: |
|
single_rep = single_rep[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
single_rep = single_rep[mutation.seq_start - 1: mutation.seq_end] |
|
embed_data = np.concatenate([embed_data, single_rep], axis=1) |
|
to_alt = np.concatenate([to_alt, single_rep[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, single_rep[mask_idx]], axis=1) |
|
if self.add_msa: |
|
|
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.alt_type == 'alt' or self.alt_type == 'zero': |
|
to_alt = np.concatenate([to_alt, msa_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, msa_data[mask_idx]], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
if self.alt_type == 'alt' or self.alt_type == 'zero': |
|
alt_embed_data = np.zeros_like(embed_data) |
|
alt_embed_data[mask_idx] = to_alt |
|
elif self.alt_type == 'concat': |
|
alt_embed_data = np.zeros((embed_data.shape[0], to_alt.shape[1] + to_ref.shape[1])) |
|
alt_embed_data[mask_idx] = np.concatenate([to_alt, to_ref], axis=1) |
|
elif self.alt_type == 'diff': |
|
alt_embed_data = np.zeros_like(embed_data) |
|
alt_embed_data[mask_idx] = to_alt |
|
embed_data[mask_idx] = to_ref |
|
elif self.alt_type == 'orig': |
|
|
|
alt_embed_data = embed_data |
|
else: |
|
raise ValueError(f'alt_type {self.alt_type} not supported') |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
if self.add_sidechain: |
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=embed_logits if self.node_embedding_type == 'one-hot-idx' else None, |
|
one_hot_mat=one_hot_mat if self.node_embedding_type.startswith('one-hot') else None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=alt_embed_data, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=edge_index, |
|
edge_index_star=edge_index_star, |
|
edge_attr=edge_attr, |
|
edge_attr_star=edge_attr_star, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae[edge_index[0], edge_index[1]] |
|
features['edge_confidence_star'] = pae[edge_index_star[0], edge_index_star[1]] |
|
return features |
|
|
|
def get(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=torch.from_numpy(features_np['alt_embed_data']).to(torch.float32), |
|
pos=torch.from_numpy(features_np['CA_coord']).to(torch.float32) if not self.use_cb else torch.from_numpy(features_np['CB_coord']).to(torch.float32), |
|
edge_index=torch.from_numpy(features_np['edge_index']).to(torch.long), |
|
edge_index_star=torch.from_numpy(features_np['edge_index_star']).to(torch.long), |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32), |
|
edge_attr_star=torch.from_numpy(features_np['edge_attr_star']).to(torch.float32), |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32), |
|
y=torch.tensor([self.data[self._y_columns].iloc[int(idx)]]).to(torch.float32), |
|
) |
|
if self.add_confidence: |
|
features['plddt'] = torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
features['edge_confidence'] = torch.from_numpy(features_np['edge_confidence']).to(torch.float32) |
|
features['edge_confidence_star'] = torch.from_numpy(features_np['edge_confidence_star']).to(torch.float32) |
|
if self.neighbor_type == 'radius' or self.neighbor_type == 'radius-KNN': |
|
|
|
concat_edge_index = torch.cat((features["edge_index"], features["edge_index_star"]), dim=1) |
|
concat_edge_attr = torch.cat((features["edge_attr"], features["edge_attr_star"]), dim=0) |
|
|
|
concat_edge_index, concat_edge_attr, mask = \ |
|
remove_isolated_nodes(concat_edge_index, concat_edge_attr, x.shape[0]) |
|
|
|
features["edge_index"] = concat_edge_index[:, :features["edge_index"].shape[1]] |
|
features["edge_index_star"] = concat_edge_index[:, features["edge_index"].shape[1]:] |
|
features["edge_attr"] = concat_edge_attr[:features["edge_attr"].shape[0]] |
|
features["edge_attr_star"] = concat_edge_attr[features["edge_attr"].shape[0]:] |
|
else: |
|
features["edge_index"], features["edge_attr"], mask = \ |
|
remove_isolated_nodes(features["edge_index"], features["edge_attr"], x.shape[0]) |
|
features["edge_index_star"], features["edge_attr_star"], mask = \ |
|
remove_isolated_nodes(features["edge_index_star"], features["edge_attr_star"], x.shape[0]) |
|
features["x"] = features["x"][mask] |
|
features["x_mask"] = features["x_mask"][mask] |
|
features["x_alt"] = features["x_alt"][mask] |
|
features["pos"] = features["pos"][mask] |
|
features["node_vec_attr"] = features["node_vec_attr"][mask] |
|
if len(self._y_mask_columns) > 0: |
|
features['score_mask'] = torch.tensor([self.data[self._y_mask_columns].iloc[int(idx)]]).to(torch.float) |
|
return Data(**features) |
|
|
|
def get_from_hdf5(self, idx): |
|
if not hasattr(self, 'hdf5_keys') or self.hdf5_file is None: |
|
raise ValueError('hdf5 file is not set') |
|
else: |
|
features = {} |
|
with h5py.File(self.hdf5_file, 'r') as f: |
|
for key in self.hdf5_keys: |
|
features[key] = torch.tensor(f[f'{self.hdf5_idx_map[idx]}/{key}']) |
|
return Data(**features) |
|
|
|
def open_lmdb(self): |
|
self.env = lmdb.open(self.lmdb_path, subdir=False, |
|
readonly=True, lock=False, |
|
readahead=False, meminit=False) |
|
self.txn = self.env.begin(write=False, buffers=True) |
|
|
|
def get_from_lmdb(self, idx): |
|
if not hasattr(self, 'txn'): |
|
self.open_lmdb() |
|
byteflow = self.txn.get(u'{}'.format(self.lmdb_idx_map[idx]).encode('ascii')) |
|
unpacked = pickle.loads(byteflow) |
|
return unpacked |
|
|
|
def __getitem__(self, idx): |
|
|
|
start = time.time() |
|
if self.get_method == 'default': |
|
data = self.get(idx) |
|
print(f'default Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
elif self.get_method == 'hdf5': |
|
data = self.get_from_hdf5(idx) |
|
print(f'hdf5 Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
elif self.get_method == 'lmdb': |
|
data = self.get_from_lmdb(idx) |
|
print(f'lmdb Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
elif self.get_method == 'memory': |
|
data = self.parsed_data[idx] |
|
print(f'memory Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.mutations) |
|
|
|
def len(self) -> int: |
|
return len(self.mutations) |
|
|
|
def subset(self, idxs): |
|
self.data = self.data.iloc[idxs].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, idxs)) |
|
|
|
subset_af2_file_dict, mutation_idx = np.unique([mutation.af2_file for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
if hasattr(self, 'af2_file_dict') and self.af2_file_dict is not None: |
|
af2_file_idx = np.array([np.where(self.af2_file_dict==i)[0][0] for i in subset_af2_file_dict]) |
|
self.af2_file_dict = subset_af2_file_dict |
|
|
|
self.af2_coord_dict = list(map(self.af2_coord_dict.__getitem__, af2_file_idx)) if self.af2_coord_dict is not None else None |
|
self.af2_plddt_dict = list(map(self.af2_plddt_dict.__getitem__, af2_file_idx)) if self.af2_plddt_dict is not None else None |
|
self.af2_confidence_dict = list(map(self.af2_confidence_dict.__getitem__, af2_file_idx)) if self.af2_confidence_dict is not None else None |
|
self.af2_dssp_dict = list(map(self.af2_dssp_dict.__getitem__, af2_file_idx)) if self.af2_dssp_dict is not None else None |
|
self.af2_graph_dict = list(map(self.af2_graph_dict.__getitem__, af2_file_idx)) if self.af2_graph_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_af2_seq_index(y), self.mutations, mutation_idx)) |
|
|
|
if hasattr(self, 'esm_file_dict') and self.esm_file_dict is not None: |
|
subset_esm_file_dict, mutation_idx = np.unique([mutation.ESM_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
esm_file_idx = np.array([np.where(self.esm_file_dict==i)[0][0] for i in subset_esm_file_dict]) |
|
self.esm_file_dict = subset_esm_file_dict |
|
|
|
self.esm_dict = list(map(self.esm_dict.__getitem__, esm_file_idx)) if self.esm_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_esm_seq_index(y), self.mutations, mutation_idx)) |
|
|
|
if hasattr(self, 'msa_file_dict') and self.msa_file_dict is not None: |
|
subset_msa_file_dict, mutation_idx = np.unique([mutation.uniprot_id for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
msa_file_idx = np.array([np.where(self.msa_file_dict==i)[0][0] for i in subset_msa_file_dict]) |
|
self.msa_file_dict = subset_msa_file_dict |
|
|
|
self.msa_dict = list(map(self.msa_dict.__getitem__, msa_file_idx)) if self.msa_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_msa_seq_index(y), self.mutations, mutation_idx)) |
|
|
|
if hasattr(self, 'hdf5_idx_map') and self.hdf5_idx_map is not None: |
|
self.hdf5_idx_map = self.hdf5_idx_map[idxs] |
|
|
|
if hasattr(self, 'lmdb_idx_map') and self.lmdb_idx_map is not None: |
|
self.lmdb_idx_map = self.lmdb_idx_map[idxs] |
|
if hasattr(self, 'parsed_data') and self.parsed_data is not None: |
|
self.parsed_data = list(map(self.parsed_data.__getitem__, idxs)) |
|
return self |
|
|
|
def shuffle(self, idxs): |
|
|
|
self.data = self.data.iloc[idxs].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, idxs)) |
|
|
|
if self.hdf5_idx_map is not None: |
|
self.hdf5_idx_map = self.hdf5_idx_map[idxs] |
|
|
|
if self.lmdb_idx_map is not None: |
|
self.lmdb_idx_map = self.lmdb_idx_map[idxs] |
|
|
|
def get_label_counts(self) -> np.ndarray: |
|
if self.data.columns.isin(['score']).any(): |
|
if (-1 in self.data['score'].values): |
|
lof = (self.data['score']==-1).sum() |
|
benign = (self.data['score']==0).sum() |
|
gof = (self.data['score']==1).sum() |
|
patho = (self.data['score']==3).sum() |
|
if lof != 0 and gof != 0: |
|
return np.array([lof, benign, gof, patho]) |
|
else: |
|
return np.array([benign, patho]) |
|
else: |
|
benign = (self.data['score']==0).sum() |
|
patho = (self.data['score']==1).sum() |
|
return np.array([benign, patho]) |
|
else: |
|
return np.array([0, 0]) |
|
|
|
|
|
def create_hdf5(self): |
|
hdf5_file = self.data_file.replace('.csv', f'.{datetime.now()}.hdf5') |
|
self.hdf5_file = hdf5_file |
|
self.get_method = 'hdf5' |
|
self.hdf5_keys = None |
|
|
|
self.hdf5_idx_map = np.arange(len(self)) |
|
with h5py.File(hdf5_file, 'w') as f: |
|
for i in range(len(self)): |
|
features = self.get(i) |
|
|
|
if self.hdf5_keys is None: |
|
self.hdf5_keys = list(features.keys()) |
|
for key in features.keys(): |
|
f.create_dataset(f'{i}/{key}', data=features[key]) |
|
return |
|
|
|
|
|
def create_lmdb(self, write_frequency=1000): |
|
lmdb_path = self.data_file.replace('.csv', f'.{datetime.now()}.lmdb') |
|
map_size = 5e12 |
|
db = lmdb.open(lmdb_path, subdir=False, map_size=map_size, readonly=False, meminit=False, map_async=True) |
|
print(f"Begin loading {len(self)} points into lmdb") |
|
txn = db.begin(write=True) |
|
for idx in range(len(self)): |
|
d = self.get(idx) |
|
txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(d)) |
|
print(f'Finished loading {idx}') |
|
if (idx + 1) % write_frequency == 0: |
|
txn.commit() |
|
txn = db.begin(write=True) |
|
txn.commit() |
|
print(f"Finished loading {len(self)} points into lmdb") |
|
self.lmdb_path = lmdb_path |
|
self.lmdb_idx_map = np.arange(len(self)) |
|
self.get_method = 'lmdb' |
|
print("Flushing database ...") |
|
db.sync() |
|
db.close() |
|
return |
|
|
|
def load_all_to_memory(self): |
|
|
|
self.get_method = 'memory' |
|
self.parsed_data = [] |
|
ctime = time.time() |
|
tmp_data = [] |
|
app = tmp_data.append |
|
for i in range(len(self)): |
|
app(self.get(i)) |
|
if (i+1) % 200 == 0: |
|
print(f'rank {self.gpu_id} Finished loading {i+1} points in {time.time() - ctime:.2f} seconds') |
|
ctime = time.time() |
|
self.parsed_data.extend(tmp_data) |
|
tmp_data = [] |
|
app = tmp_data.append |
|
print(f'rank {self.gpu_id} Extended {i+1} points in {time.time() - ctime:.2f} seconds') |
|
self.parsed_data.extend(tmp_data) |
|
|
|
if hasattr(self, 'af2_file_dict'): |
|
del self.af2_file_dict |
|
if hasattr(self, 'af2_coord_dict'): |
|
del self.af2_coord_dict |
|
if hasattr(self, 'af2_plddt_dict'): |
|
del self.af2_plddt_dict |
|
if hasattr(self, 'af2_confidence_dict'): |
|
del self.af2_confidence_dict |
|
if hasattr(self, 'af2_dssp_dict'): |
|
del self.af2_dssp_dict |
|
if hasattr(self, 'af2_graph_dict'): |
|
del self.af2_graph_dict |
|
if hasattr(self, 'esm_file_dict'): |
|
del self.esm_file_dict |
|
if hasattr(self, 'esm_dict'): |
|
del self.esm_dict |
|
if hasattr(self, 'msa_file_dict'): |
|
del self.msa_file_dict |
|
if hasattr(self, 'msa_dict'): |
|
del self.msa_dict |
|
if hasattr(self, 'af2_single_dict'): |
|
del self.af2_single_dict |
|
if hasattr(self, 'af2_pairwise_dict'): |
|
del self.af2_pairwise_dict |
|
return |
|
|
|
|
|
def clean_up(self): |
|
if hasattr(self, 'hdf5_file') and self.hdf5_file is not None and os.path.exists(self.hdf5_file): |
|
os.remove(self.hdf5_file) |
|
if hasattr(self, 'lmdb_path') and self.lmdb_path is not None and os.path.exists(self.lmdb_path): |
|
os.remove(self.lmdb_path) |
|
if hasattr(self, 'af2_pair_dict_lmdb_path') and self.af2_pair_dict_lmdb_path is not None: |
|
for lmdb_path in self.af2_pair_dict_lmdb_path: |
|
if os.path.exists(lmdb_path): |
|
os.remove(lmdb_path) |
|
|
|
if hasattr(self, 'env') and self.env is not None: |
|
self.env.close() |
|
if hasattr(self, 'af2_pairwise_env') and self.af2_pairwise_env is not None: |
|
self.af2_pairwise_env.close() |
|
return |
|
|
|
|
|
class FullGraphMutationDataset(TorchDataset): |
|
""" |
|
MutationDataSet dataset, input a file of mutations, output a star graph and KNN graph |
|
Can be either single mutation or multiple mutations. |
|
|
|
Args: |
|
data_file (string or pd.DataFrame): Path or pd.DataFrame for a csv file for a list of mutations |
|
data_type (string): Type of this data, 'ClinVar', 'DMS', etc |
|
""" |
|
|
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff'] = 'alt', |
|
computed_graph: bool = False, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
convert_to_onesite: bool = False, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
use_lmdb: bool = False |
|
): |
|
super(FullGraphMutationDataset, self).__init__() |
|
if isinstance(data_file, pd.DataFrame): |
|
self.data = data_file |
|
self.data_file = 'pd.DataFrame' |
|
elif isinstance(data_file, str): |
|
try: |
|
self.data = pd.read_csv(data_file, index_col=0, low_memory=False) |
|
except UnicodeDecodeError: |
|
self.data = pd.read_csv(data_file, index_col=0, encoding='ISO-8859-1') |
|
self.data_file = data_file |
|
else: |
|
raise ValueError("data_path must be a string or a pandas.DataFrame") |
|
if convert_to_onesite: |
|
self.data = utils.convert_to_onesite(self.data) |
|
self.data_type = data_type |
|
self._y_columns = self.data.columns[self.data.columns.str.startswith('score')] |
|
self.node_embedding_type = node_embedding_type |
|
self.graph_type = graph_type |
|
self.neighbor_type = neighbor_type |
|
self.add_plddt = add_plddt |
|
self.scale_plddt = scale_plddt |
|
self.add_conservation = add_conservation |
|
self.add_position = add_position |
|
self.use_cb = use_cb |
|
self.add_sidechain = add_sidechain |
|
self.add_msa_contacts = add_msa_contacts |
|
self.add_dssp = add_dssp |
|
self.add_msa = add_msa |
|
self.add_confidence = add_confidence |
|
self.add_af2_single = add_af2_single |
|
self.add_af2_pairwise = add_af2_pairwise |
|
self.loaded_af2_single = loaded_af2_single |
|
self.loaded_af2_pairwise = loaded_af2_pairwise |
|
self.loaded_confidence = loaded_confidence |
|
self.add_ptm = add_ptm |
|
self.loaded_msa = loaded_msa |
|
self.loaded_esm = loaded_esm |
|
self.alt_type = alt_type |
|
self.max_len = max_len |
|
self.loop = loop |
|
self.data_augment = data_augment |
|
|
|
self.af2_file_dict = None |
|
self.af2_coord_dict = None |
|
self.af2_plddt_dict = None |
|
self.af2_confidence_dict = None |
|
self.af2_dssp_dict = None |
|
self.af2_graph_dict = None |
|
self.esm_file_dict = None |
|
self.esm_dict = None |
|
self.msa_file_dict = None |
|
self.msa_dict = None |
|
self._check_embedding_files() |
|
if score_transfer: |
|
|
|
if set(self.data['score'].unique()) <= {0, 1}: |
|
self.data['score'] = self.data['score'] * 3 |
|
else: |
|
warnings.warn("score_transfer is only applied when score is 0 or 1") |
|
if data_augment and set(self.data['score'].unique()) > {0, 1}: |
|
|
|
reverse_data = self.data.copy() |
|
|
|
reverse_data = reverse_data.loc[(reverse_data['score'] == 1) | (reverse_data['score'] == 0), :] |
|
reverse_data['ref'] = self.data['alt'] |
|
reverse_data['alt'] = self.data['ref'] |
|
reverse_data['score'] = -reverse_data['score'] |
|
self.data = pd.concat([self.data, reverse_data], ignore_index=True) |
|
self._set_mutations() |
|
self.computed_graph = computed_graph |
|
self._load_af2_features(radius=radius, max_neighbors=max_neighbors, loop=loop, gpu_id=gpu_id) |
|
if (self.add_msa or self.add_conservation) and self.loaded_msa: |
|
self._load_msa_features() |
|
if self.loaded_esm: |
|
self._load_esm_features() |
|
if self.loaded_af2_pairwise or self.loaded_af2_single: |
|
self._load_af2_reps() |
|
self._set_node_embeddings() |
|
self._set_edge_embeddings() |
|
self.unmatched_msa = 0 |
|
|
|
|
|
if shuffle: |
|
np.random.seed(0) |
|
shuffle_index = np.random.permutation(len(self.mutations)) |
|
self.data = self.data.iloc[shuffle_index].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, shuffle_index)) |
|
if self.add_ptm: |
|
self.ptm_ref = pd.read_csv('./data.files/ptm.small.csv', index_col=0) |
|
self.get_method = 'default' |
|
if use_lmdb: |
|
self.get_method = 'lmdb' |
|
self.lmdb_path = data_file.replace('.csv', '.lmdb') |
|
self.lmdb_idx_map = np.arange(len(self)) |
|
|
|
def _check_embedding_files(self): |
|
print(f"read in {len(self.data)} mutations from {self.data_file}") |
|
|
|
unique_data = self.data.drop_duplicates(subset=['uniprotID']) |
|
print(f"found {len(unique_data)} unique wt sequences") |
|
|
|
if self.node_embedding_type == 'esm': |
|
with Pool(NUM_THREADS) as p: |
|
embedding_exist = p.starmap(utils.get_embedding_from_esm2, zip(unique_data['uniprotID'], cycle([True]))) |
|
|
|
|
|
to_drop = unique_data['wt.orig'].loc[~np.array(embedding_exist, dtype=bool)] |
|
print(f"drop {np.sum(self.data['wt.orig'].isin(to_drop))} mutations that do not have embedding or msa") |
|
self.data = self.data[~self.data['wt.orig'].isin(to_drop)] |
|
else: |
|
print(f"skip checking embedding files for {self.node_embedding_type}") |
|
|
|
def _set_mutations(self): |
|
if 'af2_file' not in self.data.columns: |
|
self.data['af2_file'] = pd.NA |
|
with Pool(NUM_THREADS) as p: |
|
point_mutations = p.starmap(utils.get_mutations, zip(self.data['uniprotID'], |
|
self.data['ENST'] if 'ENST' in self.data.columns else cycle([None]), |
|
self.data['wt.orig'], |
|
self.data['sequence.len.orig'], |
|
self.data['pos.orig'], |
|
self.data['ref'], |
|
self.data['alt'], |
|
cycle([self.max_len]), |
|
self.data['af2_file'],)) |
|
|
|
|
|
print(f"drop {np.sum(~np.array(point_mutations, dtype=bool))} mutations that don't have coordinates") |
|
self.data = self.data.loc[np.array(point_mutations, dtype=bool)] |
|
self.mutations = list(filter(bool, point_mutations)) |
|
print(f'Finished loading {len(self.mutations)} mutations') |
|
|
|
def _load_af2_features(self, radius, max_neighbors, loop, gpu_id): |
|
self.af2_file_dict, mutation_idx = np.unique([mutation.af2_file for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_af2_seq_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
self.af2_coord_dict = p.starmap(utils.get_coords_from_af2, zip(self.af2_file_dict, cycle([self.add_sidechain]))) |
|
print(f'Finished loading {len(self.af2_coord_dict)} af2 coords') |
|
self.af2_plddt_dict = p.starmap(utils.get_plddt_from_af2, zip(self.af2_file_dict)) if self.add_plddt else None |
|
print(f'Finished loading plddt') |
|
self.af2_confidence_dict = p.starmap(utils.get_confidence_from_af2file, zip(self.af2_file_dict, self.af2_plddt_dict)) if self.add_plddt and self.add_confidence and self.loaded_confidence else None |
|
print(f'Finished loading confidence') |
|
self.af2_dssp_dict = p.starmap(utils.get_dssp_from_af2, zip(self.af2_file_dict)) if self.add_dssp else None |
|
print(f'Finished loading dssp') |
|
self.radius = radius |
|
self.max_neighbors = max_neighbors |
|
self.loop = loop |
|
self.gpu_id = gpu_id |
|
|
|
def _load_esm_features(self): |
|
self.esm_file_dict, mutation_idx = np.unique([mutation.ESM_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_esm_seq_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
self.esm_dict = p.starmap(utils.get_esm_dict_from_uniprot, zip(self.esm_file_dict)) |
|
print(f'Finished loading {len(self.esm_file_dict)} esm embeddings') |
|
|
|
def _load_af2_reps(self): |
|
self.af2_rep_file_prefix_dict, mutation_idx = np.unique([mutation.af2_rep_file_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_af2_rep_index(y), self.mutations, mutation_idx)) |
|
with Pool(NUM_THREADS) as p: |
|
if self.add_af2_single and self.loaded_af2_single: |
|
self.af2_single_dict = p.starmap(utils.get_af2_single_rep_dict_from_prefix, zip(self.af2_rep_file_prefix_dict)) |
|
print(f'Finished loading {len(self.af2_rep_file_prefix_dict)} alphafold2 single representations') |
|
|
|
if self.add_af2_pairwise and self.loaded_af2_pairwise: |
|
raise ValueError("Not implemented in this version") |
|
|
|
def _load_msa_features(self): |
|
self.msa_file_dict, mutation_idx = np.unique([mutation.uniprot_id for mutation in self.mutations], |
|
return_inverse=True) |
|
_ = list(map(lambda x, y: x.set_msa_seq_index(y), self.mutations, mutation_idx)) |
|
with get_context('spawn').Pool(NUM_THREADS) as p: |
|
|
|
self.msa_dict = p.starmap(utils.get_msa_dict_from_transcript, zip(self.msa_file_dict)) |
|
print(f'Finished loading {len(self.msa_dict)} msa seqs') |
|
|
|
def _set_node_embeddings(self): |
|
pass |
|
|
|
def _set_edge_embeddings(self): |
|
pass |
|
|
|
def get_mask(self, mutation: utils.Mutation): |
|
return mutation.pos - 1, mutation |
|
|
|
def get_graph_and_mask(self, mutation: utils.Mutation): |
|
|
|
coords: np.ndarray = self.af2_coord_dict[mutation.af2_seq_index] |
|
|
|
if mutation.crop: |
|
coords = coords[mutation.seq_start - 1:mutation.seq_end, :] |
|
|
|
mask_idx, mutation = self.get_mask(mutation) |
|
|
|
if self.add_msa_contacts: |
|
coevo_strength = utils.get_contacts_from_msa(mutation, False) |
|
if isinstance(coevo_strength, int): |
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 1]) |
|
else: |
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 0]) |
|
start = time.time() |
|
if self.add_af2_pairwise: |
|
if self.loaded_af2_pairwise: |
|
|
|
|
|
|
|
byteflow = self.af2_pairwise_txn.get(u'{}'.format(mutation.af2_rep_file_prefix.split('/')[-1]).encode('ascii')) |
|
pairwise_rep = pickle.loads(byteflow) |
|
if pairwise_rep is None: |
|
pairwise_rep = utils.get_af2_pairwise_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
pairwise_rep = utils.get_af2_pairwise_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
|
|
if mutation.af2_rep_file_prefix.find('-F') == -1: |
|
pairwise_rep = pairwise_rep[mutation.seq_start_orig - 1: mutation.seq_end_orig, |
|
mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
pairwise_rep = pairwise_rep[mutation.seq_start - 1: mutation.seq_end, |
|
mutation.seq_start - 1: mutation.seq_end] |
|
coevo_strength = np.concatenate([coevo_strength, pairwise_rep], axis=2) |
|
end = time.time() |
|
print(f'Finished loading pairwise in {end - start:.2f} seconds') |
|
edge_attr = coevo_strength |
|
|
|
if self.add_position: |
|
|
|
edge_position = np.arange(coords.shape[0])[:, None] - np.arange(coords.shape[0])[None, :] |
|
edge_attr = np.concatenate( |
|
(edge_attr, np.sin(np.pi / 2 * edge_position / self.max_len)[:, :, None]), |
|
axis=2) |
|
return coords, None, None, edge_attr, None, mask_idx, mutation |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, _, _, edge_attr, _, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
if self.node_embedding_type == 'esm': |
|
if self.loaded_esm: |
|
|
|
embed_data = self.esm_dict[mutation.esm_seq_index][mutation.seq_start:mutation.seq_end + 1] |
|
else: |
|
embed_data = utils.get_embedding_from_esm2(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
to_alt = np.concatenate([utils.ESM_AA_EMBEDDING_DICT[alt_aa].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([utils.ESM_AA_EMBEDDING_DICT[ref_aa].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'one-hot-idx': |
|
assert not self.add_conservation and not self.add_plddt |
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_onehot_nonzero(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
to_alt = np.concatenate([np.array(utils.AA_DICT.index(alt_aa)).reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.array(utils.AA_DICT.index(ref_aa)).reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'one-hot': |
|
embed_data, one_hot_mat = utils.get_embedding_from_onehot(mutation.seq, return_idx=False, return_onehot_mat=True) |
|
to_alt = np.concatenate([np.eye(len(utils.AA_DICT))[utils.AA_DICT.index(alt_aa)].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.eye(len(utils.AA_DICT))[utils.AA_DICT.index(ref_aa)].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'aa-5dim': |
|
embed_data = utils.get_embedding_from_5dim(mutation.seq) |
|
to_alt = np.concatenate([np.array(utils.AA_5DIM_EMBED[alt_aa]).reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([np.array(utils.AA_5DIM_EMBED[ref_aa]).reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
elif self.node_embedding_type == 'esm1b': |
|
embed_data = utils.get_embedding_from_esm1b(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
to_alt = np.concatenate([utils.ESM1b_AA_EMBEDDING_DICT[alt_aa].reshape(1, -1) for alt_aa in mutation.alt_aa]) |
|
to_ref = np.concatenate([utils.ESM1b_AA_EMBEDDING_DICT[ref_aa].reshape(1, -1) for ref_aa in mutation.ref_aa]) |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
to_alt = np.concatenate([to_alt, conservation_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, conservation_data[mask_idx]], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
to_alt = np.concatenate([to_alt, plddt_data[mask_idx, None]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, plddt_data[mask_idx]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
to_alt = np.concatenate([to_alt, dssp_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, dssp_data[mask_idx]], axis=1) |
|
if self.add_ptm: |
|
|
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
to_alt = np.concatenate([to_alt, ptm_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, ptm_data[mask_idx]], axis=1) |
|
if self.add_af2_single: |
|
if self.loaded_af2_single: |
|
single_rep = self.af2_single_dict[mutation.af2_rep_index] |
|
else: |
|
single_rep = utils.get_af2_single_rep_dict_from_prefix(mutation.af2_rep_file_prefix) |
|
|
|
if mutation.af2_rep_file_prefix.find('-F') == -1: |
|
single_rep = single_rep[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
single_rep = single_rep[mutation.seq_start - 1: mutation.seq_end] |
|
embed_data = np.concatenate([embed_data, single_rep], axis=1) |
|
to_alt = np.concatenate([to_alt, single_rep[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, single_rep[mask_idx]], axis=1) |
|
if self.add_msa: |
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.alt_type == 'alt': |
|
to_alt = np.concatenate([to_alt, msa_data[mask_idx]], axis=1) |
|
if self.alt_type == 'diff': |
|
to_ref = np.concatenate([to_ref, msa_data[mask_idx]], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
if self.alt_type == 'alt': |
|
alt_embed_data = np.zeros_like(embed_data) |
|
alt_embed_data[mask_idx] = to_alt |
|
elif self.alt_type == 'concat': |
|
alt_embed_data = np.zeros((embed_data.shape[0], to_alt.shape[1] + to_ref.shape[1])) |
|
alt_embed_data[mask_idx] = np.concatenate([to_alt, to_ref], axis=1) |
|
elif self.alt_type == 'diff': |
|
alt_embed_data = np.zeros_like(embed_data) |
|
alt_embed_data[mask_idx] = to_alt |
|
embed_data[mask_idx] = to_ref |
|
else: |
|
raise ValueError(f'alt_type {self.alt_type} not supported') |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
if self.add_sidechain: |
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=embed_logits if self.node_embedding_type == 'one-hot-idx' else None, |
|
one_hot_mat=one_hot_mat if self.node_embedding_type.startswith('one-hot') else None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=alt_embed_data, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=None, |
|
edge_index_star=None, |
|
edge_attr=edge_attr, |
|
edge_attr_star=None, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae |
|
return features |
|
|
|
def get(self, idx): |
|
start_time=time.time() |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
|
|
x_padding_mask = torch.zeros(self.max_len, dtype=torch.bool) |
|
pos=torch.from_numpy(features_np['CB_coord']).to(torch.float32) if self.use_cb else torch.from_numpy(features_np['CA_coord']).to(torch.float32) |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32) |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32) |
|
x_mask=torch.from_numpy(features_np['embed_data_mask'][:, 0]).to(torch.bool) |
|
x_alt=torch.from_numpy(features_np['alt_embed_data']).to(torch.float32) |
|
if self.add_confidence: |
|
plddt=torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
edge_confidence=torch.from_numpy(features_np['edge_confidence']).to(torch.float32) |
|
if x.shape[0] < self.max_len: |
|
x_padding_mask[x.shape[0]:] = True |
|
x = torch.nn.functional.pad(x, (0, 0, 0, self.max_len - x.shape[0])) |
|
pos = torch.nn.functional.pad(pos, (0, 0, 0, self.max_len - pos.shape[0])) |
|
node_vec_attr = torch.nn.functional.pad(node_vec_attr, (0, 0, 0, 0, 0, self.max_len - node_vec_attr.shape[0])) |
|
edge_attr = torch.nn.functional.pad(edge_attr, (0, 0, 0, self.max_len - edge_attr.shape[0], 0, self.max_len - edge_attr.shape[0])) |
|
x_alt = torch.nn.functional.pad(x_alt, (0, 0, 0, self.max_len - x_alt.shape[0])) |
|
x_mask = torch.nn.functional.pad(x_mask, (0, self.max_len - x_mask.shape[0]), 'constant', True) |
|
if self.add_confidence: |
|
edge_confidence = torch.nn.functional.pad(edge_confidence, (0, self.max_len - edge_confidence.shape[0], 0, self.max_len - edge_confidence.shape[0])) |
|
plddt = torch.nn.functional.pad(plddt, (0, self.max_len - plddt.shape[0])) |
|
features = dict( |
|
x=x, |
|
x_padding_mask=x_padding_mask, |
|
x_mask=x_mask, |
|
x_alt=x_alt, |
|
pos=pos, |
|
edge_attr=edge_attr, |
|
node_vec_attr=node_vec_attr, |
|
y=torch.tensor([self.data[self._y_columns].iloc[int(idx)]]).to(torch.float32).unsqueeze(0), |
|
) |
|
if self.add_confidence: |
|
features['plddt'] = plddt |
|
features['edge_confidence'] = edge_confidence |
|
print(f'Finished loading {idx}th mutation in {time.time() - start_time} seconds') |
|
return features |
|
|
|
def get_from_hdf5(self, idx): |
|
if not hasattr(self, 'hdf5_keys') or self.hdf5_file is None: |
|
raise ValueError('hdf5 file is not set') |
|
else: |
|
features = {} |
|
with h5py.File(self.hdf5_file, 'r') as f: |
|
for key in self.hdf5_keys: |
|
features[key] = torch.tensor(f[f'{self.hdf5_idx_map[idx]}/{key}']) |
|
return Data(**features) |
|
|
|
def open_lmdb(self): |
|
self.env = lmdb.open(self.lmdb_path, subdir=False, |
|
readonly=True, lock=False, |
|
readahead=False, meminit=False) |
|
self.txn = self.env.begin(write=False, buffers=True) |
|
|
|
def get_from_lmdb(self, idx): |
|
if not hasattr(self, 'txn') or self.txn is None: |
|
self.open_lmdb() |
|
byteflow = self.txn.get(u'{}'.format(self.lmdb_idx_map[idx]).encode('ascii')) |
|
if byteflow is None: |
|
return self.get(idx) |
|
else: |
|
unpacked = pickle.loads(byteflow) |
|
return unpacked |
|
|
|
def __getitem__(self, idx): |
|
|
|
start = time.time() |
|
if self.get_method == 'default': |
|
data = self.get(idx) |
|
print(f'default Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
elif self.get_method == 'hdf5': |
|
data = self.get_from_hdf5(idx) |
|
print(f'hdf5 Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
elif self.get_method == 'lmdb': |
|
data = self.get_from_lmdb(idx) |
|
print(f'lmdb Finished loading {idx} in {time.time() - start:.2f} seconds') |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.mutations) |
|
|
|
def len(self) -> int: |
|
return len(self.mutations) |
|
|
|
def subset(self, idxs): |
|
self.data = self.data.iloc[idxs].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, idxs)) |
|
|
|
subset_af2_file_dict, mutation_idx = np.unique([mutation.af2_file for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
if self.af2_file_dict is not None: |
|
af2_file_idx = np.array([np.where(self.af2_file_dict==i)[0][0] for i in subset_af2_file_dict]) |
|
self.af2_file_dict = subset_af2_file_dict |
|
|
|
self.af2_coord_dict = list(map(self.af2_coord_dict.__getitem__, af2_file_idx)) if self.af2_coord_dict is not None else None |
|
self.af2_plddt_dict = list(map(self.af2_plddt_dict.__getitem__, af2_file_idx)) if self.af2_plddt_dict is not None else None |
|
self.af2_confidence_dict = list(map(self.af2_confidence_dict.__getitem__, af2_file_idx)) if self.af2_confidence_dict is not None else None |
|
self.af2_dssp_dict = list(map(self.af2_dssp_dict.__getitem__, af2_file_idx)) if self.af2_dssp_dict is not None else None |
|
self.af2_graph_dict = list(map(self.af2_graph_dict.__getitem__, af2_file_idx)) if self.af2_graph_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_af2_seq_index(y), self.mutations, mutation_idx)) |
|
|
|
if self.esm_file_dict is not None: |
|
subset_esm_file_dict, mutation_idx = np.unique([mutation.ESM_prefix for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
esm_file_idx = np.array([np.where(self.esm_file_dict==i)[0][0] for i in subset_esm_file_dict]) |
|
self.esm_file_dict = subset_esm_file_dict |
|
|
|
self.esm_dict = list(map(self.esm_dict.__getitem__, esm_file_idx)) if self.esm_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_esm_seq_index(y), self.mutations, mutation_idx)) |
|
|
|
if self.msa_file_dict is not None: |
|
subset_msa_file_dict, mutation_idx = np.unique([mutation.uniprot_id for mutation in self.mutations], |
|
return_inverse=True) |
|
|
|
msa_file_idx = np.array([np.where(self.msa_file_dict==i)[0][0] for i in subset_msa_file_dict]) |
|
self.msa_file_dict = subset_msa_file_dict |
|
|
|
self.msa_dict = list(map(self.msa_dict.__getitem__, msa_file_idx)) if self.msa_dict is not None else None |
|
|
|
_ = list(map(lambda x, y: x.set_msa_seq_index(y), self.mutations, mutation_idx)) |
|
return self |
|
|
|
def shuffle(self, idxs): |
|
|
|
self.data = self.data.iloc[idxs].reset_index(drop=True) |
|
self.mutations = list(map(self.mutations.__getitem__, idxs)) |
|
|
|
def get_label_counts(self) -> np.ndarray: |
|
if self.data.columns.isin(['score']).any(): |
|
if (-1 in self.data['score'].values): |
|
lof = (self.data['score']==-1).sum() |
|
benign = (self.data['score']==0).sum() |
|
gof = (self.data['score']==1).sum() |
|
patho = (self.data['score']==3).sum() |
|
if lof != 0 and gof != 0: |
|
return np.array([lof, benign, gof, patho]) |
|
else: |
|
return np.array([benign, patho]) |
|
else: |
|
benign = (self.data['score']==0).sum() |
|
patho = (self.data['score']==1).sum() |
|
return np.array([benign, patho]) |
|
else: |
|
return np.array([0, 0]) |
|
|
|
|
|
def create_hdf5(self): |
|
hdf5_file = self.data_file.replace('.csv', '.hdf5') |
|
self.hdf5_file = hdf5_file |
|
self.get_method = 'hdf5' |
|
self.hdf5_keys = None |
|
|
|
self.hdf5_idx_map = np.arange(len(self)) |
|
with h5py.File(hdf5_file, 'w') as f: |
|
for i in range(len(self)): |
|
features = self.get(i) |
|
|
|
if self.hdf5_keys is None: |
|
self.hdf5_keys = list(features.keys()) |
|
for key in features.keys(): |
|
f.create_dataset(f'{i}/{key}', data=features[key]) |
|
return |
|
|
|
|
|
def create_lmdb(self, write_frequency=1000): |
|
lmdb_path = self.data_file.replace('.csv', f'.{datetime.now()}.lmdb') |
|
map_size = 5e12 |
|
db = lmdb.open(lmdb_path, subdir=False, map_size=map_size, readonly=False, meminit=False, map_async=True) |
|
print(f"Begin loading {len(self)} points into lmdb") |
|
txn = db.begin(write=True) |
|
for idx in range(len(self)): |
|
d = self.get(idx) |
|
txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps(d)) |
|
print(f'Finished loading {idx}') |
|
if (idx + 1) % write_frequency == 0: |
|
txn.commit() |
|
txn = db.begin(write=True) |
|
txn.commit() |
|
print(f"Finished loading {len(self)} points into lmdb") |
|
self.lmdb_path = lmdb_path |
|
self.lmdb_idx_map = np.arange(len(self)) |
|
self.get_method = 'lmdb' |
|
print("Flushing database ...") |
|
db.sync() |
|
db.close() |
|
return |
|
|
|
|
|
def clean_up(self): |
|
if hasattr(self, 'hdf5_file') and self.hdf5_file is not None and os.path.exists(self.hdf5_file): |
|
os.remove(self.hdf5_file) |
|
if hasattr(self, 'lmdb_path') and self.lmdb_path is not None and os.path.exists(self.lmdb_path): |
|
os.remove(self.lmdb_path) |
|
return |
|
|
|
|
|
class MutationDataset(GraphMutationDataset): |
|
""" |
|
MutationDataSet dataset, input a file of mutations, output without graph. |
|
Can be either single mutation or multiple mutations. |
|
|
|
Args: |
|
data_file (string or pd.DataFrame): Path or pd.DataFrame for a csv file for a list of mutations |
|
data_type (string): Type of this data, 'ClinVar', 'DMS', etc |
|
""" |
|
|
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = 50, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
precomputed_graph: bool = False, |
|
add_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_msa_contacts: bool = True, |
|
max_len: int = 700, |
|
padding: bool = False, |
|
): |
|
self.padding = padding |
|
super(MutationDataset, self).__init__(data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, precomputed_graph, add_plddt, add_conservation, |
|
add_position, add_msa_contacts, |
|
max_len) |
|
|
|
def __getitem__(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
orig_len = features_np['embed_data'].shape[0] |
|
if self.padding and orig_len < self.max_len: |
|
features_np['embed_data'] = np.pad(features_np['embed_data'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
features_np['coords'] = np.pad(features_np['coords'], ((0, self.max_len - orig_len), (0, 0), (0, 0)), 'constant') |
|
features_np['alt_embed_data'] = np.pad(features_np['alt_embed_data'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
features_np['embed_data_mask'] = np.pad(features_np['embed_data_mask'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
y_mask = np.concatenate((np.ones(orig_len), np.zeros(self.max_len - orig_len))) |
|
else: |
|
y_mask = np.ones(orig_len) |
|
|
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=torch.from_numpy(features_np['alt_embed_data']).to(torch.float32), |
|
pos=torch.from_numpy(features_np['coords']).to(torch.float32), |
|
edge_index=torch.tensor([torch.nan]), |
|
edge_index_star=torch.tensor([torch.nan]), |
|
edge_attr=torch.tensor([torch.nan]), |
|
edge_attr_star=torch.tensor([torch.nan]), |
|
node_vec_attr=torch.tensor([torch.nan]), |
|
y=torch.tensor(self.data[self._y_columns].iloc[int(idx)]).to(torch.float32), |
|
y_mask=torch.from_numpy(y_mask).to(torch.bool), |
|
) |
|
return features |
|
|
|
def get(self, idx): |
|
return self.__getitem__(idx) |
|
|
|
|
|
class GraphMaskPredictMutationDataset(GraphMutationDataset): |
|
""" |
|
MutationDataSet dataset, input a file of mutations, output without graph. |
|
Can be either single mutation or multiple mutations. |
|
|
|
Args: |
|
data_file (string or pd.DataFrame): Path or pd.DataFrame for a csv file for a list of mutations |
|
data_type (string): Type of this data, 'ClinVar', 'DMS', etc |
|
""" |
|
|
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = 50, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['one-hot-idx', 'one-hot'] = 'one-hot-idx', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_msa_contacts: bool = True, |
|
computed_graph: bool = True, |
|
neighbor_type: Literal['KNN', 'radius'] = 'KNN', |
|
max_len: int = 700, |
|
mask_percentage: float = 0.15, |
|
): |
|
self.mask_percentage = mask_percentage |
|
super(GraphMaskPredictMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, add_plddt, add_conservation, |
|
add_position, add_msa_contacts, computed_graph, neighbor_type, |
|
max_len) |
|
|
|
def get_mask(self, mutation: utils.Mutation): |
|
|
|
seq_len = mutation.seq_end - mutation.seq_start + 1 |
|
if not pd.isna(mutation.alt_aa): |
|
|
|
points_to_mask = int(seq_len * self.mask_percentage) |
|
if points_to_mask > 1: |
|
mask_idx = np.random.choice(seq_len, int(seq_len * 0.15) - 1, replace=False) |
|
mask_idx = np.append(mask_idx, mutation.pos - 1) |
|
else: |
|
mask_idx = np.array([mutation.pos - 1]) |
|
else: |
|
mask_idx = np.random.choice(seq_len, int(seq_len * 0.15), replace=False) |
|
mutation.ref_aa = np.array(list(mutation.seq))[mask_idx] |
|
mutation.alt_aa = np.array(['<mask>'] * (len(mask_idx))) |
|
return mask_idx, mutation |
|
|
|
def get(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
embed_logits = features_np['embed_logits'] |
|
one_hot_mat = features_np['one_hot_mat'] |
|
mutation: utils.Mutation = self.mutaions[idx] |
|
|
|
if not pd.isna(mutation.alt_aa): |
|
embed_logits[mutation.pos - 1] = (one_hot_mat[utils.AA_DICT.index(mutation.ref_aa)] |
|
+ one_hot_mat[utils.AA_DICT.index(mutation.alt_aa)]) / 2 |
|
|
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=torch.from_numpy(features_np['alt_embed_data']).to(torch.float32), |
|
pos=torch.from_numpy(features_np['CA_coord']).to(torch.float32), |
|
edge_index=torch.from_numpy(features_np['edge_index']).to(torch.long), |
|
edge_index_star=torch.from_numpy(features_np['edge_index_star']).to(torch.long), |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32), |
|
edge_attr_star=torch.from_numpy(features_np['edge_attr_star']).to(torch.float32), |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32), |
|
y=torch.from_numpy(embed_logits).to(torch.float32), |
|
) |
|
features["edge_index"], features["edge_attr"], mask = \ |
|
remove_isolated_nodes(features["edge_index"], features["edge_attr"], x.shape[0]) |
|
features["edge_index_star"], features["edge_attr_star"], mask = \ |
|
remove_isolated_nodes(features["edge_index_star"], features["edge_attr_star"], x.shape[0]) |
|
features["x"] = features["x"][mask] |
|
features["x_mask"] = features["x_mask"][mask] |
|
features["x_alt"] = features["x_alt"][mask] |
|
features["pos"] = features["pos"][mask] |
|
features["node_vec_attr"] = features["node_vec_attr"][mask] |
|
return Data(**features) |
|
|
|
|
|
class MaskPredictMutationDataset(GraphMaskPredictMutationDataset): |
|
""" |
|
MutationDataSet dataset, input a file of mutations, output without graph. |
|
Can be either single mutation or multiple mutations. |
|
|
|
Args: |
|
data_file (string or pd.DataFrame): Path or pd.DataFrame for a csv file for a list of mutations |
|
data_type (string): Type of this data, 'ClinVar', 'DMS', etc |
|
""" |
|
|
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = 50, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['one-hot-idx', 'one-hot'] = 'one-hot-idx', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
precomputed_graph: bool = False, |
|
add_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_msa_contacts: bool = True, |
|
max_len: int = 700, |
|
padding: bool = False, |
|
mask_percentage: float = 0.15, |
|
): |
|
self.padding = padding |
|
super(MaskPredictMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, precomputed_graph, add_plddt, add_conservation, |
|
add_position, add_msa_contacts, |
|
max_len, mask_percentage) |
|
|
|
def get_mask(self, mutation: utils.Mutation): |
|
|
|
seq_len = mutation.seq_end - mutation.seq_start + 1 |
|
if not pd.isna(mutation.alt_aa): |
|
|
|
points_to_mask = int(seq_len * self.mask_percentage) |
|
if points_to_mask > 1: |
|
mask_idx = np.random.choice(seq_len, int(seq_len * 0.15) - 1, replace=False) |
|
mask_idx = np.append(mask_idx, mutation.pos - 1) |
|
else: |
|
mask_idx = np.array([mutation.pos - 1]) |
|
else: |
|
mask_idx = np.random.choice(seq_len, int(seq_len * 0.15), replace=False) |
|
mutation.ref_aa = np.array(list(mutation.seq))[mask_idx] |
|
mutation.alt_aa = np.array(['<mask>'] * (len(mask_idx))) |
|
return mask_idx, mutation |
|
|
|
def __getitem__(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
embed_logits = features_np['embed_logits'] |
|
one_hot_mat = features_np['one_hot_mat'] |
|
mutation: utils.Mutation = self.mutaions[idx] |
|
|
|
if not pd.isna(mutation.alt_aa): |
|
embed_logits[mutation.pos - 1] = (one_hot_mat[utils.AA_DICT.index(mutation.ref_aa)] |
|
+ one_hot_mat[utils.AA_DICT.index(mutation.alt_aa)]) / 2 |
|
|
|
orig_len = features_np['embed_data'].shape[0] |
|
if self.padding and orig_len < self.max_len: |
|
features_np['embed_data'] = np.pad(features_np['embed_data'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
embed_logits = np.pad(embed_logits, ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
features_np['coords'] = np.pad(features_np['coords'], ((0, self.max_len - orig_len), (0, 0), (0, 0)), 'constant') |
|
features_np['alt_embed_data'] = np.pad(features_np['alt_embed_data'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
features_np['embed_data_mask'] = np.pad(features_np['embed_data_mask'], ((0, self.max_len - orig_len), (0, 0)), 'constant') |
|
y_mask = np.concatenate((np.ones(orig_len), np.zeros(self.max_len - orig_len))) |
|
else: |
|
y_mask = np.ones(orig_len) |
|
|
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=torch.from_numpy(features_np['alt_embed_data']).to(torch.float32), |
|
pos=torch.from_numpy(features_np['CA_coord']).to(torch.float32), |
|
edge_index=torch.tensor([torch.nan]), |
|
edge_index_star=torch.tensor([torch.nan]), |
|
edge_attr=torch.tensor([torch.nan]), |
|
edge_attr_star=torch.tensor([torch.nan]), |
|
node_vec_attr=torch.tensor([torch.nan]), |
|
y=torch.from_numpy(embed_logits).to(torch.float32), |
|
y_mask=torch.from_numpy(y_mask).to(torch.bool), |
|
) |
|
return features |
|
|
|
|
|
class GraphMultiOnesiteMutationDataset(GraphMutationDataset): |
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff'] = 'alt', |
|
computed_graph: bool = True, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
convert_to_onesite: bool = False, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
): |
|
super(GraphMultiOnesiteMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, add_plddt, scale_plddt, |
|
add_conservation, add_position, add_sidechain, |
|
local_coord_transform, use_cb, add_msa_contacts, add_dssp, |
|
add_msa, add_confidence, loaded_confidence, loaded_esm, |
|
add_ptm, data_augment, score_transfer, alt_type, |
|
computed_graph, loaded_msa, neighbor_type, max_len) |
|
self._y_mask_columns = self.data.columns[self.data.columns.str.startswith('confidence.score')] |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, edge_index, edge_index_star, edge_attr, edge_attr_star, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
if self.node_embedding_type == 'esm': |
|
if self.loaded_esm: |
|
|
|
embed_data = self.esm_dict[mutation.esm_seq_index][mutation.seq_start:mutation.seq_end + 1] |
|
else: |
|
embed_data = utils.get_embedding_from_esm2(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
elif self.node_embedding_type == 'one-hot-idx': |
|
assert not self.add_conservation and not self.add_plddt |
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_onehot_nonzero(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'one-hot': |
|
embed_data, one_hot_mat = utils.get_embedding_from_onehot(mutation.seq, return_idx=False, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'aa-5dim': |
|
embed_data = utils.get_embedding_from_5dim(mutation.seq) |
|
elif self.node_embedding_type == 'esm1b': |
|
embed_data = utils.get_embedding_from_esm1b(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
if self.add_msa: |
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
warnings.warn(f'MSA file: {mutation.transcript_id} does not match mutation sequence') |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.add_ptm: |
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
|
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=embed_logits if self.node_embedding_type == 'one-hot-idx' else None, |
|
one_hot_mat=one_hot_mat if self.node_embedding_type.startswith('one-hot') else None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=None, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=edge_index, |
|
edge_index_star=edge_index_star, |
|
edge_attr=edge_attr, |
|
edge_attr_star=edge_attr_star, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae[edge_index[0], edge_index[1]] |
|
features['edge_confidence_star'] = pae[edge_index_star[0], edge_index_star[1]] |
|
return features |
|
|
|
def get(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=torch.zeros_like(x), |
|
pos=torch.from_numpy(features_np['CA_coord']).to(torch.float32) if not self.use_cb else torch.from_numpy(features_np['CB_coord']).to(torch.float32), |
|
edge_index=torch.from_numpy(features_np['edge_index']).to(torch.long), |
|
edge_index_star=torch.from_numpy(features_np['edge_index_star']).to(torch.long), |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32), |
|
edge_attr_star=torch.from_numpy(features_np['edge_attr_star']).to(torch.float32), |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32), |
|
) |
|
if self.add_confidence: |
|
features['plddt'] = torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
features['edge_confidence'] = torch.from_numpy(features_np['edge_confidence']).to(torch.float32) |
|
features['edge_confidence_star'] = torch.from_numpy(features_np['edge_confidence_star']).to(torch.float32) |
|
if self.neighbor_type == 'radius' or self.neighbor_type == 'radius-KNN': |
|
|
|
concat_edge_index = torch.cat((features["edge_index"], features["edge_index_star"]), dim=1) |
|
concat_edge_attr = torch.cat((features["edge_attr"], features["edge_attr_star"]), dim=0) |
|
|
|
concat_edge_index, concat_edge_attr, mask = \ |
|
remove_isolated_nodes(concat_edge_index, concat_edge_attr, x.shape[0]) |
|
|
|
features["edge_index"] = concat_edge_index[:, :features["edge_index"].shape[1]] |
|
features["edge_index_star"] = concat_edge_index[:, features["edge_index"].shape[1]:] |
|
features["edge_attr"] = concat_edge_attr[:features["edge_attr"].shape[0]] |
|
features["edge_attr_star"] = concat_edge_attr[features["edge_attr"].shape[0]:] |
|
else: |
|
features["edge_index"], features["edge_attr"], mask = \ |
|
remove_isolated_nodes(features["edge_index"], features["edge_attr"], x.shape[0]) |
|
features["edge_index_star"], features["edge_attr_star"], mask = \ |
|
remove_isolated_nodes(features["edge_index_star"], features["edge_attr_star"], x.shape[0]) |
|
features["x"] = features["x"][mask] |
|
features["x_mask"] = features["x_mask"][mask] |
|
features["x_alt"] = features["x_alt"][mask] |
|
features["pos"] = features["pos"][mask] |
|
features["node_vec_attr"] = features["node_vec_attr"][mask] |
|
|
|
y_scores = self.data[self._y_columns].iloc[int(idx)] |
|
|
|
if len(self._y_mask_columns) > 0: |
|
y_masks = self.data[self._y_mask_columns].iloc[int(idx)] |
|
else: |
|
|
|
y_masks = [None] * len(y_scores) |
|
|
|
y = torch.zeros([1, len(utils.AA_DICT_HUMAN), len(y_scores)]).to(torch.float32) |
|
y_mask = torch.zeros_like(y) |
|
|
|
|
|
for i in range(len(y_scores)): |
|
y_scores_i = np.array(y_scores[i].split(';')).astype(np.float32) if isinstance(y_scores[i], str) else np.array([y_scores[i]]).astype(np.float32) |
|
if y_masks[i] is not None: |
|
y_masks_i = np.array(y_masks[i].split(';')).astype(np.float32) if isinstance(y_masks[i], str) else np.array([y_masks[i]]).astype(np.float32) |
|
else: |
|
y_masks_i = np.ones_like(y_scores_i) |
|
|
|
alt_aa_idxs = [utils.AA_DICT_HUMAN.index(aa) if aa != 'X' else 19 for aa in self.mutations[idx].alt_aa] |
|
y[0, alt_aa_idxs, i] = torch.from_numpy(y_scores_i) |
|
y_mask[0, alt_aa_idxs, i] = torch.from_numpy(y_masks_i) |
|
features["y"] = y.to(torch.float32) |
|
features["score_mask"] = y_mask.to(torch.float32) |
|
return Data(**features) |
|
|
|
|
|
class GraphESMMutationDataset(GraphMutationDataset): |
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff', 'orig'] = 'orig', |
|
computed_graph: bool = True, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
convert_to_onesite: bool = False, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
): |
|
super(GraphESMMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, add_plddt, scale_plddt, |
|
add_conservation, add_position, add_sidechain, |
|
local_coord_transform, use_cb, add_msa_contacts, add_dssp, |
|
add_msa, add_confidence, loaded_confidence, loaded_esm, |
|
add_ptm, data_augment, score_transfer, alt_type, |
|
computed_graph, loaded_msa, neighbor_type, max_len) |
|
self._y_mask_columns = self.data.columns[self.data.columns.str.startswith('confidence.score')] |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, edge_index, edge_index_star, edge_attr, edge_attr_star, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
if self.node_embedding_type == 'esm': |
|
if self.loaded_esm: |
|
|
|
embed_data = self.esm_dict[mutation.esm_seq_index][mutation.seq_start:mutation.seq_end + 1] |
|
else: |
|
embed_data = utils.get_embedding_from_esm2(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
elif self.node_embedding_type == 'one-hot-idx': |
|
assert not self.add_conservation and not self.add_plddt |
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_onehot_nonzero(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'one-hot': |
|
embed_data, one_hot_mat = utils.get_embedding_from_onehot(mutation.seq, return_idx=False, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'aa-5dim': |
|
embed_data = utils.get_embedding_from_5dim(mutation.seq) |
|
elif self.node_embedding_type == 'esm1b': |
|
embed_data = utils.get_embedding_from_esm1b(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
if self.add_msa: |
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
warnings.warn(f'MSA file: {mutation.transcript_id} does not match mutation sequence') |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.add_ptm: |
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
|
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=embed_logits if self.node_embedding_type == 'one-hot-idx' else None, |
|
one_hot_mat=one_hot_mat if self.node_embedding_type.startswith('one-hot') else None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=None, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=edge_index, |
|
edge_index_star=edge_index_star, |
|
edge_attr=edge_attr, |
|
edge_attr_star=edge_attr_star, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae[edge_index[0], edge_index[1]] |
|
features['edge_confidence_star'] = pae[edge_index_star[0], edge_index_star[1]] |
|
return features |
|
|
|
def get(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
features = dict( |
|
x=x, |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool), |
|
x_alt=x.clone(), |
|
pos=torch.from_numpy(features_np['CA_coord']).to(torch.float32) if not self.use_cb else torch.from_numpy(features_np['CB_coord']).to(torch.float32), |
|
edge_index=torch.from_numpy(features_np['edge_index']).to(torch.long), |
|
edge_index_star=torch.from_numpy(features_np['edge_index_star']).to(torch.long), |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32), |
|
edge_attr_star=torch.from_numpy(features_np['edge_attr_star']).to(torch.float32), |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32), |
|
) |
|
if self.add_confidence: |
|
features['plddt'] = torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
features['edge_confidence'] = torch.from_numpy(features_np['edge_confidence']).to(torch.float32) |
|
features['edge_confidence_star'] = torch.from_numpy(features_np['edge_confidence_star']).to(torch.float32) |
|
if self.neighbor_type == 'radius' or self.neighbor_type == 'radius-KNN': |
|
|
|
concat_edge_index = torch.cat((features["edge_index"], features["edge_index_star"]), dim=1) |
|
concat_edge_attr = torch.cat((features["edge_attr"], features["edge_attr_star"]), dim=0) |
|
|
|
concat_edge_index, concat_edge_attr, mask = \ |
|
remove_isolated_nodes(concat_edge_index, concat_edge_attr, x.shape[0]) |
|
|
|
features["edge_index"] = concat_edge_index[:, :features["edge_index"].shape[1]] |
|
features["edge_index_star"] = concat_edge_index[:, features["edge_index"].shape[1]:] |
|
features["edge_attr"] = concat_edge_attr[:features["edge_attr"].shape[0]] |
|
features["edge_attr_star"] = concat_edge_attr[features["edge_attr"].shape[0]:] |
|
else: |
|
features["edge_index"], features["edge_attr"], mask = \ |
|
remove_isolated_nodes(features["edge_index"], features["edge_attr"], x.shape[0]) |
|
features["edge_index_star"], features["edge_attr_star"], mask = \ |
|
remove_isolated_nodes(features["edge_index_star"], features["edge_attr_star"], x.shape[0]) |
|
features["x"] = features["x"][mask] |
|
features["x_mask"] = features["x_mask"][mask] |
|
features["x_alt"] = features["x_alt"][mask] |
|
features["pos"] = features["pos"][mask] |
|
features["node_vec_attr"] = features["node_vec_attr"][mask] |
|
|
|
y_mask = torch.zeros([1, len(utils.ESM_TOKENS)]).to(torch.float32) |
|
|
|
|
|
|
|
alt_aa_idxs = [utils.ESM_TOKENS.index(aa) for aa in self.mutations[idx].alt_aa] |
|
ref_aa_idxs = [utils.ESM_TOKENS.index(aa) for aa in self.mutations[idx].ref_aa] |
|
y_mask[0, alt_aa_idxs] = 1 |
|
y_mask[0, ref_aa_idxs] = -1 |
|
features["y"] = torch.tensor([self.data[self._y_columns].iloc[int(idx)]]).to(torch.float32) |
|
features["esm_mask"] = y_mask.to(torch.float32) |
|
return Data(**features) |
|
|
|
|
|
class FullGraphESMMutationDataset(FullGraphMutationDataset): |
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff'] = 'alt', |
|
computed_graph: bool = True, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
convert_to_onesite: bool = False, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
): |
|
super(FullGraphESMMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, add_plddt, scale_plddt, |
|
add_conservation, add_position, add_sidechain, |
|
local_coord_transform, use_cb, add_msa_contacts, add_dssp, |
|
add_msa, add_confidence, loaded_confidence, loaded_esm, |
|
add_ptm, data_augment, score_transfer, alt_type, |
|
computed_graph, loaded_msa, neighbor_type, max_len, convert_to_onesite) |
|
self._y_mask_columns = self.data.columns[self.data.columns.str.startswith('confidence.score')] |
|
|
|
def get_graph_and_mask(self, mutation: utils.Mutation): |
|
|
|
coords: np.ndarray = self.af2_coord_dict[mutation.af2_seq_index] |
|
|
|
if mutation.crop: |
|
coords = coords[mutation.seq_start - 1:mutation.seq_end, :] |
|
|
|
mask_idx, mutation = self.get_mask(mutation) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 0]) |
|
edge_attr = coevo_strength |
|
|
|
|
|
|
|
edge_position = np.arange(coords.shape[0])[:, None] - np.arange(coords.shape[0])[None, :] |
|
edge_attr = np.concatenate( |
|
(edge_attr, np.sin(np.pi / 2 * edge_position / self.max_len)[:, :, None]), |
|
axis=2) |
|
return coords, None, None, edge_attr, None, mask_idx, mutation |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, _, _, edge_attr, _, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
|
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_esm_onehot(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
|
|
mask_idx += 1 |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
if self.add_msa: |
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
warnings.warn(f'MSA file: {mutation.transcript_id} does not match mutation sequence') |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.add_ptm: |
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
|
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=None, |
|
one_hot_mat=None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=None, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=None, |
|
edge_index_star=None, |
|
edge_attr=edge_attr, |
|
edge_attr_star=None, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae |
|
return features |
|
|
|
def get(self, idx): |
|
start = time.time() |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
|
|
|
|
pos=torch.from_numpy(features_np['CB_coord']).to(torch.float32) if self.use_cb else torch.from_numpy(features_np['CA_coord']).to(torch.float32) |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32) |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32) |
|
x_mask=torch.from_numpy(features_np['embed_data_mask']).to(torch.bool) |
|
if self.add_confidence: |
|
plddt=torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
if x.shape[0] < self.max_len + 2: |
|
|
|
x = torch.nn.functional.pad(x, (0, self.max_len + 2 - x.shape[0]), 'constant', utils.ESM_TOKENS.index('<pad>')) |
|
|
|
|
|
|
|
x_mask = torch.nn.functional.pad(x_mask, (0, self.max_len + 2 - x_mask.shape[0]), 'constant', True) |
|
|
|
|
|
|
|
|
|
y_mask = torch.zeros([len(utils.ESM_TOKENS)]).to(torch.float32) |
|
|
|
|
|
|
|
alt_aa_idxs = [utils.ESM_TOKENS.index(aa) for aa in self.mutations[idx].alt_aa] |
|
ref_aa_idxs = [utils.ESM_TOKENS.index(aa) for aa in self.mutations[idx].ref_aa] |
|
y_mask[alt_aa_idxs] = 1 |
|
y_mask[ref_aa_idxs] = -1 |
|
features = dict( |
|
x=x, |
|
|
|
x_mask=x_mask, |
|
x_alt=torch.ones_like(x) * utils.ESM_TOKENS.index('<mask>'), |
|
|
|
|
|
|
|
y=torch.tensor([self.data[self._y_columns].iloc[int(idx)]]).to(torch.float32).squeeze(1), |
|
esm_mask=y_mask.to(torch.float32), |
|
) |
|
end = time.time() |
|
print(f'get time: {end - start}') |
|
return features |
|
|
|
|
|
class FullGraphMultiOnesiteMutationDataset(FullGraphMutationDataset): |
|
def __init__(self, data_file, data_type: str, |
|
radius: float = None, max_neighbors: int = None, |
|
loop: bool = False, shuffle: bool = False, gpu_id: int = None, |
|
node_embedding_type: Literal['esm', 'one-hot-idx', 'one-hot', 'aa-5dim', 'esm1b'] = 'esm', |
|
graph_type: Literal['af2', '1d-neighbor'] = 'af2', |
|
add_plddt: bool = False, |
|
scale_plddt: bool = False, |
|
add_conservation: bool = False, |
|
add_position: bool = False, |
|
add_sidechain: bool = False, |
|
local_coord_transform: bool = False, |
|
use_cb: bool = False, |
|
add_msa_contacts: bool = True, |
|
add_dssp: bool = False, |
|
add_msa: bool = False, |
|
add_confidence: bool = False, |
|
loaded_confidence: bool = False, |
|
loaded_esm: bool = False, |
|
add_ptm: bool = False, |
|
data_augment: bool = False, |
|
score_transfer: bool = False, |
|
alt_type: Literal['alt', 'concat', 'diff'] = 'alt', |
|
computed_graph: bool = True, |
|
loaded_msa: bool = False, |
|
neighbor_type: Literal['KNN', 'radius', 'radius-KNN'] = 'KNN', |
|
max_len = 2251, |
|
convert_to_onesite: bool = False, |
|
add_af2_single: bool = False, |
|
add_af2_pairwise: bool = False, |
|
loaded_af2_single: bool = False, |
|
loaded_af2_pairwise: bool = False, |
|
): |
|
super(FullGraphMultiOnesiteMutationDataset, self).__init__( |
|
data_file, data_type, radius, max_neighbors, loop, shuffle, gpu_id, |
|
node_embedding_type, graph_type, add_plddt, scale_plddt, |
|
add_conservation, add_position, add_sidechain, |
|
local_coord_transform, use_cb, add_msa_contacts, add_dssp, |
|
add_msa, add_confidence, loaded_confidence, loaded_esm, |
|
add_ptm, data_augment, score_transfer, alt_type, |
|
computed_graph, loaded_msa, neighbor_type, max_len, convert_to_onesite) |
|
self._y_mask_columns = self.data.columns[self.data.columns.str.startswith('confidence.score')] |
|
|
|
def get_graph_and_mask(self, mutation: utils.Mutation): |
|
|
|
coords: np.ndarray = self.af2_coord_dict[mutation.af2_seq_index] |
|
|
|
if mutation.crop: |
|
coords = coords[mutation.seq_start - 1:mutation.seq_end, :] |
|
|
|
mask_idx, mutation = self.get_mask(mutation) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
coevo_strength = np.zeros([mutation.seq_end - mutation.seq_start + 1, |
|
mutation.seq_end - mutation.seq_start + 1, 0]) |
|
edge_attr = coevo_strength |
|
|
|
|
|
|
|
edge_position = np.arange(coords.shape[0])[:, None] - np.arange(coords.shape[0])[None, :] |
|
edge_attr = np.concatenate( |
|
(edge_attr, np.sin(np.pi / 2 * edge_position / self.max_len)[:, :, None]), |
|
axis=2) |
|
return coords, None, None, edge_attr, None, mask_idx, mutation |
|
|
|
def get_one_mutation(self, idx): |
|
mutation: utils.Mutation = self.mutations[idx] |
|
|
|
coords, _, _, edge_attr, _, mask_idx, mutation = self.get_graph_and_mask(mutation) |
|
|
|
if self.node_embedding_type == 'esm': |
|
if self.loaded_esm: |
|
|
|
embed_data = self.esm_dict[mutation.esm_seq_index][mutation.seq_start:mutation.seq_end + 1] |
|
else: |
|
embed_data = utils.get_embedding_from_esm2(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
elif self.node_embedding_type == 'one-hot-idx': |
|
assert not self.add_conservation and not self.add_plddt |
|
embed_logits, embed_data, one_hot_mat = utils.get_embedding_from_onehot_nonzero(mutation.seq, return_idx=True, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'one-hot': |
|
embed_data, one_hot_mat = utils.get_embedding_from_onehot(mutation.seq, return_idx=False, return_onehot_mat=True) |
|
elif self.node_embedding_type == 'aa-5dim': |
|
embed_data = utils.get_embedding_from_5dim(mutation.seq) |
|
elif self.node_embedding_type == 'esm1b': |
|
embed_data = utils.get_embedding_from_esm1b(mutation.ESM_prefix, False, |
|
mutation.seq_start, mutation.seq_end) |
|
|
|
if self.loaded_msa and (self.add_msa or self.add_conservation): |
|
msa_seq = self.msa_dict[mutation.msa_seq_index][0] |
|
conservation_data = self.msa_dict[mutation.msa_seq_index][1] |
|
msa_data = self.msa_dict[mutation.msa_seq_index][2] |
|
else: |
|
if self.add_conservation or self.add_msa: |
|
msa_seq, conservation_data, msa_data = utils.get_msa_dict_from_transcript(mutation.uniprot_id) |
|
if self.add_conservation: |
|
if conservation_data.shape[0] == 0: |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
conservation_data = conservation_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
conservation_data = conservation_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
|
|
self.unmatched_msa += 1 |
|
print(f'Unmatched MSA: {self.unmatched_msa}') |
|
conservation_data = np.zeros((embed_data.shape[0], 20)) |
|
embed_data = np.concatenate([embed_data, conservation_data], axis=1) |
|
|
|
if self.add_plddt: |
|
|
|
plddt_data = self.af2_plddt_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
plddt_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] |
|
if self.add_confidence: |
|
confidence_data = plddt_data / 100 |
|
if plddt_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {plddt_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
plddt_data = np.ones_like(embed_data[:, 0]) * 50 |
|
if self.add_confidence: |
|
|
|
confidence_data = np.ones_like(embed_data[:, 0]) / 2 |
|
if self.scale_plddt: |
|
plddt_data = plddt_data / 100 |
|
embed_data = np.concatenate([embed_data, plddt_data[:, None]], axis=1) |
|
|
|
if self.add_dssp: |
|
|
|
dssp_data = self.af2_dssp_dict[mutation.af2_seq_index] |
|
if mutation.crop: |
|
dssp_data = dssp_data[mutation.seq_start - 1: mutation.seq_end] |
|
if dssp_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'DSSP {dssp_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'DSSP file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
dssp_data = np.zeros_like(embed_data[:, 0]) |
|
|
|
if len(dssp_data.shape) == 1: |
|
dssp_data = dssp_data[:, None] |
|
embed_data = np.concatenate([embed_data, dssp_data], axis=1) |
|
if self.add_msa: |
|
if msa_data.shape[0] == 0: |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
else: |
|
msa_seq_check = msa_seq[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
msa_data = msa_data[mutation.seq_start_orig - 1: mutation.seq_end_orig] |
|
if mutation.crop: |
|
msa_seq_check = msa_seq_check[mutation.seq_start - 1: mutation.seq_end] |
|
msa_data = msa_data[mutation.seq_start - 1: mutation.seq_end] |
|
if msa_seq_check != mutation.seq: |
|
warnings.warn(f'MSA file: {mutation.transcript_id} does not match mutation sequence') |
|
msa_data = np.zeros((embed_data.shape[0], 199)) |
|
embed_data = np.concatenate([embed_data, msa_data], axis=1) |
|
if self.add_ptm: |
|
ptm_data = utils.get_ptm_from_mutation(mutation, self.ptm_ref) |
|
embed_data = np.concatenate([embed_data, ptm_data], axis=1) |
|
|
|
|
|
embed_data_mask = np.ones_like(embed_data) |
|
embed_data_mask[mask_idx] = 0 |
|
|
|
|
|
CA_coord = coords[:, 3] |
|
CB_coord = coords[:, 4] |
|
|
|
CB_coord[np.isnan(CB_coord)] = CA_coord[np.isnan(CB_coord)] |
|
if self.graph_type == '1d-neighbor': |
|
CA_coord[:, 0] = np.arange(coords.shape[0]) |
|
CB_coord[:, 0] = np.arange(coords.shape[0]) |
|
coords = np.zeros_like(coords) |
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_CB[np.isnan(CA_CB)] = 0 |
|
|
|
|
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.transpose(np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1), (0, 2, 1)) |
|
|
|
|
|
sidechain_nodes_vector = coords[:, 5:] - coords[:, [3]] |
|
sidechain_nodes_vector[np.isnan(sidechain_nodes_vector)] = 0 |
|
sidechain_nodes_vector = np.transpose(sidechain_nodes_vector, (0, 2, 1)) |
|
nodes_vector = np.concatenate([nodes_vector, sidechain_nodes_vector], axis=2) |
|
|
|
features = dict( |
|
embed_logits=None, |
|
one_hot_mat=None, |
|
mask_idx=mask_idx, |
|
embed_data=embed_data, |
|
embed_data_mask=embed_data_mask, |
|
alt_embed_data=None, |
|
coords=coords, |
|
CA_coord=CA_coord, |
|
CB_coord=CB_coord, |
|
edge_index=None, |
|
edge_index_star=None, |
|
edge_attr=edge_attr, |
|
edge_attr_star=None, |
|
nodes_vector=nodes_vector, |
|
) |
|
if self.add_confidence: |
|
|
|
if self.add_plddt: |
|
features['plddt'] = confidence_data |
|
if self.loaded_confidence: |
|
pae = self.af2_confidence_dict[mutation.af2_seq_index] |
|
else: |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, self.af2_plddt_dict[mutation.af2_seq_index]) |
|
if mutation.crop: |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
else: |
|
|
|
plddt_data = utils.get_plddt_from_af2(mutation.af2_file) |
|
pae = utils.get_confidence_from_af2file(mutation.af2_file, plddt_data) |
|
if mutation.crop: |
|
confidence_data = plddt_data[mutation.seq_start - 1: mutation.seq_end] / 100 |
|
pae = pae[mutation.seq_start - 1: mutation.seq_end, mutation.seq_start - 1: mutation.seq_end] |
|
if confidence_data.shape[0] != embed_data.shape[0]: |
|
warnings.warn(f'pLDDT {confidence_data.shape[0]} does not match embedding {embed_data.shape[0]}, ' |
|
f'pLDDT file: {mutation.af2_file}, ' |
|
f'ESM prefix: {mutation.ESM_prefix}') |
|
confidence_data = np.ones_like(embed_data[:, 0]) * 0.8 |
|
features['plddt'] = confidence_data |
|
|
|
features['edge_confidence'] = pae |
|
return features |
|
|
|
def get(self, idx): |
|
features_np = self.get_one_mutation(idx) |
|
if self.node_embedding_type == 'one-hot-idx': |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.long) |
|
else: |
|
x = torch.from_numpy(features_np['embed_data']).to(torch.float32) |
|
|
|
x_padding_mask = torch.zeros(self.max_len, dtype=torch.bool) |
|
pos=torch.from_numpy(features_np['CB_coord']).to(torch.float32) if self.use_cb else torch.from_numpy(features_np['CA_coord']).to(torch.float32) |
|
node_vec_attr=torch.from_numpy(features_np['nodes_vector']).to(torch.float32) |
|
edge_attr=torch.from_numpy(features_np['edge_attr']).to(torch.float32) |
|
x_mask=torch.from_numpy(features_np['embed_data_mask'][:, 0]).to(torch.bool) |
|
if self.add_confidence: |
|
plddt=torch.from_numpy(features_np['plddt']).to(torch.float32) |
|
edge_confidence=torch.from_numpy(features_np['edge_confidence']).to(torch.float32) |
|
if x.shape[0] < self.max_len: |
|
x_padding_mask[x.shape[0]:] = True |
|
x = torch.nn.functional.pad(x, (0, 0, 0, self.max_len - x.shape[0])) |
|
pos = torch.nn.functional.pad(pos, (0, 0, 0, self.max_len - pos.shape[0])) |
|
node_vec_attr = torch.nn.functional.pad(node_vec_attr, (0, 0, 0, 0, 0, self.max_len - node_vec_attr.shape[0])) |
|
edge_attr = torch.nn.functional.pad(edge_attr, (0, 0, 0, self.max_len - edge_attr.shape[0], 0, self.max_len - edge_attr.shape[0])) |
|
x_mask = torch.nn.functional.pad(x_mask, (0, self.max_len - x_mask.shape[0]), 'constant', True) |
|
if self.add_confidence: |
|
edge_confidence = torch.nn.functional.pad(edge_confidence, (0, self.max_len - edge_confidence.shape[0], 0, self.max_len - edge_confidence.shape[0])) |
|
plddt = torch.nn.functional.pad(plddt, (0, self.max_len - plddt.shape[0])) |
|
|
|
y_scores = self.data[self._y_columns].iloc[int(idx)] |
|
|
|
if len(self._y_mask_columns) > 0: |
|
y_masks = self.data[self._y_mask_columns].iloc[int(idx)] |
|
else: |
|
|
|
y_masks = [None] * len(y_scores) |
|
|
|
y = torch.zeros([len(utils.AA_DICT_HUMAN), len(y_scores)]).to(torch.float32) |
|
y_mask = torch.zeros_like(y) |
|
|
|
|
|
for i in range(len(y_scores)): |
|
y_scores_i = np.array(y_scores[i].split(';')).astype(np.float32) if isinstance(y_scores[i], str) else np.array([y_scores[i]]).astype(np.float32) |
|
if y_masks[i] is not None: |
|
y_masks_i = np.array(y_masks[i].split(';')).astype(np.float32) if isinstance(y_masks[i], str) else np.array([y_masks[i]]).astype(np.float32) |
|
else: |
|
y_masks_i = np.ones_like(y_scores_i) |
|
|
|
alt_aa_idxs = [utils.AA_DICT_HUMAN.index(aa) for aa in self.mutations[idx].alt_aa] |
|
y[alt_aa_idxs, i] = torch.from_numpy(y_scores_i) |
|
y_mask[alt_aa_idxs, i] = torch.from_numpy(y_masks_i) |
|
|
|
features = dict( |
|
x=x, |
|
x_padding_mask=x_padding_mask, |
|
x_mask=x_mask, |
|
x_alt=torch.zeros_like(x), |
|
pos=pos, |
|
edge_attr=edge_attr, |
|
node_vec_attr=node_vec_attr, |
|
y=y.to(torch.float32), |
|
score_mask=y_mask.to(torch.float32), |
|
) |
|
if self.add_confidence: |
|
features['plddt'] = plddt |
|
features['edge_confidence'] = edge_confidence |
|
return features |
|
|
|
|
|
|
|
def collate( |
|
data_list: List[BaseData], |
|
increment: bool = True, |
|
add_batch: bool = True, |
|
) -> BaseData: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(data_list, (list, tuple)): |
|
|
|
data_list = list(data_list) |
|
|
|
if cls != data_list[0].__class__: |
|
out = cls(_base_cls=data_list[0].__class__) |
|
else: |
|
out = cls() |
|
|
|
|
|
out.stores_as(data_list[0]) |
|
|
|
follow_batch = set(follow_batch or []) |
|
exclude_keys = set(exclude_keys or []) |
|
|
|
|
|
|
|
key_to_stores = defaultdict(list) |
|
for data in data_list: |
|
for store in data.stores: |
|
key_to_stores[store._key].append(store) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = None |
|
slice_dict, inc_dict = defaultdict(dict), defaultdict(dict) |
|
for out_store in out.stores: |
|
key = out_store._key |
|
stores = key_to_stores[key] |
|
for attr in stores[0].keys(): |
|
|
|
if attr in exclude_keys: |
|
continue |
|
|
|
values = [store[attr] for store in stores] |
|
|
|
|
|
|
|
if attr == 'num_nodes': |
|
out_store._num_nodes = values |
|
out_store.num_nodes = sum(values) |
|
continue |
|
|
|
|
|
if attr == 'ptr': |
|
continue |
|
|
|
|
|
value, slices, incs = _collate(attr, values, data_list, stores, |
|
increment) |
|
|
|
if isinstance(value, Tensor) and value.is_cuda: |
|
device = value.device |
|
|
|
out_store[attr] = value |
|
if key is not None: |
|
slice_dict[key][attr] = slices |
|
inc_dict[key][attr] = incs |
|
else: |
|
slice_dict[attr] = slices |
|
inc_dict[attr] = incs |
|
|
|
|
|
if attr in follow_batch: |
|
batch, ptr = _batch_and_ptr(slices, device) |
|
out_store[f'{attr}_batch'] = batch |
|
out_store[f'{attr}_ptr'] = ptr |
|
|
|
|
|
if (add_batch and isinstance(stores[0], NodeStorage) |
|
and stores[0].can_infer_num_nodes): |
|
repeats = [store.num_nodes for store in stores] |
|
out_store.batch = repeat_interleave(repeats, device=device) |
|
out_store.ptr = cumsum(torch.tensor(repeats, device=device)) |
|
|
|
return out |
|
|
|
|
|
def my_collate_fn(data_list: List[Any]) -> Any: |
|
batch = collate( |
|
data_list=data_list, |
|
increment=True, |
|
add_batch=True, |
|
) |
|
|
|
batch._num_graphs = len(data_list) |
|
return batch |