|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE |
|
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE, |
|
E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE. |
|
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. |
|
""" |
|
|
|
import numpy as np |
|
import logging |
|
|
|
import torch |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.modules.loss import _Loss |
|
|
|
from torch_geometric.nn import MessagePassing |
|
from torch_geometric.nn import global_add_pool |
|
|
|
|
|
from ..graph_grammar.graph_grammar.symbols import NTSymbol |
|
from ..graph_grammar.nn.encoder import EncoderBase |
|
from ..graph_grammar.nn.decoder import DecoderBase |
|
|
|
def get_atom_edge_feature_dims(): |
|
from torch_geometric.utils.smiles import x_map, e_map |
|
func = lambda x: len(x[1]) |
|
return list(map(func, x_map.items())), list(map(func, e_map.items())) |
|
|
|
|
|
class FeatureEmbedding(nn.Module): |
|
def __init__(self, input_dims, embedded_dim): |
|
super().__init__() |
|
self.embedding_list = nn.ModuleList() |
|
for dim in input_dims: |
|
embedding = nn.Embedding(dim, embedded_dim) |
|
self.embedding_list.append(embedding) |
|
|
|
def forward(self, x): |
|
output = 0 |
|
for i in range(x.shape[1]): |
|
input = x[:, i].to(torch.int) |
|
device = next(self.parameters()).device |
|
if device != input.device: |
|
input = input.to(device) |
|
emb = self.embedding_list[i](input) |
|
output += emb |
|
return output |
|
|
|
|
|
class GRUEncoder(EncoderBase): |
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, |
|
bidirectional: bool, dropout: float, batch_size: int, rank: int=-1, |
|
no_dropout: bool=False): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.num_layers = num_layers |
|
self.bidirectional = bidirectional |
|
self.dropout = dropout |
|
self.batch_size = batch_size |
|
self.rank = rank |
|
self.model = nn.GRU(input_size=self.input_dim, |
|
hidden_size=self.hidden_dim, |
|
num_layers=self.num_layers, |
|
batch_first=True, |
|
bidirectional=self.bidirectional, |
|
dropout=self.dropout if not no_dropout else 0) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.model = self.model.to(rank) |
|
else: |
|
|
|
self.model = self.model.to(torch.device("mps", rank)) |
|
self.init_hidden(self.batch_size) |
|
|
|
def init_hidden(self, bsize): |
|
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, |
|
min(self.batch_size, bsize), |
|
self.hidden_dim), |
|
requires_grad=False) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.h0 = self.h0.to(self.rank) |
|
else: |
|
|
|
self.h0 = self.h0.to(torch.device("mps", self.rank)) |
|
|
|
def to(self, device): |
|
newself = super().to(device) |
|
newself.model = newself.model.to(device) |
|
newself.h0 = newself.h0.to(device) |
|
newself.rank = next(newself.parameters()).get_device() |
|
return newself |
|
|
|
def forward(self, in_seq_emb): |
|
''' forward model |
|
|
|
Parameters |
|
---------- |
|
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) |
|
|
|
Returns |
|
------- |
|
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) |
|
''' |
|
|
|
self.init_hidden(in_seq_emb.size(0)) |
|
max_len = in_seq_emb.size(1) |
|
hidden_seq_emb, self.h0 = self.model( |
|
in_seq_emb, self.h0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_seq_emb = hidden_seq_emb.view(-1, |
|
max_len, |
|
1 + self.bidirectional, |
|
self.hidden_dim) |
|
return hidden_seq_emb |
|
|
|
|
|
class GRUDecoder(DecoderBase): |
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, |
|
dropout: float, batch_size: int, rank: int=-1, |
|
no_dropout: bool=False): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.num_layers = num_layers |
|
self.dropout = dropout |
|
self.batch_size = batch_size |
|
self.rank = rank |
|
self.model = nn.GRU(input_size=self.input_dim, |
|
hidden_size=self.hidden_dim, |
|
num_layers=self.num_layers, |
|
batch_first=True, |
|
bidirectional=False, |
|
dropout=self.dropout if not no_dropout else 0 |
|
) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.model = self.model.to(self.rank) |
|
else: |
|
|
|
self.model = self.model.to(torch.device("mps", self.rank)) |
|
self.init_hidden(self.batch_size) |
|
|
|
def init_hidden(self, bsize): |
|
self.hidden_dict['h'] = torch.zeros((self.num_layers, |
|
min(self.batch_size, bsize), |
|
self.hidden_dim), |
|
requires_grad=False) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank) |
|
else: |
|
self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank)) |
|
|
|
def to(self, device): |
|
newself = super().to(device) |
|
newself.model = newself.model.to(device) |
|
for k in self.hidden_dict.keys(): |
|
newself.hidden_dict[k] = newself.hidden_dict[k].to(device) |
|
newself.rank = next(newself.parameters()).get_device() |
|
return newself |
|
|
|
def forward_one_step(self, tgt_emb_in): |
|
''' one-step forward model |
|
|
|
Parameters |
|
---------- |
|
tgt_emb_in : Tensor, shape (batch_size, input_dim) |
|
|
|
Returns |
|
------- |
|
Tensor, shape (batch_size, hidden_dim) |
|
''' |
|
bsize = tgt_emb_in.size(0) |
|
tgt_emb_out, self.hidden_dict['h'] \ |
|
= self.model(tgt_emb_in.view(bsize, 1, -1), |
|
self.hidden_dict['h']) |
|
return tgt_emb_out |
|
|
|
|
|
class NodeMLP(nn.Module): |
|
def __init__(self, input_size, output_size, hidden_size): |
|
super().__init__() |
|
self.lin1 = nn.Linear(input_size, hidden_size) |
|
self.nbat = nn.BatchNorm1d(hidden_size) |
|
self.lin2 = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
x = self.lin1(x) |
|
x = self.nbat(x) |
|
x = x.relu() |
|
x = self.lin2(x) |
|
return x |
|
|
|
|
|
class GINLayer(MessagePassing): |
|
def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size): |
|
super().__init__() |
|
self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size) |
|
self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size) |
|
self.eps = nn.Parameter(torch.tensor([0.0])) |
|
|
|
def forward(self, x, edge_index, edge_attr): |
|
msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr) |
|
x = (1.0 + self.eps) * x + msg |
|
x = x.relu() |
|
x = self.node_mlp(x) |
|
return x |
|
|
|
def message(self, x_j, edge_attr): |
|
edge_attr = self.edge_mlp(edge_attr) |
|
x_j = x_j + edge_attr |
|
x_j = x_j.relu() |
|
return x_j |
|
|
|
def update(self, aggr_out): |
|
return aggr_out |
|
|
|
|
|
|
|
class GIN(torch.nn.Module): |
|
def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64, |
|
proximity_size=3, dropout=0.1): |
|
super().__init__() |
|
|
|
hsize = hidden_channels * 2 |
|
atom_dim, edge_dim = get_atom_edge_feature_dims() |
|
self.trans = FeatureEmbedding(atom_dim, hidden_channels) |
|
ml = [] |
|
for _ in range(proximity_size): |
|
ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim)) |
|
self.mlist = nn.ModuleList(ml) |
|
|
|
|
|
self.dropout = dropout |
|
self.proximity_size = proximity_size |
|
|
|
def forward(self, x, edge_index, edge_attr, batch_size): |
|
x = x.to(torch.float) |
|
|
|
edge_attr = edge_attr.to(torch.float) |
|
|
|
x = self.trans(x) |
|
|
|
hlist = [global_add_pool(x, batch_size)] |
|
for id, m in enumerate(self.mlist): |
|
x = m(x, edge_index=edge_index, edge_attr=edge_attr) |
|
|
|
|
|
x = x.relu() |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
h = global_add_pool(x, batch_size) |
|
hlist.append(h) |
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat(hlist, dim=1) |
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
return x |
|
|
|
|
|
|
|
class GrammarSeq2SeqVAE(nn.Module): |
|
|
|
''' |
|
Variational seq2seq with grammar. |
|
TODO: rewrite this class using mixin |
|
''' |
|
|
|
def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80, |
|
batch_size=64, padding_idx=-1, |
|
encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True, |
|
'dropout': 0.1}, |
|
decoder_params={'hidden_dim': 384, |
|
'num_layers': 3, |
|
'dropout': 0.1}, |
|
prod_rule_embed_params={'out_dim': 128}, |
|
no_dropout=False): |
|
|
|
super().__init__() |
|
|
|
self.hrg = hrg |
|
self.rank = rank |
|
self.prod_rule_corpus = hrg.prod_rule_corpus |
|
self.prod_rule_embed_params = prod_rule_embed_params |
|
|
|
self.vocab_size = hrg.num_prod_rule + 1 |
|
self.batch_size = batch_size |
|
self.padding_idx = np.mod(padding_idx, self.vocab_size) |
|
self.no_dropout = no_dropout |
|
|
|
self.latent_dim = latent_dim |
|
self.max_len = max_len |
|
self.encoder_params = encoder_params |
|
self.decoder_params = decoder_params |
|
|
|
|
|
embed_out_dim = self.prod_rule_embed_params['out_dim'] |
|
|
|
self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim, |
|
padding_idx=self.padding_idx) |
|
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, |
|
padding_idx=self.padding_idx) |
|
|
|
|
|
self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size, |
|
rank=self.rank, no_dropout=self.no_dropout, |
|
**self.encoder_params) |
|
|
|
lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim'] |
|
lin_out_dim = self.latent_dim |
|
self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False) |
|
self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim) |
|
|
|
|
|
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, |
|
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) |
|
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) |
|
self.latent2hidden_dict = nn.ModuleDict() |
|
dec_lin_out_dim = self.decoder_params['hidden_dim'] |
|
for each_hidden in self.decoder.hidden_dict.keys(): |
|
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) |
|
else: |
|
|
|
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) |
|
|
|
self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size) |
|
self.encoder.init_hidden(self.batch_size) |
|
self.decoder.init_hidden(self.batch_size) |
|
|
|
|
|
if hasattr(self.src_embedding, 'weight'): |
|
self.src_embedding.weight.data.uniform_(-0.1, 0.1) |
|
if hasattr(self.tgt_embedding, 'weight'): |
|
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) |
|
|
|
self.encoder.init_hidden(self.batch_size) |
|
self.decoder.init_hidden(self.batch_size) |
|
|
|
def to(self, device): |
|
newself = super().to(device) |
|
newself.src_embedding = newself.src_embedding.to(device) |
|
newself.tgt_embedding = newself.tgt_embedding.to(device) |
|
newself.encoder = newself.encoder.to(device) |
|
newself.decoder = newself.decoder.to(device) |
|
newself.dec2vocab = newself.dec2vocab.to(device) |
|
newself.hidden2mean = newself.hidden2mean.to(device) |
|
newself.hidden2logvar = newself.hidden2logvar.to(device) |
|
newself.latent2tgt_emb = newself.latent2tgt_emb.to(device) |
|
newself.latent2hidden_dict = newself.latent2hidden_dict.to(device) |
|
return newself |
|
|
|
def forward(self, in_seq, out_seq): |
|
''' forward model |
|
|
|
Parameters |
|
---------- |
|
in_seq : Variable, shape (batch_size, length) |
|
each element corresponds to word index. |
|
where the index should be less than `vocab_size` |
|
|
|
Returns |
|
------- |
|
Variable, shape (batch_size, length, vocab_size) |
|
logit of each word (applying softmax yields the probability) |
|
''' |
|
mu, logvar = self.encode(in_seq) |
|
z = self.reparameterize(mu, logvar) |
|
return self.decode(z, out_seq), mu, logvar |
|
|
|
def encode(self, in_seq): |
|
src_emb = self.src_embedding(in_seq) |
|
src_h = self.encoder.forward(src_emb) |
|
if self.encoder_params.get('bidirectional', False): |
|
concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1) |
|
return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h) |
|
else: |
|
return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :]) |
|
|
|
def reparameterize(self, mu, logvar, training=True): |
|
if training: |
|
std = logvar.mul(0.5).exp_() |
|
device = next(self.parameters()).device |
|
eps = Variable(std.data.new(std.size()).normal_()) |
|
if device != eps.get_device(): |
|
eps.to(device) |
|
return eps.mul(std).add_(mu) |
|
else: |
|
return mu |
|
|
|
|
|
def sample(self, sample_size=-1, deterministic=True, return_z=False): |
|
self.eval() |
|
self.init_hidden() |
|
if sample_size == -1: |
|
sample_size = self.batch_size |
|
|
|
num_iter = int(np.ceil(sample_size / self.batch_size)) |
|
hg_list = [] |
|
z_list = [] |
|
for _ in range(num_iter): |
|
z = Variable(torch.normal( |
|
torch.zeros(self.batch_size, self.latent_dim), |
|
torch.ones(self.batch_size * self.latent_dim))).cuda() |
|
_, each_hg_list = self.decode(z, deterministic=deterministic) |
|
z_list.append(z) |
|
hg_list += each_hg_list |
|
z = torch.cat(z_list)[:sample_size] |
|
hg_list = hg_list[:sample_size] |
|
if return_z: |
|
return hg_list, z.cpu().detach().numpy() |
|
else: |
|
return hg_list |
|
|
|
def decode(self, z=None, out_seq=None, deterministic=True): |
|
if z is None: |
|
z = Variable(torch.normal( |
|
torch.zeros(self.batch_size, self.latent_dim), |
|
torch.ones(self.batch_size * self.latent_dim))) |
|
if self.rank >= 0: |
|
z = z.to(next(self.parameters()).device) |
|
|
|
hidden_dict_0 = {} |
|
for each_hidden in self.latent2hidden_dict.keys(): |
|
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) |
|
bsize = z.size(0) |
|
self.decoder.init_hidden(bsize) |
|
self.decoder.feed_hidden(hidden_dict_0) |
|
|
|
if out_seq is not None: |
|
tgt_emb0 = self.latent2tgt_emb(z) |
|
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) |
|
out_seq_emb = self.tgt_embedding(out_seq) |
|
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] |
|
tgt_emb_pred_list = [] |
|
for each_idx in range(self.max_len): |
|
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1)) |
|
tgt_emb_pred_list.append(tgt_emb_pred) |
|
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) |
|
return vocab_logit |
|
else: |
|
with torch.no_grad(): |
|
tgt_emb = self.latent2tgt_emb(z) |
|
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) |
|
tgt_emb_pred_list = [] |
|
stack_list = [] |
|
hg_list = [] |
|
nt_symbol_list = [] |
|
nt_edge_list = [] |
|
gen_finish_list = [] |
|
for _ in range(bsize): |
|
stack_list.append([]) |
|
hg_list.append(None) |
|
nt_symbol_list.append(NTSymbol(degree=0, |
|
is_aromatic=False, |
|
bond_symbol_list=[])) |
|
nt_edge_list.append(None) |
|
gen_finish_list.append(False) |
|
|
|
for idx in range(self.max_len): |
|
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) |
|
tgt_emb_pred_list.append(tgt_emb_pred) |
|
vocab_logit = self.dec2vocab(tgt_emb_pred) |
|
for each_batch_idx in range(bsize): |
|
if not gen_finish_list[each_batch_idx]: |
|
|
|
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), |
|
nt_symbol_list[each_batch_idx], |
|
deterministic=deterministic) |
|
|
|
tgt_id = self.hrg.prod_rule_list.index(prod_rule) |
|
|
|
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) |
|
|
|
stack_list[each_batch_idx].extend(nt_edges[::-1]) |
|
|
|
if len(stack_list[each_batch_idx]) == 0: |
|
gen_finish_list[each_batch_idx] = True |
|
else: |
|
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() |
|
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] |
|
else: |
|
tgt_id = np.mod(self.padding_idx, self.vocab_size) |
|
indice_tensor = torch.LongTensor([tgt_id]) |
|
device = next(self.parameters()).device |
|
if indice_tensor.device != device: |
|
indice_tensor = indice_tensor.to(device) |
|
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) |
|
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) |
|
|
|
|
|
|
|
return gen_finish_list, vocab_logit, hg_list |
|
|
|
|
|
|
|
class GrammarGINVAE(nn.Module): |
|
|
|
''' |
|
Variational autoencoder based on GIN and grammar |
|
''' |
|
|
|
def __init__(self, hrg, rank=-1, max_len=80, |
|
batch_size=64, padding_idx=-1, |
|
encoder_params={'node_feature_size': 4, 'edge_feature_size': 3, |
|
'hidden_channels': 64, 'proximity_size': 3, |
|
'dropout': 0.1}, |
|
decoder_params={'hidden_dim': 384, 'num_layers': 3, |
|
'dropout': 0.1}, |
|
prod_rule_embed_params={'out_dim': 128}, |
|
no_dropout=False): |
|
|
|
super().__init__() |
|
|
|
self.hrg = hrg |
|
self.rank = rank |
|
self.prod_rule_corpus = hrg.prod_rule_corpus |
|
self.prod_rule_embed_params = prod_rule_embed_params |
|
|
|
self.vocab_size = hrg.num_prod_rule + 1 |
|
self.batch_size = batch_size |
|
self.padding_idx = np.mod(padding_idx, self.vocab_size) |
|
self.no_dropout = no_dropout |
|
self.max_len = max_len |
|
self.encoder_params = encoder_params |
|
self.decoder_params = decoder_params |
|
|
|
|
|
embed_out_dim = self.prod_rule_embed_params['out_dim'] |
|
|
|
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, |
|
padding_idx=self.padding_idx) |
|
|
|
self.encoder = GIN(**self.encoder_params) |
|
self.latent_dim = self.encoder_params['hidden_channels'] |
|
self.proximity_size = self.encoder_params['proximity_size'] |
|
hidden_dim = self.decoder_params['hidden_dim'] |
|
self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False) |
|
self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim) |
|
|
|
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, |
|
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) |
|
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) |
|
self.latent2hidden_dict = nn.ModuleDict() |
|
for each_hidden in self.decoder.hidden_dict.keys(): |
|
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim) |
|
if self.rank >= 0: |
|
if torch.cuda.is_available(): |
|
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) |
|
else: |
|
|
|
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) |
|
|
|
self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size) |
|
self.decoder.init_hidden(self.batch_size) |
|
|
|
|
|
if hasattr(self.tgt_embedding, 'weight'): |
|
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) |
|
self.decoder.init_hidden(self.batch_size) |
|
|
|
def to(self, device): |
|
newself = super().to(device) |
|
newself.encoder = newself.encoder.to(device) |
|
newself.decoder = newself.decoder.to(device) |
|
newself.rank = next(newself.encoder.parameters()).get_device() |
|
return newself |
|
|
|
def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None): |
|
mu, logvar = self.encode(x, edge_index, edge_attr, batch_size) |
|
z = self.reparameterize(mu, logvar) |
|
return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar |
|
|
|
|
|
def sample(self, sample_size=-1, deterministic=True, return_z=False): |
|
self.eval() |
|
self.init_hidden() |
|
if sample_size == -1: |
|
sample_size = self.batch_size |
|
|
|
num_iter = int(np.ceil(sample_size / self.batch_size)) |
|
hg_list = [] |
|
z_list = [] |
|
for _ in range(num_iter): |
|
z = Variable(torch.normal( |
|
torch.zeros(self.batch_size, self.latent_dim), |
|
torch.ones(self.batch_size * self.latent_dim))).cuda() |
|
_, each_hg_list = self.decode(z, deterministic=deterministic) |
|
z_list.append(z) |
|
hg_list += each_hg_list |
|
z = torch.cat(z_list)[:sample_size] |
|
hg_list = hg_list[:sample_size] |
|
if return_z: |
|
return hg_list, z.cpu().detach().numpy() |
|
else: |
|
return hg_list |
|
|
|
def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None): |
|
if z is None: |
|
z = Variable(torch.normal( |
|
torch.zeros(self.batch_size, self.latent_dim), |
|
torch.ones(self.batch_size * self.latent_dim))) |
|
if self.rank >= 0: |
|
z = z.to(next(self.parameters()).device) |
|
|
|
hidden_dict_0 = {} |
|
for each_hidden in self.latent2hidden_dict.keys(): |
|
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) |
|
bsize = z.size(0) |
|
self.decoder.init_hidden(bsize) |
|
self.decoder.feed_hidden(hidden_dict_0) |
|
|
|
if out_seq is not None: |
|
tgt_emb0 = self.latent2tgt_emb(z) |
|
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) |
|
out_seq_emb = self.tgt_embedding(out_seq) |
|
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] |
|
tgt_emb_pred_list = [] |
|
tgt_emb_pred = None |
|
for each_idx in range(self.max_len): |
|
if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob: |
|
inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1) |
|
else: |
|
cur_logit = self.dec2vocab(tgt_emb_pred) |
|
yi = torch.argmax(cur_logit, dim=2) |
|
inp = self.tgt_embedding(yi) |
|
tgt_emb_pred = self.decoder.forward_one_step(inp) |
|
tgt_emb_pred_list.append(tgt_emb_pred) |
|
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) |
|
return vocab_logit |
|
else: |
|
with torch.no_grad(): |
|
tgt_emb = self.latent2tgt_emb(z) |
|
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) |
|
tgt_emb_pred_list = [] |
|
stack_list = [] |
|
hg_list = [] |
|
nt_symbol_list = [] |
|
nt_edge_list = [] |
|
gen_finish_list = [] |
|
for _ in range(bsize): |
|
stack_list.append([]) |
|
hg_list.append(None) |
|
nt_symbol_list.append(NTSymbol(degree=0, |
|
is_aromatic=False, |
|
bond_symbol_list=[])) |
|
nt_edge_list.append(None) |
|
gen_finish_list.append(False) |
|
|
|
for _ in range(self.max_len): |
|
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) |
|
tgt_emb_pred_list.append(tgt_emb_pred) |
|
vocab_logit = self.dec2vocab(tgt_emb_pred) |
|
for each_batch_idx in range(bsize): |
|
if not gen_finish_list[each_batch_idx]: |
|
|
|
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), |
|
nt_symbol_list[each_batch_idx], |
|
deterministic=deterministic) |
|
|
|
tgt_id = self.hrg.prod_rule_list.index(prod_rule) |
|
|
|
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) |
|
|
|
stack_list[each_batch_idx].extend(nt_edges[::-1]) |
|
|
|
if len(stack_list[each_batch_idx]) == 0: |
|
gen_finish_list[each_batch_idx] = True |
|
else: |
|
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() |
|
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] |
|
else: |
|
tgt_id = np.mod(self.padding_idx, self.vocab_size) |
|
indice_tensor = torch.LongTensor([tgt_id]) |
|
if self.rank >= 0: |
|
indice_tensor = indice_tensor.to(next(self.parameters()).device) |
|
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) |
|
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) |
|
return gen_finish_list, vocab_logit, hg_list |
|
|
|
|
|
def conditional_distribution(self, z, tgt_id_list): |
|
self.eval() |
|
self.init_hidden() |
|
z = z.cuda() |
|
|
|
hidden_dict_0 = {} |
|
for each_hidden in self.latent2hidden_dict.keys(): |
|
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) |
|
self.decoder.feed_hidden(hidden_dict_0) |
|
|
|
with torch.no_grad(): |
|
tgt_emb = self.latent2tgt_emb(z) |
|
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) |
|
nt_symbol_list = [] |
|
stack_list = [] |
|
hg_list = [] |
|
nt_edge_list = [] |
|
gen_finish_list = [] |
|
for _ in range(self.batch_size): |
|
nt_symbol_list.append(NTSymbol(degree=0, |
|
is_aromatic=False, |
|
bond_symbol_list=[])) |
|
stack_list.append([]) |
|
hg_list.append(None) |
|
nt_edge_list.append(None) |
|
gen_finish_list.append(False) |
|
|
|
for each_position in range(len(tgt_id_list[0])): |
|
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) |
|
for each_batch_idx in range(self.batch_size): |
|
if not gen_finish_list[each_batch_idx]: |
|
|
|
tgt_id = tgt_id_list[each_batch_idx][each_position] |
|
prod_rule = self.hrg.prod_rule_list[tgt_id] |
|
|
|
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) |
|
|
|
stack_list[each_batch_idx].extend(nt_edges[::-1]) |
|
|
|
if len(stack_list[each_batch_idx]) == 0: |
|
gen_finish_list[each_batch_idx] = True |
|
else: |
|
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() |
|
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] |
|
else: |
|
tgt_id = np.mod(self.padding_idx, self.vocab_size) |
|
indice_tensor = torch.LongTensor([tgt_id]) |
|
indice_tensor = indice_tensor.cuda() |
|
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) |
|
|
|
|
|
conditional_logprob_list = [] |
|
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) |
|
vocab_logit = self.dec2vocab(tgt_emb_pred) |
|
for each_batch_idx in range(self.batch_size): |
|
if not gen_finish_list[each_batch_idx]: |
|
|
|
masked_logprob = self.hrg.prod_rule_corpus.masked_logprob( |
|
vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), |
|
nt_symbol_list[each_batch_idx]) |
|
conditional_logprob_list.append(masked_logprob) |
|
else: |
|
conditional_logprob_list.append(None) |
|
return conditional_logprob_list |
|
|
|
|
|
def decode_with_beam_search(self, z, beam_width=1): |
|
''' Decode a latent vector using beam search. |
|
|
|
Parameters |
|
---------- |
|
z |
|
latent vector |
|
beam_width : int |
|
parameter for beam search |
|
|
|
Returns |
|
------- |
|
List of Hypergraphs |
|
''' |
|
if self.batch_size != 1: |
|
raise ValueError('this method works only under batch_size=1') |
|
if self.padding_idx != -1: |
|
raise ValueError('this method works only under padding_idx=-1') |
|
top_k_tgt_id_list = [[]] * beam_width |
|
logprob_list = [0.] * beam_width |
|
|
|
for each_len in range(self.max_len): |
|
expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) |
|
expanded_length_list = np.array([0] * (beam_width * self.vocab_size)) |
|
for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list): |
|
conditional_logprob = self.conditional_distribution(z, [each_candidate])[0] |
|
if conditional_logprob is None: |
|
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ |
|
= logprob_list[each_beam_idx] |
|
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ |
|
= -np.inf |
|
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ |
|
= len(each_candidate) |
|
else: |
|
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ |
|
= logprob_list[each_beam_idx] + conditional_logprob |
|
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ |
|
= -np.inf |
|
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ |
|
= len(each_candidate) + 1 |
|
score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list) |
|
if each_len == 0: |
|
top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width] |
|
else: |
|
top_k_list = np.argsort(score_list)[::-1][:beam_width] |
|
next_top_k_tgt_id_list = [] |
|
next_logprob_list = [] |
|
for each_top_k in top_k_list: |
|
beam_idx = each_top_k // self.vocab_size |
|
vocab_idx = each_top_k % self.vocab_size |
|
if vocab_idx == self.vocab_size - 1: |
|
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx]) |
|
next_logprob_list.append(expanded_logprob_list[each_top_k]) |
|
else: |
|
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx]) |
|
next_logprob_list.append(expanded_logprob_list[each_top_k]) |
|
top_k_tgt_id_list = next_top_k_tgt_id_list |
|
logprob_list = next_logprob_list |
|
|
|
|
|
hg_list = [] |
|
for each_tgt_id_list in top_k_tgt_id_list: |
|
hg = None |
|
stack = [] |
|
nt_edge = None |
|
for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list): |
|
prod_rule = self.hrg.prod_rule_list[each_prod_rule_id] |
|
hg, nt_edges = prod_rule.applied_to(hg, nt_edge) |
|
stack.extend(nt_edges[::-1]) |
|
try: |
|
nt_edge = stack.pop() |
|
except IndexError: |
|
if each_idx == len(each_tgt_id_list) - 1: |
|
break |
|
else: |
|
raise ValueError('some bugs') |
|
hg_list.append(hg) |
|
return hg_list |
|
|
|
def graph_embed(self, x, edge_index, edge_attr, batch_size): |
|
src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size) |
|
return src_h |
|
|
|
def encode(self, x, edge_index, edge_attr, batch_size): |
|
|
|
|
|
src_h = self.graph_embed(x, edge_index, edge_attr, batch_size) |
|
mu, lv = self.get_mean_var(src_h) |
|
return mu, lv |
|
|
|
def get_mean_var(self, src_h): |
|
|
|
mu = self.hidden2mean(src_h) |
|
lv = self.hidden2logvar(src_h) |
|
mu = torch.tanh(mu) |
|
lv = torch.tanh(lv) |
|
return mu, lv |
|
|
|
def reparameterize(self, mu, logvar, training=True): |
|
if training: |
|
std = logvar.mul(0.5).exp_() |
|
eps = Variable(std.data.new(std.size()).normal_()) |
|
if self.rank >= 0: |
|
eps = eps.to(next(self.parameters()).device) |
|
return eps.mul(std).add_(mu) |
|
else: |
|
return mu |
|
|
|
|
|
class GrammarVAELoss(_Loss): |
|
|
|
''' |
|
a loss function for Grammar VAE |
|
|
|
Attributes |
|
---------- |
|
hrg : HyperedgeReplacementGrammar |
|
beta : float |
|
coefficient of KL divergence |
|
''' |
|
|
|
def __init__(self, rank, hrg, beta=1.0, **kwargs): |
|
super().__init__(**kwargs) |
|
self.hrg = hrg |
|
self.beta = beta |
|
self.rank = rank |
|
|
|
def forward(self, mu, logvar, in_seq_pred, in_seq): |
|
''' compute VAE loss |
|
|
|
Parameters |
|
---------- |
|
in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size) |
|
logit |
|
in_seq : torch.Tensor, shape (batch_size, max_len) |
|
each element corresponds to a word id in vocabulary. |
|
mu : torch.Tensor, shape (batch_size, hidden_dim) |
|
logvar : torch.Tensor, shape (batch_size, hidden_dim) |
|
mean and log variance of the normal distribution |
|
''' |
|
batch_size = in_seq_pred.shape[0] |
|
max_len = in_seq_pred.shape[1] |
|
vocab_size = in_seq_pred.shape[2] |
|
mask = torch.zeros(in_seq_pred.shape) |
|
|
|
for each_batch in range(batch_size): |
|
flag = True |
|
for each_idx in range(max_len): |
|
prod_rule_idx = in_seq[each_batch, each_idx] |
|
if prod_rule_idx == vocab_size - 1: |
|
|
|
mask[each_batch, each_idx, prod_rule_idx] = 1 |
|
|
|
continue |
|
lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol |
|
lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs) |
|
mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx]) |
|
if self.rank >= 0: |
|
mask = mask.to(next(self.parameters()).device) |
|
in_seq_pred = mask * in_seq_pred |
|
|
|
cross_entropy = F.cross_entropy( |
|
in_seq_pred.view(-1, vocab_size), |
|
in_seq.view(-1), |
|
reduction='sum', |
|
|
|
) |
|
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
return cross_entropy + self.beta * kl_div |
|
|
|
|
|
class VAELoss(_Loss): |
|
def __init__(self, beta=0.01): |
|
super().__init__() |
|
self.beta = beta |
|
|
|
def forward(self, mean, log_var, dec_outputs, targets): |
|
|
|
device = mean.get_device() |
|
if device >= 0: |
|
targets = targets.to(mean.get_device()) |
|
reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum') |
|
|
|
KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var)) |
|
loss = - self.beta * KL + reconstruction |
|
return loss |
|
|