Spaces:
Sleeping
Sleeping
import rootutils | |
import torch | |
from torch import nn | |
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential | |
from torch_geometric.loader import DataLoader | |
from torch_geometric.nn import MessagePassing | |
from torch_scatter import scatter | |
# setup root dir and pythonpath | |
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
from src.data.components.pinder_dataset import PinderDataset | |
from src.models.components.utils import ( | |
compute_euler_angles_from_rotation_matrices, | |
compute_rotation_matrix_from_ortho6d, | |
) | |
class EquivariantMPNNLayer(MessagePassing): | |
def __init__(self, emb_dim=64, out_dim=128, aggr="add"): | |
r"""Message Passing Neural Network Layer | |
This layer is equivariant to 3D rotations and translations. | |
Args: | |
emb_dim: (int) - hidden dimension d | |
edge_dim: (int) - edge feature dimension d_e | |
aggr: (str) - aggregation function \oplus (sum/mean/max) | |
""" | |
# Set the aggregation function | |
super().__init__(aggr=aggr) | |
self.emb_dim = emb_dim | |
# | |
self.mlp_msg = Sequential( | |
Linear(2 * emb_dim + 1, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
Linear(emb_dim, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
) | |
self.mlp_pos = Sequential( | |
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1) | |
) # MLP \psi | |
self.mlp_upd = Sequential( | |
Linear(2 * emb_dim, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
Linear(emb_dim, emb_dim), | |
BatchNorm1d(emb_dim), | |
ReLU(), | |
) # MLP \phi | |
# =========================================== | |
self.lin_out = Linear(emb_dim, out_dim) | |
def forward(self, data): | |
""" | |
The forward pass updates node features h via one round of message passing. | |
Args: | |
h: (n, d) - initial node features | |
pos: (n, 3) - initial node coordinates | |
edge_index: (e, 2) - pairs of edges (i, j) | |
edge_attr: (e, d_e) - edge features | |
Returns: | |
out: [(n, d),(n,3)] - updated node features | |
""" | |
# | |
h, pos, edge_index = data | |
h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos) | |
h_out = self.lin_out(h_out) | |
return h_out, pos_out, edge_index | |
# ========================================== | |
# | |
def message(self, h_i, h_j, pos_i, pos_j): | |
# Compute distance between nodes i and j (Euclidean distance) | |
# distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1) | |
pos_diff = pos_i - pos_j | |
dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) | |
# Concatenate node features, edge features, and distance | |
msg = torch.cat([h_i, h_j, dists], dim=-1) | |
msg = self.mlp_msg(msg) | |
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1) | |
# (e, d) | |
return msg, pos_diff | |
# ... | |
# | |
def aggregate(self, inputs, index): | |
"""The aggregate function aggregates the messages from neighboring nodes, | |
according to the chosen aggregation function ('sum' by default). | |
Args: | |
inputs: (e, d) - messages m_ij from destination to source nodes | |
index: (e, 1) - list of source nodes for each edge/message in input | |
Returns: | |
aggr_out: (n, d) - aggregated messages m_i | |
""" | |
msgs, pos_diffs = inputs | |
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr) | |
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean") | |
return msg_aggr, pos_aggr | |
def update(self, aggr_out, h, pos): | |
msg_aggr, pos_aggr = aggr_out | |
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1)) | |
upd_pos = pos + pos_aggr | |
return upd_out, upd_pos | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" | |
class PinderMPNNModel(Module): | |
def __init__(self, input_dim=1, emb_dim=64, num_heads=5): | |
"""Message Passing Neural Network model for graph property prediction | |
This model uses both node features and coordinates as inputs, and | |
is invariant to 3D rotations and translations (the constituent MPNN layers | |
are equivariant to 3D rotations and translations). | |
Args: | |
emb_dim: (int) - hidden dimension d | |
input_dim: (int) - initial node feature dimension d_n | |
edge_dim: (int) - edge feature dimension d_e | |
out_dim: (int) - output dimension (fixed to 1) | |
""" | |
super().__init__() | |
# Linear projection for initial node features | |
self.lin_in_rec = Linear(input_dim, emb_dim) | |
self.lin_in_lig = Linear(input_dim, emb_dim) | |
# Stack of MPNN layers | |
self.receptor_mpnn = Sequential( | |
EquivariantMPNNLayer(emb_dim, 128, aggr="mean"), | |
EquivariantMPNNLayer(128, 256, aggr="mean"), | |
# EquivariantMPNNLayer(256, 512, aggr="mean"), | |
# EquivariantMPNNLayer(512, 512, aggr="mean"), | |
) | |
self.ligand_mpnn = Sequential( | |
EquivariantMPNNLayer(64, 128, aggr="mean"), | |
EquivariantMPNNLayer(128, 256, aggr="mean"), | |
# EquivariantMPNNLayer(256, 512, aggr="mean"), | |
# EquivariantMPNNLayer(512, 512, aggr="mean"), | |
) | |
# Cross-attention layer | |
self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) | |
self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) | |
# MLPs for translation prediction | |
self.fc_translation_rec = nn.Linear(256 + 3, 3) | |
self.fc_translation_lig = nn.Linear(256 + 3, 3) | |
def forward(self, batch): | |
""" | |
The main forward pass of the model. | |
Args: | |
batch: Same as in forward_rot_trans. | |
Returns: | |
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3) | |
representing the transformed ligand coordinates after applying the predicted | |
rotation and translation. | |
""" | |
h_receptor = self.lin_in_rec(batch["receptor"].x) | |
h_ligand = self.lin_in_lig(batch["ligand"].x) | |
pos_receptor = batch["receptor"].pos | |
pos_ligand = batch["ligand"].pos | |
h_receptor, pos_receptor, _ = self.receptor_mpnn( | |
(h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index) | |
) | |
h_ligand, pos_ligand, _ = self.ligand_mpnn( | |
(h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index) | |
) | |
attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand) | |
attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor) | |
emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1) | |
emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1) | |
translation_vector_r = self.fc_translation_rec(emb_features_receptor) | |
translation_vector_l = self.fc_translation_lig(emb_features_ligand) | |
ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec) | |
ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig) | |
receptor_coords = ( | |
compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi | |
) | |
ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi | |
receptor_coords = receptor_coords + translation_vector_r | |
ligand_coords = ligand_coords + translation_vector_l | |
return receptor_coords, ligand_coords | |
if __name__ == "__main__": | |
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] | |
dataset = PinderDataset(file_paths=file_paths * 3) | |
loader = DataLoader(dataset, batch_size=3, shuffle=False) | |
batch = next(iter(loader)) | |
model = PinderMPNNModel() | |
print("Number of parameters:", sum(p.numel() for p in model.parameters())) | |
receptor_coords, ligand_coords = model(batch) | |
print(receptor_coords.shape) | |
print(ligand_coords.shape) | |