import torch from torch.nn import functional as F, Parameter from torch.autograd import Variable from torch.nn.init import xavier_normal_, xavier_uniform_ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class Distmult(torch.nn.Module): def __init__(self, args, num_entities, num_relations): super(Distmult, self).__init__() if args.max_norm: self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0) self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim) else: self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None) self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None) self.inp_drop = torch.nn.Dropout(args.input_drop) self.loss = torch.nn.CrossEntropyLoss() self.init() def init(self): xavier_normal_(self.emb_e.weight) xavier_normal_(self.emb_rel.weight) def score_sr(self, sub, rel, sigmoid = False): sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) #sub_emb = self.inp_drop(sub_emb) #rel_emb = self.inp_drop(rel_emb) pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0)) if sigmoid: pred = torch.sigmoid(pred) return pred def score_or(self, obj, rel, sigmoid = False): obj_emb = self.emb_e(obj).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) #obj_emb = self.inp_drop(obj_emb) #rel_emb = self.inp_drop(rel_emb) pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0)) if sigmoid: pred = torch.sigmoid(pred) return pred def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False): ''' When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r) For distmult, computations for both modes are equivalent, so we do not need if-else block ''' sub_emb = self.inp_drop(sub_emb) rel_emb = self.inp_drop(rel_emb) pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0)) if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - score ''' sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) obj_emb = self.emb_e(obj).squeeze(dim=1) pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1) if sigmoid: pred = torch.sigmoid(pred) return pred def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False): ''' Inputs - embeddings of subject, relation, object Return - score ''' pred = torch.sum(emb_s*emb_r*emb_o, dim=-1) if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples_vec(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - a vector score for the triple instead of reducing over the embedding dimension ''' sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) obj_emb = self.emb_e(obj).squeeze(dim=1) pred = sub_emb*rel_emb*obj_emb if sigmoid: pred = torch.sigmoid(pred) return pred class Complex(torch.nn.Module): def __init__(self, args, num_entities, num_relations): super(Complex, self).__init__() if args.max_norm: self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0) self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim) else: self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None) self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None) self.inp_drop = torch.nn.Dropout(args.input_drop) self.loss = torch.nn.CrossEntropyLoss() self.init() def init(self): xavier_normal_(self.emb_e.weight) xavier_normal_(self.emb_rel.weight) def score_sr(self, sub, rel, sigmoid = False): sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) s_real, s_img = torch.chunk(rel_emb, 2, dim=-1) rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1) emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1) realo_realreal = s_real*rel_real realo_imgimg = s_img*rel_img realo = realo_realreal - realo_imgimg real = torch.mm(realo, emb_e_real.transpose(1,0)) imgo_realimg = s_real*rel_img imgo_imgreal = s_img*rel_real imgo = imgo_realimg + imgo_imgreal img = torch.mm(imgo, emb_e_img.transpose(1,0)) pred = real + img if sigmoid: pred = torch.sigmoid(pred) return pred def score_or(self, obj, rel, sigmoid = False): obj_emb = self.emb_e(obj).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1) o_real, o_img = torch.chunk(obj_emb, 2, dim=-1) emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1) #rel_real = self.inp_drop(rel_real) #rel_img = self.inp_drop(rel_img) #o_real = self.inp_drop(o_real) #o_img = self.inp_drop(o_img) # complex space bilinear product (equivalent to HolE) # realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0)) # realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0)) # imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0)) # imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0)) # pred = realrealreal + realimgimg + imgrealimg - imgimgreal reals_realreal = rel_real*o_real reals_imgimg = rel_img*o_img reals = reals_realreal + reals_imgimg real = torch.mm(reals, emb_e_real.transpose(1,0)) imgs_realimg = rel_real*o_img imgs_imgreal = rel_img*o_real imgs = imgs_realimg - imgs_imgreal img = torch.mm(imgs, emb_e_img.transpose(1,0)) pred = real + img if sigmoid: pred = torch.sigmoid(pred) return pred def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False): ''' When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r) ''' if mode == 'lhs': rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1) o_real, o_img = torch.chunk(sub_emb, 2, dim=-1) emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1) rel_real = self.inp_drop(rel_real) rel_img = self.inp_drop(rel_img) o_real = self.inp_drop(o_real) o_img = self.inp_drop(o_img) reals_realreal = rel_real*o_real reals_imgimg = rel_img*o_img reals = reals_realreal + reals_imgimg real = torch.mm(reals, emb_e_real.transpose(1,0)) imgs_realimg = rel_real*o_img imgs_imgreal = rel_img*o_real imgs = imgs_realimg - imgs_imgreal img = torch.mm(imgs, emb_e_img.transpose(1,0)) pred = real + img else: s_real, s_img = torch.chunk(rel_emb, 2, dim=-1) rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1) emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1) s_real = self.inp_drop(s_real) s_img = self.inp_drop(s_img) rel_real = self.inp_drop(rel_real) rel_img = self.inp_drop(rel_img) realo_realreal = s_real*rel_real realo_imgimg = s_img*rel_img realo = realo_realreal - realo_imgimg real = torch.mm(realo, emb_e_real.transpose(1,0)) imgo_realimg = s_real*rel_img imgo_imgreal = s_img*rel_real imgo = imgo_realimg + imgo_imgreal img = torch.mm(imgo, emb_e_img.transpose(1,0)) pred = real + img if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - score ''' sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) obj_emb = self.emb_e(obj).squeeze(dim=1) s_real, s_img = torch.chunk(sub_emb, 2, dim=-1) rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1) o_real, o_img = torch.chunk(obj_emb, 2, dim=-1) realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1) realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1) imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1) imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1) pred = realrealreal + realimgimg + imgrealimg - imgimgreal if sigmoid: pred = torch.sigmoid(pred) return pred def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False): ''' Inputs - embeddings of subject, relation, object Return - score ''' s_real, s_img = torch.chunk(emb_s, 2, dim=-1) rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1) o_real, o_img = torch.chunk(emb_o, 2, dim=-1) realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1) realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1) imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1) imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1) pred = realrealreal + realimgimg + imgrealimg - imgimgreal if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples_vec(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - a vector score for the triple instead of reducing over the embedding dimension ''' sub_emb = self.emb_e(sub).squeeze(dim=1) rel_emb = self.emb_rel(rel).squeeze(dim=1) obj_emb = self.emb_e(obj).squeeze(dim=1) s_real, s_img = torch.chunk(sub_emb, 2, dim=-1) rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1) o_real, o_img = torch.chunk(obj_emb, 2, dim=-1) realrealreal = s_real*rel_real*o_real realimgimg = s_real*rel_img*o_img imgrealimg = s_img*rel_real*o_img imgimgreal = s_img*rel_img*o_real pred = realrealreal + realimgimg + imgrealimg - imgimgreal if sigmoid: pred = torch.sigmoid(pred) return pred class Conve(torch.nn.Module): #Too slow !!!! def __init__(self, args, num_entities, num_relations): super(Conve, self).__init__() if args.max_norm: self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0) self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim) else: self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None) self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None) self.inp_drop = torch.nn.Dropout(args.input_drop) self.hidden_drop = torch.nn.Dropout(args.hidden_drop) self.feature_drop = torch.nn.Dropout2d(args.feat_drop) self.embedding_dim = args.embedding_dim #default is 200 self.num_filters = args.num_filters # default is 32 self.kernel_size = args.kernel_size # default is 3 self.stack_width = args.stack_width # default is 20 self.stack_height = args.embedding_dim // self.stack_width self.bn0 = torch.nn.BatchNorm2d(1) self.bn1 = torch.nn.BatchNorm2d(self.num_filters) self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim) self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters, kernel_size=(self.kernel_size, self.kernel_size), stride=1, padding=0, bias=args.use_bias) #self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1 flat_sz_w = self.stack_height - self.kernel_size + 1 self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim) self.register_parameter('b', Parameter(torch.zeros(num_entities))) self.loss = torch.nn.CrossEntropyLoss() self.init() def init(self): xavier_normal_(self.emb_e.weight) xavier_normal_(self.emb_rel.weight) def concat(self, e1_embed, rel_embed, form='plain'): if form == 'plain': e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height) rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height) stack_inp = torch.cat([e1_embed, rel_embed], 2) elif form == 'alternate': e1_embed = e1_embed. view(-1, 1, self.embedding_dim) rel_embed = rel_embed.view(-1, 1, self.embedding_dim) stack_inp = torch.cat([e1_embed, rel_embed], 1) stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height)) else: raise NotImplementedError return stack_inp def conve_architecture(self, sub_emb, rel_emb): stacked_inputs = self.concat(sub_emb, rel_emb) stacked_inputs = self.bn0(stacked_inputs) x = self.inp_drop(stacked_inputs) x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.feature_drop(x) #x = x.view(x.shape[0], -1) x = x.view(-1, self.flat_sz) x = self.fc(x) x = self.hidden_drop(x) x = self.bn2(x) x = F.relu(x) return x def score_sr(self, sub, rel, sigmoid = False): sub_emb = self.emb_e(sub) rel_emb = self.emb_rel(rel) x = self.conve_architecture(sub_emb, rel_emb) pred = torch.mm(x, self.emb_e.weight.transpose(1,0)) pred += self.b.expand_as(pred) if sigmoid: pred = torch.sigmoid(pred) return pred def score_or(self, obj, rel, sigmoid = False): obj_emb = self.emb_e(obj) rel_emb = self.emb_rel(rel) x = self.conve_architecture(obj_emb, rel_emb) pred = torch.mm(x, self.emb_e.weight.transpose(1,0)) pred += self.b.expand_as(pred) if sigmoid: pred = torch.sigmoid(pred) return pred def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False): ''' When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r) For conve, computations for both modes are equivalent, so we do not need if-else block ''' x = self.conve_architecture(sub_emb, rel_emb) pred = torch.mm(x, self.emb_e.weight.transpose(1,0)) pred += self.b.expand_as(pred) if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - score ''' sub_emb = self.emb_e(sub) rel_emb = self.emb_rel(rel) obj_emb = self.emb_e(obj) x = self.conve_architecture(sub_emb, rel_emb) pred = torch.mm(x, obj_emb.transpose(1,0)) #print(pred.shape) pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding # above works fine for single input triples; # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores # so use torch.diagonal() after calling this function pred = torch.diagonal(pred) # or could have used : pred= torch.sum(x*obj_emb, dim=-1) if sigmoid: pred = torch.sigmoid(pred) return pred def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False): ''' Inputs - embeddings of subject, relation, object Return - score ''' x = self.conve_architecture(emb_s, emb_r) pred = torch.mm(x, emb_o.transpose(1,0)) #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - don't know which obj # above works fine for single input triples; # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores # so use torch.diagonal() after calling this function pred = torch.diagonal(pred) # or could have used : pred= torch.sum(x*obj_emb, dim=-1) if sigmoid: pred = torch.sigmoid(pred) return pred def score_triples_vec(self, sub, rel, obj, sigmoid=False): ''' Inputs - subject, relation, object Return - a vector score for the triple instead of reducing over the embedding dimension ''' sub_emb = self.emb_e(sub) rel_emb = self.emb_rel(rel) obj_emb = self.emb_e(obj) x = self.conve_architecture(sub_emb, rel_emb) #pred = torch.mm(x, obj_emb.transpose(1,0)) pred = x*obj_emb #print(pred.shape, self.b[obj].shape) #shapes are [7,200] and [7] #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - can't add scalar to vector #pred = sub_emb*rel_emb*obj_emb if sigmoid: pred = torch.sigmoid(pred) return pred