Sukanyaaa's picture
Upload 36 files
b38c7b5 verified
import rootutils
import torch
from torch import nn
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
# setup root dir and pythonpath
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from src.data.components.pinder_dataset import PinderDataset
from src.models.components.utils import (
compute_euler_angles_from_rotation_matrices,
compute_rotation_matrix_from_ortho6d,
)
class EquivariantMPNNLayer(MessagePassing):
def __init__(self, emb_dim=64, out_dim=128, aggr="add"):
r"""Message Passing Neural Network Layer
This layer is equivariant to 3D rotations and translations.
Args:
emb_dim: (int) - hidden dimension d
edge_dim: (int) - edge feature dimension d_e
aggr: (str) - aggregation function \oplus (sum/mean/max)
"""
# Set the aggregation function
super().__init__(aggr=aggr)
self.emb_dim = emb_dim
#
self.mlp_msg = Sequential(
Linear(2 * emb_dim + 1, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
Linear(emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
)
self.mlp_pos = Sequential(
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1)
) # MLP \psi
self.mlp_upd = Sequential(
Linear(2 * emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
Linear(emb_dim, emb_dim),
BatchNorm1d(emb_dim),
ReLU(),
) # MLP \phi
# ===========================================
self.lin_out = Linear(emb_dim, out_dim)
def forward(self, data):
"""
The forward pass updates node features h via one round of message passing.
Args:
h: (n, d) - initial node features
pos: (n, 3) - initial node coordinates
edge_index: (e, 2) - pairs of edges (i, j)
edge_attr: (e, d_e) - edge features
Returns:
out: [(n, d),(n,3)] - updated node features
"""
#
h, pos, edge_index = data
h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos)
h_out = self.lin_out(h_out)
return h_out, pos_out, edge_index
# ==========================================
#
def message(self, h_i, h_j, pos_i, pos_j):
# Compute distance between nodes i and j (Euclidean distance)
# distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
pos_diff = pos_i - pos_j
dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
# Concatenate node features, edge features, and distance
msg = torch.cat([h_i, h_j, dists], dim=-1)
msg = self.mlp_msg(msg)
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
# (e, d)
return msg, pos_diff
# ...
#
def aggregate(self, inputs, index):
"""The aggregate function aggregates the messages from neighboring nodes,
according to the chosen aggregation function ('sum' by default).
Args:
inputs: (e, d) - messages m_ij from destination to source nodes
index: (e, 1) - list of source nodes for each edge/message in input
Returns:
aggr_out: (n, d) - aggregated messages m_i
"""
msgs, pos_diffs = inputs
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
return msg_aggr, pos_aggr
def update(self, aggr_out, h, pos):
msg_aggr, pos_aggr = aggr_out
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
upd_pos = pos + pos_aggr
return upd_out, upd_pos
def __repr__(self) -> str:
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
class PinderMPNNModel(Module):
def __init__(self, input_dim=1, emb_dim=64, num_heads=5):
"""Message Passing Neural Network model for graph property prediction
This model uses both node features and coordinates as inputs, and
is invariant to 3D rotations and translations (the constituent MPNN layers
are equivariant to 3D rotations and translations).
Args:
emb_dim: (int) - hidden dimension d
input_dim: (int) - initial node feature dimension d_n
edge_dim: (int) - edge feature dimension d_e
out_dim: (int) - output dimension (fixed to 1)
"""
super().__init__()
# Linear projection for initial node features
self.lin_in_rec = Linear(input_dim, emb_dim)
self.lin_in_lig = Linear(input_dim, emb_dim)
# Stack of MPNN layers
self.receptor_mpnn = Sequential(
EquivariantMPNNLayer(emb_dim, 128, aggr="mean"),
EquivariantMPNNLayer(128, 256, aggr="mean"),
# EquivariantMPNNLayer(256, 512, aggr="mean"),
# EquivariantMPNNLayer(512, 512, aggr="mean"),
)
self.ligand_mpnn = Sequential(
EquivariantMPNNLayer(64, 128, aggr="mean"),
EquivariantMPNNLayer(128, 256, aggr="mean"),
# EquivariantMPNNLayer(256, 512, aggr="mean"),
# EquivariantMPNNLayer(512, 512, aggr="mean"),
)
# Cross-attention layer
self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
# MLPs for translation prediction
self.fc_translation_rec = nn.Linear(256 + 3, 3)
self.fc_translation_lig = nn.Linear(256 + 3, 3)
def forward(self, batch):
"""
The main forward pass of the model.
Args:
batch: Same as in forward_rot_trans.
Returns:
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
representing the transformed ligand coordinates after applying the predicted
rotation and translation.
"""
h_receptor = self.lin_in_rec(batch["receptor"].x)
h_ligand = self.lin_in_lig(batch["ligand"].x)
pos_receptor = batch["receptor"].pos
pos_ligand = batch["ligand"].pos
h_receptor, pos_receptor, _ = self.receptor_mpnn(
(h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index)
)
h_ligand, pos_ligand, _ = self.ligand_mpnn(
(h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index)
)
attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand)
attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor)
emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1)
emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1)
translation_vector_r = self.fc_translation_rec(emb_features_receptor)
translation_vector_l = self.fc_translation_lig(emb_features_ligand)
ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec)
ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig)
receptor_coords = (
compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi
)
ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi
receptor_coords = receptor_coords + translation_vector_r
ligand_coords = ligand_coords + translation_vector_l
return receptor_coords, ligand_coords
if __name__ == "__main__":
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
dataset = PinderDataset(file_paths=file_paths * 3)
loader = DataLoader(dataset, batch_size=3, shuffle=False)
batch = next(iter(loader))
model = PinderMPNNModel()
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
receptor_coords, ligand_coords = model(batch)
print(receptor_coords.shape)
print(ligand_coords.shape)