import logging import os import sys sys.path.append("../") import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import autocast from torch.nn import Module from tqdm import tqdm from torch.nn.utils.weight_norm import weight_norm from torch.utils.data import Dataset LOGGER = logging.getLogger(__name__) class FusionDTI(nn.Module): def __init__(self, prot_out_dim, disease_out_dim, args): super(FusionDTI, self).__init__() self.fusion = args.fusion self.drug_reg = nn.Linear(disease_out_dim, 512) self.prot_reg = nn.Linear(prot_out_dim, 512) if self.fusion == "CAN": self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args) self.mlp_classifier = MlPdecoder_CAN(input_dim=1024) elif self.fusion == "BAN": self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None) self.mlp_classifier = MlPdecoder_CAN(input_dim=256) elif self.fusion == "Nan": self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214) def forward(self, prot_embed, drug_embed, prot_mask, drug_mask): # print("drug_embed", drug_embed.shape) if self.fusion == "Nan": prot_embed = prot_embed.mean(1) # query : [batch_size, hidden] drug_embed = drug_embed.mean(1) # query : [batch_size, hidden] joint_embed = torch.cat([prot_embed, drug_embed], dim=1) score = self.mlp_classifier_nan(joint_embed) else: prot_embed = self.prot_reg(prot_embed) drug_embed = self.drug_reg(drug_embed) if self.fusion == "CAN": joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask) elif self.fusion == "BAN": joint_embed, att = self.ban_layer(prot_embed, drug_embed) score = self.mlp_classifier(joint_embed) return score, att class Pre_encoded(nn.Module): def __init__( self, prot_encoder, drug_encoder, args ): """Constructor for the model. Args: prot_encoder (_type_): Protein sturcture-aware sequence encoder. drug_encoder (_type_): Drug SFLFIES encoder. args (_type_): _description_ """ super(Pre_encoded, self).__init__() self.prot_encoder = prot_encoder self.drug_encoder = drug_encoder def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask): # Process inputs through encoders prot_embed = self.prot_encoder( input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True ).logits # prot_embed = self.prot_reg(prot_embed) drug_embed = self.drug_encoder( input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True ).last_hidden_state # .last_hidden_state # print("drug_embed", drug_embed.shape) return prot_embed, drug_embed class CAN_Layer(nn.Module): def __init__(self, hidden_dim, num_heads, args): super(CAN_Layer, self).__init__() self.agg_mode = args.agg_mode self.group_size = args.group_size # Control Fusion Scale self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_size = hidden_dim // num_heads self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False) self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False) self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False) self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False) self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False) self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False) def alpha_logits(self, logits, mask_row, mask_col, inf=1e6): N, L1, L2, H = logits.shape mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H) mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H) mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col) logits = torch.where(mask_pair, logits, logits - inf) alpha = torch.softmax(logits, dim=2) mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1) alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) return alpha def apply_heads(self, x, n_heads, n_ch): s = list(x.size())[:-1] + [n_heads, n_ch] return x.view(*s) def group_embeddings(self, x, mask, group_size): N, L, D = x.shape groups = L // group_size x_grouped = x.view(N, groups, group_size, D).mean(dim=2) mask_grouped = mask.view(N, groups, group_size).any(dim=2) return x_grouped, mask_grouped def forward(self, protein, drug, mask_prot, mask_drug): # Group embeddings before applying multi-head attention protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size) drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size) # print("protein_grouped:", protein_grouped.shape) # print("mask_prot_grouped:", mask_prot_grouped.shape) # Compute queries, keys, values for both protein and drug after grouping query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size) key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size) value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size) query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size) key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size) value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size) # Compute attention scores logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot) logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug) logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot) logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug) # print("logits_pp:", logits_pp.shape) alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped) alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped) alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped) alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped) prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) + torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2 drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) + torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2 # print("prot_embedding:", prot_embedding.shape) # Continue as usual with the aggregation mode if self.agg_mode == "cls": prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden] drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden] elif self.agg_mode == "mean_all_tok": prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden] drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden] elif self.agg_mode == "mean": prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1) drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1) else: raise NotImplementedError() # print("prot_embed:", prot_embed.shape) query_embed = torch.cat([prot_embed, drug_embed], dim=1) att = torch.zeros(1, 1, 1024, 1024) att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug # print("query_embed:", query_embed.shape) return query_embed, att class MlPdecoder_CAN(nn.Module): def __init__(self, input_dim): super(MlPdecoder_CAN, self).__init__() self.fc1 = nn.Linear(input_dim, input_dim) self.bn1 = nn.BatchNorm1d(input_dim) self.fc2 = nn.Linear(input_dim, input_dim // 2) self.bn2 = nn.BatchNorm1d(input_dim // 2) self.fc3 = nn.Linear(input_dim // 2, input_dim // 4) self.bn3 = nn.BatchNorm1d(input_dim // 4) self.output = nn.Linear(input_dim // 4, 1) def forward(self, x): x = self.bn1(torch.relu(self.fc1(x))) x = self.bn2(torch.relu(self.fc2(x))) x = self.bn3(torch.relu(self.fc3(x))) x = torch.sigmoid(self.output(x)) return x class MLPdecoder_BAN(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, binary=1): super(MLPdecoder_BAN, self).__init__() self.fc1 = nn.Linear(in_dim, hidden_dim) self.bn1 = nn.BatchNorm1d(hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.bn2 = nn.BatchNorm1d(hidden_dim) self.fc3 = nn.Linear(hidden_dim, out_dim) self.bn3 = nn.BatchNorm1d(out_dim) self.fc4 = nn.Linear(out_dim, binary) def forward(self, x): x = self.bn1(F.relu(self.fc1(x))) x = self.bn2(F.relu(self.fc2(x))) x = self.bn3(F.relu(self.fc3(x))) # x = self.fc4(x) x = torch.sigmoid(self.fc4(x)) return x class BANLayer(nn.Module): """ Bilinear attention network Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py """ def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): super(BANLayer, self).__init__() self.c = 32 self.k = k self.v_dim = v_dim self.q_dim = q_dim self.h_dim = h_dim self.h_out = h_out self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) # self.dropout = nn.Dropout(dropout[1]) if 1 < k: self.p_net = nn.AvgPool1d(self.k, stride=self.k) if h_out <= self.c: self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) else: self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) self.bn = nn.BatchNorm1d(h_dim) def attention_pooling(self, v, q, att_map): fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) if 1 < self.k: fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling return fusion_logits def forward(self, v, q, softmax=False): v_num = v.size(1) q_num = q.size(1) # print("v_num", v_num) # print("v_num ", v_num) if self.h_out <= self.c: v_ = self.v_net(v) q_ = self.q_net(q) # print("v_", v_.shape) # print("q_ ", q_.shape) att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias # print("Attention map_1",att_maps.shape) else: v_ = self.v_net(v).transpose(1, 2).unsqueeze(3) q_ = self.q_net(q).transpose(1, 2).unsqueeze(2) d_ = torch.matmul(v_, q_) # b x h_dim x v x q att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q # print("Attention map_2",att_maps.shape) if softmax: p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2) att_maps = p.view(-1, self.h_out, v_num, q_num) # print("Attention map_softmax", att_maps.shape) logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) for i in range(1, self.h_out): logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :]) logits += logits_i logits = self.bn(logits) return logits, att_maps class FCNet(nn.Module): """Simple class for non-linear fully connect network Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py """ def __init__(self, dims, act='ReLU', dropout=0): super(FCNet, self).__init__() layers = [] for i in range(len(dims) - 2): in_dim = dims[i] out_dim = dims[i + 1] if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) if '' != act: layers.append(getattr(nn, act)()) if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) if '' != act: layers.append(getattr(nn, act)()) self.main = nn.Sequential(*layers) def forward(self, x): return self.main(x) class BatchFileDataset_Case(Dataset): def __init__(self, file_list): self.file_list = file_list def __len__(self): return len(self.file_list) def __getitem__(self, idx): batch_file = self.file_list[idx] data = torch.load(batch_file) return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']