from typing import Optional, Tuple import torch from torch import Tensor, nn from torch_geometric.nn import MessagePassing from torch_geometric.utils import softmax from torch_scatter import scatter from torch_sparse import SparseTensor import loralib as lora from esm.multihead_attention import MultiheadAttention import math from torch import _dynamo _dynamo.config.suppress_errors = True from ..module.utils import ( CosineCutoff, act_class_mapping, get_template_fn, gelu ) # original torchmd-net attention layer class EquivariantMultiHeadAttention(MessagePassing): """Equivariant multi-head attention layer.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantMultiHeadAttention, self).__init__( aggr="mean", node_dim=0) assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.share_kv = share_kv self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None self.v_proj = lora.Linear( x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) self.o_proj = lora.Linear( x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora) self.vec_proj = lora.Linear( vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None self.v_proj = nn.Linear( x_channels, x_hidden_channels + vec_channels * 2) self.o_proj = nn.Linear( x_hidden_channels, x_channels * 2 + vec_channels) self.vec_proj = nn.Linear( vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) self.dk_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) else: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.dv_proj = None if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.vec_proj.weight) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False): x = self.layernorm(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.share_kv: k = v[:, :, :self.x_head_dim] else: k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) dv = ( self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, vec, attn = self.propagate( edge_index, q=q, k=k, v=v, vec=vec, dk=dk, dv=dv, r_ij=r_ij, d_ij=d_ij, size=None, ) x = x.reshape(-1, self.x_hidden_channels) vec = vec.reshape(-1, 3, self.vec_channels) o1, o2, o3 = torch.split(self.o_proj( x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) dx = vec_dot * o2 + o3 dvec = vec3 * o1.unsqueeze(1) + vec if return_attn: return dx, dvec, torch.concat((edge_index.T, attn), dim=1) else: return dx, dvec, None def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j * dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) # value pathway if dv is not None: v_j = v_j * dv x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ d_ij.unsqueeze(2).unsqueeze(3) return x, vec, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, vec, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) return x, vec, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # ESM multi-head attention layer, added LoRA class ESMMultiheadAttention(MultiheadAttention): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv: bool = False, add_zero_attn: bool = False, self_attention: bool = False, encoder_decoder_attention: bool = False, use_rotary_embeddings: bool = False, ): super().__init__(embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, add_zero_attn, self_attention, encoder_decoder_attention, use_rotary_embeddings) # change the projection to LoRA self.k_proj = lora.Linear(self.kdim, embed_dim, bias=bias, r=16) self.v_proj = lora.Linear(self.vdim, embed_dim, bias=bias, r=16) self.q_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16) self.out_proj = lora.Linear(embed_dim, embed_dim, bias=bias, r=16) # original torchmd-net attention layer, add pair-wise confidence of PAE class EquivariantPAEMultiHeadAttention(EquivariantMultiHeadAttention): """Equivariant multi-head attention layer.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantPAEMultiHeadAttention, self).__init__( x_channels=x_channels, x_hidden_channels=x_hidden_channels, vec_channels=vec_channels, vec_hidden_channels=vec_hidden_channels, share_kv=share_kv, edge_attr_channels=edge_attr_channels, distance_influence=distance_influence, num_heads=num_heads, activation=activation, attn_activation=attn_activation, cutoff_lower=cutoff_lower, cutoff_upper=cutoff_upper, use_lora=use_lora) # we cancel the cutoff function self.cutoff = None # we set separate projection for distance influence self.dk_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) self.dv_dist_proj = None if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) if self.dk_dist_proj: nn.init.xavier_uniform_(self.dk_dist_proj.weight) self.dk_dist_proj.bias.data.fill_(0) if self.dv_dist_proj: nn.init.xavier_uniform_(self.dv_dist_proj.weight) self.dv_dist_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence x = self.layernorm(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.share_kv: k = v[:, :, :self.x_head_dim] else: k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) dk_dist = ( self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_dist_proj is not None else None ) dv = ( self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_proj is not None else None ) dv_dist = ( self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_dist_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, vec, attn = self.propagate( edge_index, q=q, k=k, v=v, vec=vec, dk=dk, dk_dist=dk_dist, dv=dv, dv_dist=dv_dist, d_ij=d_ij, w_ij=w_ij, size=None, ) x = x.reshape(-1, self.x_hidden_channels) vec = vec.reshape(-1, 3, self.vec_channels) o1, o2, o3 = torch.split(self.o_proj( x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) dx = vec_dot * o2 * plddt.unsqueeze(1) + o3 dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec if return_attn: return dx, dvec, torch.concat((edge_index.T, attn), dim=1) else: return dx, dvec, None def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij): # attention mechanism attn = (q_i * k_j) if dk is not None: attn += dk if dk_dist is not None: attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2) attn = attn.sum(dim=-1) # attention activation function attn = self.attn_activation(attn) # value pathway, add dv, but apply w_ij to dv if dv is not None: v_j += dv if dv_dist is not None: v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ d_ij.unsqueeze(2).unsqueeze(3) return x, vec, attn # original torchmd-net attention layer, add pair-wise confidence of PAE class EquivariantWeightedPAEMultiHeadAttention(EquivariantMultiHeadAttention): """Equivariant multi-head attention layer.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantWeightedPAEMultiHeadAttention, self).__init__( x_channels=x_channels, x_hidden_channels=x_hidden_channels, vec_channels=vec_channels, vec_hidden_channels=vec_hidden_channels, share_kv=share_kv, edge_attr_channels=edge_attr_channels, distance_influence=distance_influence, num_heads=num_heads, activation=activation, attn_activation=attn_activation, cutoff_lower=cutoff_lower, cutoff_upper=cutoff_upper, use_lora=use_lora) # we cancel the cutoff function self.cutoff = None # we set a separate weight for distance influence self.pae_weight = nn.Linear(1, 1, bias=True) self.pae_weight.weight.data.fill_(-0.5) self.pae_weight.bias.data.fill_(7.5) # we set separate projection for distance influence self.dk_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) self.dv_dist_proj = None if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) if self.dk_dist_proj: nn.init.xavier_uniform_(self.dk_dist_proj.weight) self.dk_dist_proj.bias.data.fill_(0) if self.dv_dist_proj: nn.init.xavier_uniform_(self.dv_dist_proj.weight) self.dv_dist_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence x = self.layernorm(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.share_kv: k = v[:, :, :self.x_head_dim] else: k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) dk_dist = ( self.act(self.dk_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_dist_proj is not None else None ) dv = ( self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_proj is not None else None ) dv_dist = ( self.act(self.dv_dist_proj(f_dist_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_dist_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, vec, attn = self.propagate( edge_index, q=q, k=k, v=v, vec=vec, dk=dk, dk_dist=dk_dist, dv=dv, dv_dist=dv_dist, d_ij=d_ij, w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), size=None, ) x = x.reshape(-1, self.x_hidden_channels) vec = vec.reshape(-1, 3, self.vec_channels) o1, o2, o3 = torch.split(self.o_proj( x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) dx = vec_dot * o2 * plddt.unsqueeze(1) + o3 dvec = vec3 * o1.unsqueeze(1) * plddt.unsqueeze(1).unsqueeze(2) + vec if return_attn: return dx, dvec, torch.concat((edge_index.T, attn), dim=1) else: return dx, dvec, None def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij): # attention mechanism attn = (q_i * k_j) if dk_dist is not None: if dk is not None: attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2)) else: attn *= dk_dist * w_ij else: if dk is not None: attn *= dk attn = attn.sum(dim=-1) # attention activation function attn = self.attn_activation(attn) # value pathway, add dv, but apply w_ij to dv if dv is not None: v_j += dv if dv_dist is not None: v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ d_ij.unsqueeze(2).unsqueeze(3) return x, vec, attn class EquivariantPAEMultiHeadAttentionSoftMaxFullGraph(nn.Module): """Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default super(EquivariantPAEMultiHeadAttentionSoftMaxFullGraph, self).__init__() assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.share_kv = share_kv self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.cutoff = None self.scaling = self.x_head_dim**-0.5 if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) self.o_proj = lora.Linear(x_hidden_channels, x_channels * 2 + vec_channels, r=use_lora) self.vec_proj = lora.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None self.v_proj = nn.Linear(x_channels, x_hidden_channels + vec_channels * 2) self.o_proj = nn.Linear(x_hidden_channels, x_channels * 2 + vec_channels) self.vec_proj = nn.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) self.dk_proj = None self.dk_dist_proj = None self.dv_proj = None self.dv_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) # set PAE weight as a learnable parameter, basiclly a sigmoid function self.pae_weight = nn.Linear(1, 1, bias=True) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.vec_proj.weight) self.pae_weight.weight.data.fill_(-0.5) self.pae_weight.bias.data.fill_(7.5) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) if self.dk_dist_proj: nn.init.xavier_uniform_(self.dk_dist_proj.weight) self.dk_dist_proj.bias.data.fill_(0) if self.dv_dist_proj: nn.init.xavier_uniform_(self.dv_dist_proj.weight) self.dv_dist_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence # edge_index is unused x = self.layernorm(x) q = self.q_proj(x) * self.scaling v = self.v_proj(x) # if self.share_kv: # k = v[:, :, :self.x_head_dim] # else: k = self.k_proj(x) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) vec_dot = (vec1 * vec2).sum(dim=-2) dk = self.act(self.dk_proj(f_ij)) dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) dv = self.act(self.dv_proj(f_ij)) dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) # full graph attention x, vec, attn = self.attention( q=q, k=k, v=v, vec=vec, dk=dk, dk_dist=dk_dist, dv=dv, dv_dist=dv_dist, d_ij=d_ij, w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), key_padding_mask=key_padding_mask, ) o1, o2, o3 = torch.split(self.o_proj(x), [self.vec_channels, self.x_channels, self.x_channels], dim=-1) dx = vec_dot * o2 * plddt.unsqueeze(-1) + o3 dvec = vec3 * o1.unsqueeze(-2) * plddt.unsqueeze(-1).unsqueeze(-2) + vec # apply key_padding_mask to dx dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0) dvec = dvec.masked_fill(key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0) if return_attn: return dx, dvec, attn else: return dx, dvec, None def attention(self, q, k, v, vec, dk, dk_dist, dv, dv_dist, d_ij, w_ij, key_padding_mask=None, need_head_weights=False): # note that q is of shape (bsz, tgt_len, num_heads * head_dim) # k, v is of shape (bsz, src_len, num_heads * head_dim) # vec is of shape (bsz, src_len, 3, num_heads * head_dim) # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim) # d_ij is of shape (bsz, tgt_len, src_len, 3) # w_ij is of shape (bsz, tgt_len, src_len) # key_padding_mask is of shape (bsz, src_len) bsz, tgt_len, _ = q.size() src_len = k.size(1) # change q size to (bsz * num_heads, tgt_len, head_dim) # change k,v size to (bsz * num_heads, src_len, head_dim) q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).transpose(0, 1).contiguous() # change vec to (bsz * num_heads, src_len, 3, head_dim) vec = vec.permute(1, 2, 0, 3).reshape(src_len, 3, bsz * self.num_heads, self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # dk size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dk is not None: # change dk to (bsz * num_heads, tgt_len, src_len, head_dim) dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dk_dist is not None: # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim) dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # dv size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dv is not None: # change dv to (bsz * num_heads, tgt_len, src_len, head_dim) dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if dv_dist is not None: # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim) dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if key_padding_mask is not None: # key_padding_mask should be (bsz, src_len) assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim) attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) # w_ij is PAE confidence # w_ij size is (bsz, tgt_len, src_len) # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim) # if dk_dist is not None: assert w_ij is not None # if dk is not None: attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim)) # add dv and dv_dist v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim) # else: # attn_weights *= dk_dist * w_ij # else: # if dk is not None: # attn_weights *= dk # attn_weights size is (bsz * num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=-1) # apply key_padding_mask to attn_weights # if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() # apply softmax to attn_weights attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) # x, vec1, vec2 are of shape (bsz * num_heads, src_len, head_dim) x, vec1, vec2 = torch.split(v, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=-1) # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim) x_out = torch.einsum('bts,btsh->bth', attn_weights, x) # next get equivariant feature outputs vec_out_1, size is (bsz * num_heads, tgt_len, 3, head_dim) vec_out_1 = torch.einsum('bsih,btsh->btih', vec, vec1) # next get equivariant feature outputs vec_out_2, size is (bsz * num_heads, tgt_len, src_len, 3, head_dim) vec_out_2 = torch.einsum('btsi,btsh->btih', d_ij, vec2) # adds up vec_out_1 and vec_out_2, get vec_out, size is (bsz * num_heads, tgt_len, 3, head_dim) vec_out = vec_out_1 + vec_out_2 attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) # if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) # reshape x_out to (bsz, tgt_len, num_heads * head_dim) x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() # reshape vec_out to (bsz, tgt_len, 3, num_heads * head_dim) vec_out = vec_out.permute(1, 2, 0, 3).reshape(tgt_len, 3, bsz, self.num_heads * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() return x_out, vec_out, attn_weights class MultiHeadAttentionSoftMaxFullGraph(nn.Module): """ Multi-head attention layer with softmax, apply attention on full graph by default No equivariant property, but can take structure information as input, just didn't use it """ def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default super(MultiHeadAttentionSoftMaxFullGraph, self).__init__() assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.share_kv = share_kv self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.cutoff = None self.scaling = self.x_head_dim**-0.5 if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.o_proj = lora.Linear(x_hidden_channels, x_channels, r=use_lora) # self.vec_proj = lora.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None self.v_proj = nn.Linear(x_channels, x_hidden_channels) self.o_proj = nn.Linear(x_hidden_channels, x_channels) # self.vec_proj = nn.Linear(vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) self.dk_proj = None self.dk_dist_proj = None self.dv_proj = None self.dv_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) # self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) # self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) # self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels) # self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) # set PAE weight as a learnable parameter, basiclly a sigmoid function # self.pae_weight = nn.Linear(1, 1, bias=True) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.vec_proj.weight) # self.pae_weight.weight.data.fill_(-0.5) # self.pae_weight.bias.data.fill_(7.5) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, w_ij, f_dist_ij, f_ij, d_ij, plddt, key_padding_mask, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence # edge_index is unused x = self.layernorm(x) q = self.q_proj(x) * self.scaling v = self.v_proj(x) # if self.share_kv: # k = v[:, :, :self.x_head_dim] # else: k = self.k_proj(x) # vec1, vec2, vec3 = torch.split(self.vec_proj(vec), # [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) # vec_dot = (vec1 * vec2).sum(dim=-2) dk = self.act(self.dk_proj(f_ij)) # dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) dv = self.act(self.dv_proj(f_ij)) # dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) # full graph attention x, vec, attn = self.attention( q=q, k=k, v=v, vec=vec, dk=dk, # dk_dist=dk_dist, dv=dv, # dv_dist=dv_dist, # d_ij=d_ij, # w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), key_padding_mask=key_padding_mask, ) # o1, o2, o3 = torch.split(self.o_proj(x), [self.vec_channels, self.x_channels, self.x_channels], dim=-1) # dx = vec_dot * o2 * plddt.unsqueeze(-1) + o3 dx = self.o_proj(x) # dvec = vec3 * o1.unsqueeze(-2) * plddt.unsqueeze(-1).unsqueeze(-2) + vec # apply key_padding_mask to dx dx = dx.masked_fill(key_padding_mask.unsqueeze(-1), 0) # dvec = dvec.masked_fill(key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0) if return_attn: return dx, vec, attn else: return dx, vec, None def attention(self, q, k, v, vec, dk, dv, key_padding_mask=None, need_head_weights=False): # note that q is of shape (bsz, tgt_len, num_heads * head_dim) # k, v is of shape (bsz, src_len, num_heads * head_dim) # vec is of shape (bsz, src_len, 3, num_heads * head_dim) # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim) # d_ij is of shape (bsz, tgt_len, src_len, 3) # w_ij is of shape (bsz, tgt_len, src_len) # key_padding_mask is of shape (bsz, src_len) bsz, tgt_len, _ = q.size() src_len = k.size(1) # change q size to (bsz * num_heads, tgt_len, head_dim) # change k,v size to (bsz * num_heads, src_len, head_dim) q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() # change vec to (bsz * num_heads, src_len, 3, head_dim) # vec = vec.permute(1, 2, 0, 3).reshape(src_len, 3, bsz * self.num_heads, self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # dk size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dk is not None: # change dk to (bsz * num_heads, tgt_len, src_len, head_dim) dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dk_dist is not None: # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim) # dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # dv size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dv is not None: # change dv to (bsz * num_heads, tgt_len, src_len, head_dim) dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dv_dist is not None: # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim) # dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if key_padding_mask is not None: # key_padding_mask should be (bsz, src_len) assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim) attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) # w_ij is PAE confidence # w_ij size is (bsz, tgt_len, src_len) # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim) # if dk_dist is not None: # assert w_ij is not None # if dk is not None: attn_weights *= dk # add dv and dv_dist v = v.unsqueeze(1) + dv # else: # attn_weights *= dk_dist * w_ij # else: # if dk is not None: # attn_weights *= dk # attn_weights size is (bsz * num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=-1) # apply key_padding_mask to attn_weights # if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() # apply softmax to attn_weights attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) # x, vec1, vec2 are of shape (bsz * num_heads, src_len, head_dim) # x, vec1, vec2 = torch.split(v, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=-1) # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim) x_out = torch.einsum('bts,btsh->bth', attn_weights, v) # next get equivariant feature outputs vec_out_1, size is (bsz * num_heads, tgt_len, 3, head_dim) # vec_out_1 = torch.einsum('bsih,btsh->btih', vec, vec1) # next get equivariant feature outputs vec_out_2, size is (bsz * num_heads, tgt_len, src_len, 3, head_dim) # vec_out_2 = torch.einsum('btsi,btsh->btih', d_ij, vec2) # adds up vec_out_1 and vec_out_2, get vec_out, size is (bsz * num_heads, tgt_len, 3, head_dim) # vec_out = vec_out_1 + vec_out_2 attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) # if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) # reshape x_out to (bsz, tgt_len, num_heads * head_dim) x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() # reshape vec_out to (bsz, tgt_len, 3, num_heads * head_dim) # vec_out = vec_out.permute(1, 2, 0, 3).reshape(tgt_len, 3, bsz, self.num_heads * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() return x_out, vec, attn_weights class PAEMultiHeadAttentionSoftMaxStarGraph(nn.Module): """Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, cutoff_lower, cutoff_upper, use_lora=None, ): # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default super(PAEMultiHeadAttentionSoftMaxStarGraph, self).__init__() assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.share_kv = share_kv self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.cutoff = None self.scaling = self.x_head_dim**-0.5 if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None self.v_proj = lora.Linear(x_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None self.v_proj = nn.Linear(x_channels, x_hidden_channels) self.dk_proj = None self.dk_dist_proj = None self.dv_proj = None self.dv_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2, r=use_lora) self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels + vec_channels * 2) self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) # set PAE weight as a learnable parameter, basiclly a sigmoid function self.pae_weight = nn.Linear(1, 1, bias=True) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) self.pae_weight.weight.data.fill_(-0.5) self.pae_weight.bias.data.fill_(7.5) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) if self.dk_dist_proj: nn.init.xavier_uniform_(self.dk_dist_proj.weight) self.dk_dist_proj.bias.data.fill_(0) if self.dv_dist_proj: nn.init.xavier_uniform_(self.dv_dist_proj.weight) self.dv_dist_proj.bias.data.fill_(0) def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence # edge_index is unused x = self.layernorm(x) q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling v = self.v_proj(x) # if self.share_kv: # k = v[:, :, :self.x_head_dim] # else: k = self.k_proj(x) dk = self.act(self.dk_proj(f_ij)) dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) dv = self.act(self.dv_proj(f_ij)) dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) # full graph attention x, attn = self.attention( q=q, k=k, v=v, dk=dk, dk_dist=dk_dist, dv=dv, dv_dist=dv_dist, w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), key_padding_mask=key_padding_mask, ) if return_attn: return x, attn else: return x, None def attention(self, q, k, v, dk, dk_dist, dv, dv_dist, w_ij, key_padding_mask=None, need_head_weights=False): # note that q is of shape (bsz, tgt_len, num_heads * head_dim) # k, v is of shape (bsz, src_len, num_heads * head_dim) # vec is of shape (bsz, src_len, 3, num_heads * head_dim) # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim) # d_ij is of shape (bsz, tgt_len, src_len, 3) # w_ij is of shape (bsz, tgt_len, src_len) # key_padding_mask is of shape (bsz, src_len) bsz, tgt_len, _ = q.size() src_len = k.size(1) # change q size to (bsz * num_heads, tgt_len, head_dim) # change k,v size to (bsz * num_heads, src_len, head_dim) q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() # dk size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dk is not None: # change dk to (bsz * num_heads, tgt_len, src_len, head_dim) dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dk_dist is not None: # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim) dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # dv size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dv is not None: # change dv to (bsz * num_heads, tgt_len, src_len, head_dim) dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if dv_dist is not None: # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim) dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if key_padding_mask is not None: # key_padding_mask should be (bsz, src_len) assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim) attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) # w_ij is PAE confidence # w_ij size is (bsz, tgt_len, src_len) # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim) # if dk_dist is not None: assert w_ij is not None # if dk is not None: attn_weights *= (dk + dk_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim)) # add dv and dv_dist v = v.unsqueeze(1) + dv + dv_dist * w_ij[:, :, :, None].repeat(self.num_heads, 1, 1, self.x_head_dim + 2 * self.vec_head_dim) # else: # attn_weights *= dk_dist * w_ij # else: # if dk is not None: # attn_weights *= dk # attn_weights size is (bsz * num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=-1) # apply key_padding_mask to attn_weights # if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() # apply softmax to attn_weights attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim) x_out = torch.einsum('bts,btsh->bth', attn_weights, v) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) # if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) # reshape x_out to (bsz, tgt_len, num_heads * head_dim) x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() return x_out, attn_weights class MultiHeadAttentionSoftMaxStarGraph(nn.Module): """Equivariant multi-head attention layer with softmax, apply attention on full graph by default""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, cutoff_lower, cutoff_upper, use_lora=None, ): # same as EquivariantPAEMultiHeadAttentionSoftMax, but apply attention on full graph by default super(MultiHeadAttentionSoftMaxStarGraph, self).__init__() assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.share_kv = share_kv self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.cutoff = None self.scaling = self.x_head_dim**-0.5 if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.k_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) if not share_kv else None self.v_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.k_proj = nn.Linear(x_channels, x_hidden_channels) if not share_kv else None self.v_proj = nn.Linear(x_channels, x_hidden_channels) self.dk_proj = None # self.dk_dist_proj = None self.dv_proj = None # self.dv_dist_proj = None if distance_influence in ["keys", "both"]: if use_lora is not None: self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) # self.dk_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels, r=use_lora) else: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) # self.dk_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels) if distance_influence in ["values", "both"]: if use_lora is not None: self.dv_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) # self.dv_dist_proj = lora.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2, r=use_lora) else: self.dv_proj = nn.Linear(edge_attr_channels, x_hidden_channels) # self.dv_dist_proj = nn.Linear(edge_attr_dist_channels, x_hidden_channels + vec_channels * 2) # set PAE weight as a learnable parameter, basiclly a sigmoid function # self.pae_weight = nn.Linear(1, 1, bias=True) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.k_proj.weight) self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) # self.pae_weight.weight.data.fill_(-0.5) # self.pae_weight.bias.data.fill_(7.5) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) # if self.dk_dist_proj: # nn.init.xavier_uniform_(self.dk_dist_proj.weight) # self.dk_dist_proj.bias.data.fill_(0) # if self.dv_dist_proj: # nn.init.xavier_uniform_(self.dv_dist_proj.weight) # self.dv_dist_proj.bias.data.fill_(0) def forward(self, x, x_center_index, w_ij, f_dist_ij, f_ij, key_padding_mask, return_attn=False): # we replaced r_ij to w_ij as pair-wise confidence # we add plddt as position-wise confidence # edge_index is unused x = self.layernorm(x) q = self.q_proj(x[x_center_index].unsqueeze(1)) * self.scaling v = self.v_proj(x) # if self.share_kv: # k = v[:, :, :self.x_head_dim] # else: k = self.k_proj(x) dk = self.act(self.dk_proj(f_ij)) # dk_dist = self.act(self.dk_dist_proj(f_dist_ij)) dv = self.act(self.dv_proj(f_ij)) # dv_dist = self.act(self.dv_dist_proj(f_dist_ij)) # full graph attention x, attn = self.attention( q=q, k=k, v=v, dk=dk, # dk_dist=dk_dist, dv=dv, # dv_dist=dv_dist, # w_ij=nn.functional.sigmoid(self.pae_weight(w_ij.unsqueeze(-1)).squeeze(-1)), key_padding_mask=key_padding_mask, ) if return_attn: return x, attn else: return x, None def attention(self, q, k, v, dk, dv, key_padding_mask=None, need_head_weights=False): # note that q is of shape (bsz, tgt_len, num_heads * head_dim) # k, v is of shape (bsz, src_len, num_heads * head_dim) # vec is of shape (bsz, src_len, 3, num_heads * head_dim) # dk, dk_dist, dv, dv_dist is of shape (bsz, tgt_len, src_len, num_heads * head_dim) # d_ij is of shape (bsz, tgt_len, src_len, 3) # w_ij is of shape (bsz, tgt_len, src_len) # key_padding_mask is of shape (bsz, src_len) bsz, tgt_len, _ = q.size() src_len = k.size(1) # change q size to (bsz * num_heads, tgt_len, head_dim) # change k,v size to (bsz * num_heads, src_len, head_dim) q = q.transpose(0, 1).reshape(tgt_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() k = k.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() v = v.transpose(0, 1).reshape(src_len, bsz * self.num_heads, self.x_head_dim).transpose(0, 1).contiguous() # dk size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dk is not None: # change dk to (bsz * num_heads, tgt_len, src_len, head_dim) dk = dk.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dk_dist is not None: # change dk_dist to (bsz * num_heads, tgt_len, src_len, head_dim) # dk_dist = dk_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # dv size is (bsz, tgt_len, src_len, num_heads * head_dim) # if dv is not None: # change dv to (bsz * num_heads, tgt_len, src_len, head_dim) dv = dv.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim).permute(2, 0, 1, 3).contiguous() # if dv_dist is not None: # change dv_dist to (bsz * num_heads, tgt_len, src_len, head_dim) # dv_dist = dv_dist.permute(1, 2, 0, 3).reshape(tgt_len, src_len, bsz * self.num_heads, self.x_head_dim + 2 * self.vec_head_dim).permute(2, 0, 1, 3).contiguous() # if key_padding_mask is not None: # key_padding_mask should be (bsz, src_len) assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len # attn_weights size is (bsz * num_heads, tgt_len, src_len, head_dim) attn_weights = torch.multiply(q[:, :, None, :], k[:, None, :, :]) # w_ij is PAE confidence # w_ij size is (bsz, tgt_len, src_len) # change dimension of w_ij to (bsz * num_heads, tgt_len, src_len, head_dim) # if dk_dist is not None: # assert w_ij is not None # if dk is not None: attn_weights *= dk # add dv and dv_dist v = v.unsqueeze(1) + dv # else: # attn_weights *= dk_dist * w_ij # else: # if dk is not None: # attn_weights *= dk # attn_weights size is (bsz * num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=-1) # apply key_padding_mask to attn_weights # if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).contiguous() attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len).contiguous() # apply softmax to attn_weights attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) # first get invariant feature outputs x, size is (bsz * num_heads, tgt_len, head_dim) x_out = torch.einsum('bts,btsh->bth', attn_weights, v) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) # if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) # reshape x_out to (bsz, tgt_len, num_heads * head_dim) x_out = x_out.transpose(1, 0).reshape(tgt_len, bsz, self.num_heads * self.x_head_dim).transpose(1, 0).contiguous() return x_out, attn_weights # original torchmd-net attention layer, let k, v share the same projection class EquivariantProMultiHeadAttention(MessagePassing): """Equivariant multi-head attention layer.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, ): super(EquivariantMultiHeadAttention, self).__init__( aggr="mean", node_dim=0) assert x_hidden_channels % num_heads == 0 \ and vec_channels % num_heads == 0, ( f"The number of hidden channels x_hidden_channels ({x_hidden_channels}) " f"and vec_channels ({vec_channels}) " f"must be evenly divisible by the number of " f"attention heads ({num_heads})" ) assert vec_hidden_channels == x_channels, ( f"The number of hidden channels x_channels ({x_channels}) " f"and vec_hidden_channels ({vec_hidden_channels}) " f"must be equal" ) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels # important, not vec_hidden_channels // num_heads self.vec_head_dim = vec_channels // num_heads self.layernorm = nn.LayerNorm(x_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.q_proj = nn.Linear(x_channels, x_hidden_channels) # self.k_proj = nn.Linear(x_channels, x_hidden_channels) self.kv_proj = nn.Linear( x_channels, x_hidden_channels + vec_channels * 2) self.o_proj = nn.Linear( x_hidden_channels, x_channels * 2 + vec_channels) self.vec_proj = nn.Linear( vec_channels, vec_hidden_channels * 2 + vec_channels, bias=False) self.dk_proj = None if distance_influence in ["keys", "both"]: self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.dv_proj = None if distance_influence in ["values", "both"]: self.dv_proj = nn.Linear( edge_attr_channels, x_hidden_channels + vec_channels * 2) self.reset_parameters() def reset_parameters(self): self.layernorm.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.k_proj.weight) # self.k_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.kv_proj.weight) self.kv_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.vec_proj.weight) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) if self.dv_proj: nn.init.xavier_uniform_(self.dv_proj.weight) self.dv_proj.bias.data.fill_(0) def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, return_attn=False): x = self.layernorm(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) # k = self.k_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) k = v[:, :, :self.x_head_dim] vec1, vec2, vec3 = torch.split(self.vec_proj(vec), [self.vec_hidden_channels, self.vec_hidden_channels, self.vec_channels], dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.vec_head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) dv = ( self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.x_head_dim + self.vec_head_dim * 2) if self.dv_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, vec, attn = self.propagate( edge_index, q=q, k=k, v=v, vec=vec, dk=dk, dv=dv, r_ij=r_ij, d_ij=d_ij, size=None, ) x = x.reshape(-1, self.x_hidden_channels) vec = vec.reshape(-1, 3, self.vec_channels) o1, o2, o3 = torch.split(self.o_proj( x), [self.vec_channels, self.x_channels, self.x_channels], dim=1) dx = vec_dot * o2 + o3 dvec = vec3 * o1.unsqueeze(1) + vec if return_attn: return dx, dvec, torch.concat((edge_index.T, attn), dim=1) else: return dx, dvec, None def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j * dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) # value pathway if dv is not None: v_j = v_j * dv x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * \ d_ij.unsqueeze(2).unsqueeze(3) return x, vec, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, vec, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size) vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) return x, vec, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # softmax version of torchmd-net attention layer class EquivariantMultiHeadAttentionSoftMax(EquivariantMultiHeadAttention): """Equivariant multi-head attention layer with softmax""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantMultiHeadAttentionSoftMax, self).__init__(x_channels=x_channels, x_hidden_channels=x_hidden_channels, vec_channels=vec_channels, vec_hidden_channels=vec_hidden_channels, share_kv=share_kv, edge_attr_channels=edge_attr_channels, distance_influence=distance_influence, num_heads=num_heads, activation=activation, attn_activation=attn_activation, cutoff_lower=cutoff_lower, cutoff_upper=cutoff_upper, use_lora=use_lora) self.attn_activation = nn.LeakyReLU(0.2) def message(self, q_i, k_j, v_j, vec_j, dk, dv, r_ij, d_ij, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j * dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) attn = softmax(attn, index, ptr, size_i) # TODO: consider drop out attn or not. # attn = F.dropout(attn, p=self.dropout, training=self.training) # value pathway if dv is not None: v_j = v_j * dv x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ * attn.unsqueeze(1).unsqueeze(3) return x, vec, attn # softmax version of torchmd-net attention layer, add pair-wise confidence of PAE class EquivariantPAEMultiHeadAttentionSoftMax(EquivariantPAEMultiHeadAttention): """Equivariant multi-head attention layer with softmax""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantPAEMultiHeadAttentionSoftMax, self).__init__( x_channels=x_channels, x_hidden_channels=x_hidden_channels, vec_channels=vec_channels, vec_hidden_channels=vec_hidden_channels, share_kv=share_kv, edge_attr_channels=edge_attr_channels, edge_attr_dist_channels=edge_attr_dist_channels, distance_influence=distance_influence, num_heads=num_heads, activation=activation, attn_activation=attn_activation, cutoff_lower=cutoff_lower, cutoff_upper=cutoff_upper, use_lora=use_lora) self.attn_activation = nn.LeakyReLU(0.2) def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]): # attention mechanism attn = (q_i * k_j) if dk is not None: attn += dk if dk_dist is not None: attn += dk_dist * w_ij.unsqueeze(1).unsqueeze(2) attn = attn.sum(dim=-1) # attention activation function attn = self.attn_activation(attn) attn = softmax(attn, index, ptr, size_i) # TODO: consider drop out attn or not. # attn = F.dropout(attn, p=self.dropout, training=self.training) # value pathway if dv is not None: v_j += dv if dv_dist is not None: v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ * attn.unsqueeze(1).unsqueeze(3) return x, vec, attn # softmax version of torchmd-net attention layer, add pair-wise confidence of PAE class EquivariantWeightedPAEMultiHeadAttentionSoftMax(EquivariantWeightedPAEMultiHeadAttention): """Equivariant multi-head attention layer with softmax""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, share_kv, edge_attr_channels, edge_attr_dist_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, use_lora=None, ): super(EquivariantWeightedPAEMultiHeadAttentionSoftMax, self).__init__( x_channels=x_channels, x_hidden_channels=x_hidden_channels, vec_channels=vec_channels, vec_hidden_channels=vec_hidden_channels, share_kv=share_kv, edge_attr_channels=edge_attr_channels, edge_attr_dist_channels=edge_attr_dist_channels, distance_influence=distance_influence, num_heads=num_heads, activation=activation, attn_activation=attn_activation, cutoff_lower=cutoff_lower, cutoff_upper=cutoff_upper, use_lora=use_lora) self.attn_activation = nn.LeakyReLU(0.2) def message(self, q_i, k_j, v_j, vec_j, dk, dk_dist, dv, dv_dist, d_ij, w_ij, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]): # attention mechanism attn = (q_i * k_j) if dk_dist is not None: if dk is not None: attn *= (dk + dk_dist * w_ij.unsqueeze(1).unsqueeze(2)) else: attn *= dk_dist * w_ij else: if dk is not None: attn *= dk attn = attn.sum(dim=-1) # attention activation function attn = self.attn_activation(attn) attn = softmax(attn, index, ptr, size_i) # TODO: consider drop out attn or not. # attn = F.dropout(attn, p=self.dropout, training=self.training) # value pathway if dv is not None: v_j += dv if dv_dist is not None: v_j += dv_dist * w_ij.unsqueeze(1).unsqueeze(2) x, vec1, vec2 = torch.split( v_j, [self.x_head_dim, self.vec_head_dim, self.vec_head_dim], dim=2) # update scalar features x = x * attn.unsqueeze(2) # update vector features vec = (vec1.unsqueeze(1) * vec_j + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3)) \ * attn.unsqueeze(1).unsqueeze(3) return x, vec, attn # MSA encoder adapted from gMVP class MSAEncoder(nn.Module): def __init__(self, num_species, pairwise_type, weighting_schema): """[summary] Args: num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices "none" -> uniformly weight species Raises: NotImplementedError: [description] """ super(MSAEncoder, self).__init__() self.num_species = num_species self.pairwise_type = pairwise_type self.weighting_schema = weighting_schema if self.weighting_schema == 'spe': self.W = nn.parameter.Parameter( torch.zeros((1, num_species)), requires_grad=True) elif self.weighting_schema == 'none': self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species) else: raise NotImplementedError def forward(self, x, edge_index): # x: L nodes x N num_species shape = x.shape L, N = shape[0], shape[1] E = edge_index.shape[1] A = 21 # number of amino acids x = x[:, :self.num_species] if self.weighting_schema == 'spe': sm = torch.nn.Softmax(dim=-1) W = sm(self.W) else: W = self.W x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) # L x N x A x1 = torch.matmul(W[:, None], x) # L x 1 x A if self.pairwise_type == 'fre': x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) # E x N x A x A x2 = x2.reshape((E, N, A * A)) # E x N x (A x A) x2 = (W[:, :, None] * x2).sum(dim=1) # E x (A x A) elif self.pairwise_type == 'cov': #numerical stability x2 = torch.matmul(x[edge_index[0], :, :, None], x[edge_index[1], :, None, :]) # E x N x A x A x2 = (W[:, :, None, None] * x2).sum(dim=1) # E x A x A x2_t = x1[edge_index[0], 0, :, None] * x1[edge_index[1], 0, None, :] # E x A x A x2 = (x2 - x2_t).reshape(E, A * A) # E x (A x A) x2 = x2.reshape(E, A * A) # E x (A x A) norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12) x2 = torch.cat([x2, norm], dim=-1) # E x (A x A + 1) elif self.pairwise_type == 'cov_all': print('cov_all not implemented in EvolEncoder2') raise NotImplementedError elif self.pairwise_type == 'inv_cov': print('in_cov not implemented in EvolEncoder2') raise NotImplementedError elif self.pairwise_type == 'none': x2 = None else: raise NotImplementedError( f'pairwise_type {self.pairwise_type} not implemented') x1 = torch.squeeze(x1, dim=1) # L x A return x1, x2 # MSA encoder adapted from gMVP class MSAEncoderFullGraph(nn.Module): def __init__(self, num_species, pairwise_type, weighting_schema): """[summary] Args: num_species (int): Number of species to use from MSA. [1,200] // 200 used in default gMVP pairwise_type ([str]): method for calculating pairwise coevolution. only "cov" supported weighting_schema ([str]): species weighting type; "spe" -> use dense layer to weight speices "none" -> uniformly weight species Raises: NotImplementedError: [description] """ super(MSAEncoderFullGraph, self).__init__() self.num_species = num_species self.pairwise_type = pairwise_type self.weighting_schema = weighting_schema if self.weighting_schema == 'spe': self.W = nn.parameter.Parameter( torch.zeros((num_species)), requires_grad=True) elif self.weighting_schema == 'none': self.W = torch.tensor(1.0 / self.num_species).repeat(self.num_species) else: raise NotImplementedError def forward(self, x): # x: B batch size x L lenth x N num_species shape = x.shape B, L, N = shape[0], shape[1], shape[2] A = 21 # number of amino acids x = x[:, :, :self.num_species] if self.weighting_schema == 'spe': W = torch.nn.functional.softmax(self.W, dim=-1) else: W = self.W x = nn.functional.one_hot(x.type(torch.int64), A).type(torch.float32) # B x L x N x A x1 = torch.einsum('blna,n->bla', x, W) # B x L x A if self.pairwise_type == 'cov': #numerical stability # x2 = torch.einsum('bLnA,blna,n->bLlAa', x, x, W) # B x L x L x A x A, check if ram supports this # x2_t = x1[:, :, None, :, None] * x1[:, None, :, None, :] # B x L x L x A x A # x2 = (x2 - x2_t).reshape(B, L, L, A * A) # B x L x L x (A x A) # complete that in one line to save memory x2 = (torch.einsum('bLnA,blna,n->bLlAa', x, x, W) - x1[:, :, None, :, None] * x1[:, None, :, None, :]).reshape(B, L, L, A * A) norm = torch.sqrt(torch.sum(torch.square(x2), dim=-1, keepdim=True) + 1e-12) # B x L x L x 1 x2 = torch.cat([x2, norm], dim=-1) # B x L x L x (A x A + 1) elif self.pairwise_type == 'cov_all': print('cov_all not implemented in EvolEncoder2') raise NotImplementedError elif self.pairwise_type == 'inv_cov': print('in_cov not implemented in EvolEncoder2') raise NotImplementedError elif self.pairwise_type == 'none': x2 = None else: raise NotImplementedError( f'pairwise_type {self.pairwise_type} not implemented') return x1, x2 class NodeToEdgeAttr(nn.Module): def __init__(self, node_channel, hidden_channel, edge_attr_channel, use_lora=None, layer_norm=False): super().__init__() self.layer_norm = layer_norm if layer_norm: self.layernorm = nn.LayerNorm(node_channel) if use_lora is not None: self.proj = lora.Linear(node_channel, hidden_channel * 2, bias=True, r=use_lora) self.o_proj = lora.Linear(2 * hidden_channel, edge_attr_channel, r=use_lora) else: self.proj = nn.Linear(node_channel, hidden_channel * 2, bias=True) self.o_proj = nn.Linear(2 * hidden_channel, edge_attr_channel, bias=True) torch.nn.init.zeros_(self.proj.bias) torch.nn.init.zeros_(self.o_proj.bias) def forward(self, x, edge_index): """ Inputs: x: N x sequence_state_dim Output: edge_attr: edge_index.shape[0] x pairwise_state_dim Intermediate state: B x L x L x 2*inner_dim """ x = self.layernorm(x) if self.layer_norm else x q, k = self.proj(x).chunk(2, dim=-1) prod = q[edge_index[0], :] * k[edge_index[1], :] diff = q[edge_index[0], :] - k[edge_index[1], :] edge_attr = torch.cat([prod, diff], dim=-1) edge_attr = self.o_proj(edge_attr) return edge_attr class MultiplicativeUpdate(MessagePassing): def __init__(self, vec_in_channel, hidden_channel, hidden_vec_channel, ee_channels=None, use_lora=None, layer_norm=True) -> None: super(MultiplicativeUpdate, self).__init__(aggr="mean") self.vec_in_channel = vec_in_channel self.hidden_channel = hidden_channel self.hidden_vec_channel = hidden_vec_channel if use_lora is not None: self.linear_a_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora) self.linear_b_p = lora.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False, r=use_lora) self.linear_g = lora.Linear(self.hidden_vec_channel, self.hidden_channel, r=use_lora) else: self.linear_a_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False) self.linear_b_p = nn.Linear(self.vec_in_channel, self.hidden_vec_channel, bias=False) self.linear_g = nn.Linear(self.hidden_vec_channel, self.hidden_channel) if ee_channels is not None: if use_lora is not None: self.linear_ee = lora.Linear(ee_channels, self.hidden_channel, r=use_lora) else: self.linear_ee = nn.Linear(ee_channels, self.hidden_channel) else: self.linear_ee = None self.layer_norm = layer_norm if layer_norm: self.layer_norm_in = nn.LayerNorm(self.hidden_channel) self.layer_norm_out = nn.LayerNorm(self.hidden_channel) self.sigmoid = nn.Sigmoid() def forward(self, edge_attr: torch.Tensor, edge_vec: torch.Tensor, edge_edge_index: torch.Tensor, edge_edge_attr: torch.Tensor, ) -> torch.Tensor: """ Args: edge_vec: [*, 3, in_channel] input tensor edge_attr: [*, hidden_channel] input mask Returns: [*, hidden_channel] output tensor """ if self.layer_norm: x = self.layer_norm_in(edge_attr) x = self.propagate(edge_index=edge_edge_index, a=self.linear_a_p(edge_vec).reshape(edge_attr.shape[0], -1), b=self.linear_b_p(edge_vec).reshape(edge_attr.shape[0], -1), edge_attr=x, ee_ij=edge_edge_attr, ) if self.layer_norm: x = self.layer_norm_out(x) edge_attr = edge_attr + x return edge_attr def message(self, a_i: Tensor, b_j: Tensor, edge_attr_j: Tensor, ee_ij: Tensor,) -> Tensor: # a: [*, 3, hidden_channel] # b: [*, 3, hidden_channel] s = (a_i.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1) \ * b_j.reshape(-1, 3, self.hidden_vec_channel).permute(0, 2, 1)).sum(dim=-1) if ee_ij is not None and self.linear_ee is not None: s = self.sigmoid(self.linear_ee(ee_ij) + self.linear_g(s)) else: s = self.sigmoid(self.linear_g(s)) return s * edge_attr_j # let k v share the same weight class EquivariantTriAngularMultiHeadAttention(MessagePassing): """Equivariant multi-head attention layer. Add Triangular update between edges.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, triangular_update=False, ee_channels=None, ): super(EquivariantTriAngularMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels self.ee_channels = ee_channels # important, not vec_hidden_channels // num_heads self.layernorm_in = nn.LayerNorm(x_channels) self.layernorm_out = nn.LayerNorm(x_hidden_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.kv_proj = nn.Linear(x_channels, x_hidden_channels) # self.v_proj = nn.Linear(x_channels, x_hidden_channels) self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) self.out = nn.Linear(x_hidden_channels, x_channels) # add residue to x # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels) self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.triangular_update = triangular_update if self.triangular_update: self.edge_triangle_start_update = MultiplicativeUpdate(vec_in_channel=vec_channels, hidden_channel=edge_attr_channels, hidden_vec_channel=vec_hidden_channels, ee_channels=ee_channels, ) self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, hidden_channel=edge_attr_channels, hidden_vec_channel=vec_hidden_channels, ee_channels=ee_channels, ) self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, hidden_channel=x_hidden_channels, edge_attr_channel=edge_attr_channels) self.reset_parameters() def reset_parameters(self): self.layernorm_in.reset_parameters() self.layernorm_out.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.kv_proj.weight) self.kv_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.v_proj.weight) # self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) def get_start_index(self, edge_index): edge_start_index = [] start_node_count = edge_index[0].unique(return_counts=True) start_nodes = start_node_count[0][start_node_count[1] > 1] for i in start_nodes: node_start_index = torch.where(edge_index[0] == i)[0] candidates = torch.combinations(node_start_index, r=2).T edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_start_index = torch.concat(edge_start_index, dim=1) edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] return edge_start_index def get_end_index(self, edge_index): edge_end_index = [] end_node_count = edge_index[1].unique(return_counts=True) end_nodes = end_node_count[0][end_node_count[1] > 1] for i in end_nodes: node_end_index = torch.where(edge_index[1] == i)[0] candidates = torch.combinations(node_end_index, r=2).T edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_end_index = torch.concat(edge_end_index, dim=1) edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] return edge_end_index def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): residue = x x = self.layernorm_in(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = k # point ettr to edge_attr if self.triangular_update: edge_attr += self.node_to_edge_attr(x, edge_index) # Triangular edge update # TODO: Add drop out layers here edge_edge_index = self.get_start_index(edge_index) if self.ee_channels is not None: edge_edge_attr = coords[edge_index[1][edge_edge_index[0]], :, [0]] - coords[edge_index[1][edge_edge_index[1]], :, [0]] edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) else: edge_edge_attr = None edge_attr = self.edge_triangle_start_update( edge_attr, edge_vec, edge_edge_index, edge_edge_attr ) edge_edge_index = self.get_end_index(edge_index) if self.ee_channels is not None: edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) else: edge_edge_attr = None edge_attr = self.edge_triangle_end_update( edge_attr, edge_vec, edge_edge_index, edge_edge_attr ) del edge_edge_attr, edge_edge_index dk = ( self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, attn = self.propagate( edge_index, q=q, k=k, v=v, dk=dk, size=None, ) x = x.reshape(-1, self.x_hidden_channels) x = residue + x x = self.layernorm_out(x) x = gelu(self.o_proj(x)) x = self.out(x) del residue, q, k, v, dk if return_attn: return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) else: return x, edge_attr, None def message(self, q_i, k_j, v_j, dk): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j * dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) # update scalar features x = v_j * attn.unsqueeze(2) return x, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) return x, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # let k v share the same weight, dropout attention weights, with option LoRA class EquivariantTriAngularDropMultiHeadAttention(MessagePassing): """Equivariant multi-head attention layer. Add Triangular update between edges.""" def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, rbf_channels, triangular_update=False, ee_channels=None, drop_out_rate=0.0, use_lora=None, layer_norm=True, ): super(EquivariantTriAngularDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels self.ee_channels = ee_channels self.rbf_channels = rbf_channels self.layer_norm = layer_norm # important, not vec_hidden_channels // num_heads if layer_norm: self.layernorm_in = nn.LayerNorm(x_channels) self.layernorm_out = nn.LayerNorm(x_hidden_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) self.o_proj = lora.Linear(x_hidden_channels, x_hidden_channels, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.kv_proj = nn.Linear(x_channels, x_hidden_channels) self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) self.triangular_drop = nn.Dropout(drop_out_rate) self.rbf_drop = nn.Dropout(drop_out_rate) self.dense_drop = nn.Dropout(drop_out_rate) self.dropout = nn.Dropout(drop_out_rate) self.triangular_update = triangular_update if self.triangular_update: self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, hidden_channel=edge_attr_channels, hidden_vec_channel=vec_hidden_channels, ee_channels=ee_channels, layer_norm=layer_norm, use_lora=use_lora) self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, hidden_channel=x_hidden_channels, edge_attr_channel=edge_attr_channels, use_lora=use_lora) self.triangle_update_dropout = nn.Dropout(0.5) self.reset_parameters() def reset_parameters(self): if self.layer_norm: self.layernorm_in.reset_parameters() self.layernorm_out.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.kv_proj.weight) self.kv_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.v_proj.weight) # self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.o_proj.weight) self.o_proj.bias.data.fill_(0) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) def get_start_index(self, edge_index): edge_start_index = [] start_node_count = edge_index[0].unique(return_counts=True) start_nodes = start_node_count[0][start_node_count[1] > 1] for i in start_nodes: node_start_index = torch.where(edge_index[0] == i)[0] candidates = torch.combinations(node_start_index, r=2).T edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_start_index = torch.concat(edge_start_index, dim=1) edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] return edge_start_index def get_end_index(self, edge_index): edge_end_index = [] end_node_count = edge_index[1].unique(return_counts=True) end_nodes = end_node_count[0][end_node_count[1] > 1] for i in end_nodes: node_end_index = torch.where(edge_index[1] == i)[0] candidates = torch.combinations(node_end_index, r=2).T edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_end_index = torch.concat(edge_end_index, dim=1) edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] return edge_end_index def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): residue = x if self.layer_norm: x = self.layernorm_in(x) q = self.q_proj(x).reshape(-1, self.num_heads, self.x_head_dim) k = self.kv_proj(x).reshape(-1, self.num_heads, self.x_head_dim) v = k # point ettr to edge_attr if self.triangular_update: edge_attr += self.node_to_edge_attr(x, edge_index) # Triangular edge update # TODO: Add drop out layers here edge_edge_index = self.get_end_index(edge_index) edge_edge_index = edge_edge_index[:, self.triangular_drop( torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device) ).to(torch.bool)] if self.ee_channels is not None: edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) else: edge_edge_attr = None edge_attr = self.edge_triangle_end_update( edge_attr, edge_vec, edge_edge_index, edge_edge_attr ) del edge_edge_attr, edge_edge_index # drop rbfs edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels], self.rbf_drop(edge_attr[:, -self.rbf_channels:])), dim=-1) dk = ( self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, attn = self.propagate( edge_index, q=q, k=k, v=v, dk=dk, size=None, ) x = x.reshape(-1, self.x_hidden_channels) if self.layer_norm: x = self.layernorm_out(x) x = self.dense_drop(x) x = residue + gelu(x) x = self.o_proj(x) x = self.dropout(x) del residue, q, k, v, dk if return_attn: return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) else: return x, edge_attr, None def message(self, q_i, k_j, v_j, dk): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j * dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) # update scalar features x = v_j * attn.unsqueeze(2) return x, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) return x, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # let k v share the same weight class EquivariantTriAngularStarMultiHeadAttention(MessagePassing): """ Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node. """ def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, cutoff_lower, cutoff_upper, triangular_update=False, ee_channels=None, ): super(EquivariantTriAngularStarMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels self.ee_channels = ee_channels # important, not vec_hidden_channels // num_heads # self.layernorm_in = nn.LayerNorm(x_channels) self.layernorm_out = nn.LayerNorm(x_hidden_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.kv_proj = nn.Linear(x_channels, x_hidden_channels) # self.v_proj = nn.Linear(x_channels, x_hidden_channels) # self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) # self.out = nn.Linear(x_hidden_channels, x_channels) # add residue to x # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels) self.gru = nn.GRUCell(x_channels, x_channels) self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) self.triangular_update = triangular_update if self.triangular_update: # self.edge_triangle_start_update = MultiplicativeUpdate(vec_in_channel=vec_channels, # hidden_channel=edge_attr_channels, # hidden_vec_channel=vec_hidden_channels, # ee_channels=ee_channels, ) self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, hidden_channel=edge_attr_channels, hidden_vec_channel=vec_hidden_channels, ee_channels=ee_channels, ) self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, hidden_channel=x_hidden_channels, edge_attr_channel=edge_attr_channels) self.reset_parameters() def reset_parameters(self): # self.layernorm_in.reset_parameters() self.layernorm_out.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.kv_proj.weight) self.kv_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.v_proj.weight) # self.v_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.o_proj.weight) # self.o_proj.bias.data.fill_(0) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) def get_start_index(self, edge_index): edge_start_index = [] start_node_count = edge_index[0].unique(return_counts=True) start_nodes = start_node_count[0][start_node_count[1] > 1] for i in start_nodes: node_start_index = torch.where(edge_index[0] == i)[0] candidates = torch.combinations(node_start_index, r=2).T edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_start_index = torch.concat(edge_start_index, dim=1) edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] return edge_start_index def get_end_index(self, edge_index): edge_end_index = [] end_node_count = edge_index[1].unique(return_counts=True) end_nodes = end_node_count[0][end_node_count[1] > 1] for i in end_nodes: node_end_index = torch.where(edge_index[1] == i)[0] candidates = torch.combinations(node_end_index, r=2).T edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_end_index = torch.concat(edge_end_index, dim=1) edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] return edge_end_index def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): # perform topK pooling end_node_count = edge_index[1].unique(return_counts=True) center_nodes = end_node_count[0][end_node_count[1] > 1] other_nodes = end_node_count[0][end_node_count[1] <= 1] residue = x[center_nodes] # batch_size * x_channels # filter edge_index and edge_attr to from context to center only edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :] edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :] edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)] # x itself is q, k and v q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim) kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim) qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True) qkv[center_nodes] = q qkv[other_nodes] = kv # point ettr to edge_attr if self.triangular_update: edge_attr += self.node_to_edge_attr(x, edge_index) # Triangular edge update # TODO: Add drop out layers here edge_edge_index = self.get_end_index(edge_index) if self.ee_channels is not None: edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) else: edge_edge_attr = None edge_attr = self.edge_triangle_end_update( edge_attr, edge_vec, edge_edge_index, edge_edge_attr ) del edge_edge_attr, edge_edge_index dk = ( self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) # TODO: check self.act # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, attn = self.propagate( edge_index, q=qkv, k=qkv, v=qkv, dk=dk, size=None, ) x = x.reshape(-1, self.x_hidden_channels) # only get the center nodes x = x[center_nodes] x = self.layernorm_out(x) x = self.gru(residue, x) del residue, dk if return_attn: return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) else: return x, edge_attr, None def message(self, q_i, k_j, v_j, dk): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j + dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) / self.x_head_dim # update scalar features x = v_j * attn.unsqueeze(2) return x, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) return x, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # let k v share the same weight, dropout attention weights, with option LoRA class EquivariantTriAngularStarDropMultiHeadAttention(MessagePassing): """ Equivariant multi-head attention layer. Add Triangular update between edges. Only update the center node. """ def __init__( self, x_channels, x_hidden_channels, vec_channels, vec_hidden_channels, edge_attr_channels, distance_influence, num_heads, activation, attn_activation, rbf_channels, triangular_update=False, ee_channels=None, drop_out_rate=0.0, use_lora=None, ): super(EquivariantTriAngularStarDropMultiHeadAttention, self).__init__(aggr="mean", node_dim=0) self.distance_influence = distance_influence self.num_heads = num_heads self.x_channels = x_channels self.x_hidden_channels = x_hidden_channels self.x_head_dim = x_hidden_channels // num_heads self.vec_channels = vec_channels self.vec_hidden_channels = vec_hidden_channels self.ee_channels = ee_channels self.rbf_channels = rbf_channels # important, not vec_hidden_channels // num_heads # self.layernorm_in = nn.LayerNorm(x_channels) self.layernorm_out = nn.LayerNorm(x_hidden_channels) self.act = activation() self.attn_activation = act_class_mapping[attn_activation]() if use_lora is not None: self.q_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.kv_proj = lora.Linear(x_channels, x_hidden_channels, r=use_lora) self.dk_proj = lora.Linear(edge_attr_channels, x_hidden_channels, r=use_lora) else: self.q_proj = nn.Linear(x_channels, x_hidden_channels) self.kv_proj = nn.Linear(x_channels, x_hidden_channels) self.dk_proj = nn.Linear(edge_attr_channels, x_hidden_channels) # self.v_proj = nn.Linear(x_channels, x_hidden_channels) # self.o_proj = nn.Linear(x_hidden_channels, x_hidden_channels) # self.out = nn.Linear(x_hidden_channels, x_channels) # add residue to x # self.residue_hidden = nn.Linear(x_channels, x_hidden_channels) self.gru = nn.GRUCell(x_channels, x_channels) self.triangular_drop = nn.Dropout(drop_out_rate) self.rbf_drop = nn.Dropout(drop_out_rate) self.dense_drop = nn.Dropout(drop_out_rate) self.dropout = nn.Dropout(drop_out_rate) self.triangular_update = triangular_update if self.triangular_update: self.edge_triangle_end_update = MultiplicativeUpdate(vec_in_channel=vec_channels, hidden_channel=edge_attr_channels, hidden_vec_channel=vec_hidden_channels, ee_channels=ee_channels, use_lora=use_lora) self.node_to_edge_attr = NodeToEdgeAttr(node_channel=x_channels, hidden_channel=x_hidden_channels, edge_attr_channel=edge_attr_channels, use_lora=use_lora) self.triangle_update_dropout = nn.Dropout(0.5) self.reset_parameters() def reset_parameters(self): # self.layernorm_in.reset_parameters() self.layernorm_out.reset_parameters() nn.init.xavier_uniform_(self.q_proj.weight) self.q_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.kv_proj.weight) self.kv_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.v_proj.weight) # self.v_proj.bias.data.fill_(0) # nn.init.xavier_uniform_(self.o_proj.weight) # self.o_proj.bias.data.fill_(0) if self.dk_proj: nn.init.xavier_uniform_(self.dk_proj.weight) self.dk_proj.bias.data.fill_(0) def get_start_index(self, edge_index): edge_start_index = [] start_node_count = edge_index[0].unique(return_counts=True) start_nodes = start_node_count[0][start_node_count[1] > 1] for i in start_nodes: node_start_index = torch.where(edge_index[0] == i)[0] candidates = torch.combinations(node_start_index, r=2).T edge_start_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_start_index = torch.concat(edge_start_index, dim=1) edge_start_index = edge_start_index[:, edge_start_index[0] != edge_start_index[1]] return edge_start_index def get_end_index(self, edge_index): edge_end_index = [] end_node_count = edge_index[1].unique(return_counts=True) end_nodes = end_node_count[0][end_node_count[1] > 1] for i in end_nodes: node_end_index = torch.where(edge_index[1] == i)[0] candidates = torch.combinations(node_end_index, r=2).T edge_end_index.append(torch.cat([candidates, candidates.flip(0)], dim=1)) edge_end_index = torch.concat(edge_end_index, dim=1) edge_end_index = edge_end_index[:, edge_end_index[0] != edge_end_index[1]] return edge_end_index def forward(self, x, coords, edge_index, edge_attr, edge_vec, return_attn=False): # perform topK pooling end_node_count = edge_index[1].unique(return_counts=True) center_nodes = end_node_count[0][end_node_count[1] > 1] other_nodes = end_node_count[0][end_node_count[1] <= 1] residue = x[center_nodes] # batch_size * x_channels # filter edge_index and edge_attr to from context to center only edge_attr = edge_attr[torch.isin(edge_index[1], center_nodes), :] edge_vec = edge_vec[torch.isin(edge_index[1], center_nodes), :] edge_index = edge_index[:, torch.isin(edge_index[1], center_nodes)] # x itself is q, k and v q = self.q_proj(residue).reshape(-1, self.num_heads, self.x_head_dim) kv = self.kv_proj(x[other_nodes]).reshape(-1, self.num_heads, self.x_head_dim) qkv = torch.zeros(x.shape[0], self.num_heads, self.x_head_dim).to(x.device, non_blocking=True) qkv[center_nodes] = q qkv[other_nodes] = kv # point ettr to edge_attr if self.triangular_update: edge_attr += self.node_to_edge_attr(x, edge_index) # Triangular edge update # TODO: Add drop out layers here edge_edge_index = self.get_end_index(edge_index) edge_edge_index = edge_edge_index[:, self.triangular_drop( torch.ones(edge_edge_index.shape[1], device=edge_edge_index.device) ).to(torch.bool)] if self.ee_channels is not None: edge_edge_attr = coords[edge_index[0][edge_edge_index[0]], :, [0]] - coords[edge_index[0][edge_edge_index[1]], :, [0]] edge_edge_attr = torch.norm(edge_edge_attr, dim=-1, keepdim=True) else: edge_edge_attr = None edge_attr = self.edge_triangle_end_update( edge_attr, edge_vec, edge_edge_index, edge_edge_attr ) del edge_edge_attr, edge_edge_index # drop rbfs edge_attr = torch.cat((edge_attr[:, :-self.rbf_channels], self.rbf_drop(edge_attr[:, -self.rbf_channels:])), dim=-1) dk = ( self.act(self.dk_proj(edge_attr)).reshape(-1, self.num_heads, self.x_head_dim) if self.dk_proj is not None else None ) # TODO: check self.act # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, # d_ij: Tensor) x, attn = self.propagate( edge_index, q=qkv, k=qkv, v=qkv, dk=dk, size=None, ) x = x.reshape(-1, self.x_hidden_channels) # only get the center nodes x = x[center_nodes] x = self.layernorm_out(x) x = self.dense_drop(x) x = self.gru(residue, x) x = self.dropout(x) del residue, dk if return_attn: return x, edge_attr, torch.concat((edge_index.T, attn), dim=1) else: return x, edge_attr, None def message(self, q_i, k_j, v_j, dk): # attention mechanism if dk is None: attn = (q_i * k_j).sum(dim=-1) else: # TODO: consider add or multiply dk attn = (q_i * k_j + dk).sum(dim=-1) # attention activation function attn = self.attn_activation(attn) / self.x_head_dim # update scalar features x = v_j * attn.unsqueeze(2) return x, attn def aggregate( self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x, attn = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) return x, attn def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: return inputs def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: pass def edge_update(self) -> Tensor: pass # Transform sequence, structure, and relative position into a pair feature class PairFeatureNet(nn.Module): def __init__(self, c_s, c_p, relpos_k=32, template_type="exp-normal-smearing-distance"): super(PairFeatureNet, self).__init__() self.c_s = c_s self.c_p = c_p self.linear_s_p_i = nn.Linear(c_s, c_p) self.linear_s_p_j = nn.Linear(c_s, c_p) self.relpos_k = relpos_k self.n_bin = 2 * relpos_k + 1 self.linear_relpos = nn.Linear(self.n_bin, c_p) # TODO: implement structure to pairwise feature function self.template_fn, c_template = get_template_fn(template_type) self.linear_template = nn.Linear(c_template, c_p) def relpos(self, r): # AlphaFold 2 Algorithm 4 & 5 # Based on OpenFold utils/tensor_utils.py # Input: [b, n_res] # [b, n_res, n_res] d = r[:, :, None] - r[:, None, :] # [n_bin] v = torch.arange(-self.relpos_k, self.relpos_k + 1).to(r.device, non_blocking=True) # [1, 1, 1, n_bin] v_reshaped = v.view(*((1,) * len(d.shape) + (len(v),))) # [b, n_res, n_res] b = torch.argmin(torch.abs(d[:, :, :, None] - v_reshaped), dim=-1) # [b, n_res, n_res, n_bin] oh = nn.functional.one_hot(b, num_classes=len(v)).float() # [b, n_res, n_res, c_p] p = self.linear_relpos(oh) return p def template(self, t): return self.linear_template(self.template_fn(t)) def forward(self, s, t, r, mask): # Input: [b, n_res, c_s] p_mask = mask.unsqueeze(1) * mask.unsqueeze(2) # [b, n_res, c_p] p_i = self.linear_s_p_i(s) p_j = self.linear_s_p_j(s) # [b, n_res, n_res, c_p] p = p_i[:, :, None, :] + p_j[:, None, :, :] # [b, n_res, n_res, c_p] p += self.relpos(r) # upper bond is 64 AA p += self.template(t) # upper bond is 100 A # [b, n_res, n_res, c_p] p *= p_mask.unsqueeze(-1) return p # AF2's TriangularSelfAttentionBlock, but I removed the pairwise attention because of memory issues. # In genie they are doing the same. class TriangularSelfAttentionBlock(nn.Module): def __init__( self, sequence_state_dim, pairwise_state_dim, sequence_head_width, pairwise_head_width, dropout=0, **__kwargs, ): super().__init__() from openfold.model.triangular_multiplicative_update import ( TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing, ) from esm.esmfold.v1.misc import ( Attention, Dropout, PairToSequence, ResidueMLP, SequenceToPair, ) assert sequence_state_dim % sequence_head_width == 0 assert pairwise_state_dim % pairwise_head_width == 0 sequence_num_heads = sequence_state_dim // sequence_head_width pairwise_num_heads = pairwise_state_dim // pairwise_head_width assert sequence_state_dim == sequence_num_heads * sequence_head_width assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width assert pairwise_state_dim % 2 == 0 self.sequence_state_dim = sequence_state_dim self.pairwise_state_dim = pairwise_state_dim self.layernorm_1 = nn.LayerNorm(sequence_state_dim) self.sequence_to_pair = SequenceToPair( sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim ) self.pair_to_sequence = PairToSequence( pairwise_state_dim, sequence_num_heads) self.seq_attention = Attention( sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True ) self.tri_mul_out = TriangleMultiplicationOutgoing( pairwise_state_dim, pairwise_state_dim, ) self.tri_mul_in = TriangleMultiplicationIncoming( pairwise_state_dim, pairwise_state_dim, ) self.mlp_seq = ResidueMLP( sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) self.mlp_pair = ResidueMLP( pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) assert dropout < 0.4 self.drop = nn.Dropout(dropout) self.row_drop = Dropout(dropout * 2, 2) self.col_drop = Dropout(dropout * 2, 1) torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) torch.nn.init.zeros_(self.seq_attention.o_proj.weight) torch.nn.init.zeros_(self.seq_attention.o_proj.bias) torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): """ Inputs: sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean tensor of valid positions Output: sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim """ assert len(sequence_state.shape) == 3 assert len(pairwise_state.shape) == 4 if mask is not None: assert len(mask.shape) == 2 batch_dim, seq_dim, sequence_state_dim = sequence_state.shape pairwise_state_dim = pairwise_state.shape[3] assert sequence_state_dim == self.sequence_state_dim assert pairwise_state_dim == self.pairwise_state_dim assert batch_dim == pairwise_state.shape[0] assert seq_dim == pairwise_state.shape[1] assert seq_dim == pairwise_state.shape[2] # Update sequence state bias = self.pair_to_sequence(pairwise_state) # Self attention with bias + mlp. y = self.layernorm_1(sequence_state) y, _ = self.seq_attention(y, mask=mask, bias=bias) sequence_state = sequence_state + self.drop(y) sequence_state = self.mlp_seq(sequence_state) # Update pairwise state pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) # Axial attention with triangular bias. tri_mask = mask.unsqueeze( 2) * mask.unsqueeze(1) if mask is not None else None pairwise_state = pairwise_state + self.row_drop( self.tri_mul_out(pairwise_state, mask=tri_mask) ) pairwise_state = pairwise_state + self.col_drop( self.tri_mul_in(pairwise_state, mask=tri_mask) ) # MLP over pairs. pairwise_state = self.mlp_pair(pairwise_state) return sequence_state, pairwise_state # A Self-Attention Pooling Block class SeqPairAttentionOutput(nn.Module): def __init__(self, seq_state_dim, pairwise_state_dim, num_heads, output_dim, dropout): super(SeqPairAttentionOutput, self).__init__() from esm.esmfold.v1.misc import ( Attention, PairToSequence, ResidueMLP, ) self.seq_state_dim = seq_state_dim self.pairwise_state_dim = pairwise_state_dim self.output_dim = output_dim seq_head_width = seq_state_dim // num_heads self.layernorm = nn.LayerNorm(seq_state_dim) self.seq_attention = Attention( seq_state_dim, num_heads, seq_head_width, gated=True ) self.pair_to_sequence = PairToSequence(pairwise_state_dim, num_heads) self.mlp_seq = ResidueMLP( seq_state_dim, 4 * seq_state_dim, dropout=dropout) self.drop = nn.Dropout(dropout) def forward(self, sequence_state, pairwise_state, mask=None): # Update sequence state bias = self.pair_to_sequence(pairwise_state) # Self attention with bias + mlp. y = self.layernorm(sequence_state) y, _ = self.seq_attention(y, mask=mask, bias=bias) sequence_state = sequence_state + self.drop(y) sequence_state = self.mlp_seq(sequence_state) return sequence_state