Scorpius_HF / model.py
yjwtheonly
midification
2d06d0e
raw
history blame
No virus
18.5 kB
import torch
from torch.nn import functional as F, Parameter
from torch.autograd import Variable
from torch.nn.init import xavier_normal_, xavier_uniform_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class Distmult(torch.nn.Module):
def __init__(self, args, num_entities, num_relations):
super(Distmult, self).__init__()
if args.max_norm:
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
else:
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
self.inp_drop = torch.nn.Dropout(args.input_drop)
self.loss = torch.nn.CrossEntropyLoss()
self.init()
def init(self):
xavier_normal_(self.emb_e.weight)
xavier_normal_(self.emb_rel.weight)
def score_sr(self, sub, rel, sigmoid = False):
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
#sub_emb = self.inp_drop(sub_emb)
#rel_emb = self.inp_drop(rel_emb)
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_or(self, obj, rel, sigmoid = False):
obj_emb = self.emb_e(obj).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
#obj_emb = self.inp_drop(obj_emb)
#rel_emb = self.inp_drop(rel_emb)
pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
'''
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
For distmult, computations for both modes are equivalent, so we do not need if-else block
'''
sub_emb = self.inp_drop(sub_emb)
rel_emb = self.inp_drop(rel_emb)
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - score
'''
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
obj_emb = self.emb_e(obj).squeeze(dim=1)
pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
'''
Inputs - embeddings of subject, relation, object
Return - score
'''
pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - a vector score for the triple instead of reducing over the embedding dimension
'''
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
obj_emb = self.emb_e(obj).squeeze(dim=1)
pred = sub_emb*rel_emb*obj_emb
if sigmoid:
pred = torch.sigmoid(pred)
return pred
class Complex(torch.nn.Module):
def __init__(self, args, num_entities, num_relations):
super(Complex, self).__init__()
if args.max_norm:
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
else:
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
self.inp_drop = torch.nn.Dropout(args.input_drop)
self.loss = torch.nn.CrossEntropyLoss()
self.init()
def init(self):
xavier_normal_(self.emb_e.weight)
xavier_normal_(self.emb_rel.weight)
def score_sr(self, sub, rel, sigmoid = False):
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
realo_realreal = s_real*rel_real
realo_imgimg = s_img*rel_img
realo = realo_realreal - realo_imgimg
real = torch.mm(realo, emb_e_real.transpose(1,0))
imgo_realimg = s_real*rel_img
imgo_imgreal = s_img*rel_real
imgo = imgo_realimg + imgo_imgreal
img = torch.mm(imgo, emb_e_img.transpose(1,0))
pred = real + img
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_or(self, obj, rel, sigmoid = False):
obj_emb = self.emb_e(obj).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
#rel_real = self.inp_drop(rel_real)
#rel_img = self.inp_drop(rel_img)
#o_real = self.inp_drop(o_real)
#o_img = self.inp_drop(o_img)
# complex space bilinear product (equivalent to HolE)
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
reals_realreal = rel_real*o_real
reals_imgimg = rel_img*o_img
reals = reals_realreal + reals_imgimg
real = torch.mm(reals, emb_e_real.transpose(1,0))
imgs_realimg = rel_real*o_img
imgs_imgreal = rel_img*o_real
imgs = imgs_realimg - imgs_imgreal
img = torch.mm(imgs, emb_e_img.transpose(1,0))
pred = real + img
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
'''
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
'''
if mode == 'lhs':
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
rel_real = self.inp_drop(rel_real)
rel_img = self.inp_drop(rel_img)
o_real = self.inp_drop(o_real)
o_img = self.inp_drop(o_img)
reals_realreal = rel_real*o_real
reals_imgimg = rel_img*o_img
reals = reals_realreal + reals_imgimg
real = torch.mm(reals, emb_e_real.transpose(1,0))
imgs_realimg = rel_real*o_img
imgs_imgreal = rel_img*o_real
imgs = imgs_realimg - imgs_imgreal
img = torch.mm(imgs, emb_e_img.transpose(1,0))
pred = real + img
else:
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
s_real = self.inp_drop(s_real)
s_img = self.inp_drop(s_img)
rel_real = self.inp_drop(rel_real)
rel_img = self.inp_drop(rel_img)
realo_realreal = s_real*rel_real
realo_imgimg = s_img*rel_img
realo = realo_realreal - realo_imgimg
real = torch.mm(realo, emb_e_real.transpose(1,0))
imgo_realimg = s_real*rel_img
imgo_imgreal = s_img*rel_real
imgo = imgo_realimg + imgo_imgreal
img = torch.mm(imgo, emb_e_img.transpose(1,0))
pred = real + img
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - score
'''
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
obj_emb = self.emb_e(obj).squeeze(dim=1)
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
'''
Inputs - embeddings of subject, relation, object
Return - score
'''
s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - a vector score for the triple instead of reducing over the embedding dimension
'''
sub_emb = self.emb_e(sub).squeeze(dim=1)
rel_emb = self.emb_rel(rel).squeeze(dim=1)
obj_emb = self.emb_e(obj).squeeze(dim=1)
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
realrealreal = s_real*rel_real*o_real
realimgimg = s_real*rel_img*o_img
imgrealimg = s_img*rel_real*o_img
imgimgreal = s_img*rel_img*o_real
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
if sigmoid:
pred = torch.sigmoid(pred)
return pred
class Conve(torch.nn.Module):
#Too slow !!!!
def __init__(self, args, num_entities, num_relations):
super(Conve, self).__init__()
if args.max_norm:
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
else:
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
self.inp_drop = torch.nn.Dropout(args.input_drop)
self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
self.embedding_dim = args.embedding_dim #default is 200
self.num_filters = args.num_filters # default is 32
self.kernel_size = args.kernel_size # default is 3
self.stack_width = args.stack_width # default is 20
self.stack_height = args.embedding_dim // self.stack_width
self.bn0 = torch.nn.BatchNorm2d(1)
self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
kernel_size=(self.kernel_size, self.kernel_size),
stride=1, padding=0, bias=args.use_bias)
#self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
flat_sz_w = self.stack_height - self.kernel_size + 1
self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
self.loss = torch.nn.CrossEntropyLoss()
self.init()
def init(self):
xavier_normal_(self.emb_e.weight)
xavier_normal_(self.emb_rel.weight)
def concat(self, e1_embed, rel_embed, form='plain'):
if form == 'plain':
e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
stack_inp = torch.cat([e1_embed, rel_embed], 2)
elif form == 'alternate':
e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
stack_inp = torch.cat([e1_embed, rel_embed], 1)
stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
else: raise NotImplementedError
return stack_inp
def conve_architecture(self, sub_emb, rel_emb):
stacked_inputs = self.concat(sub_emb, rel_emb)
stacked_inputs = self.bn0(stacked_inputs)
x = self.inp_drop(stacked_inputs)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_drop(x)
#x = x.view(x.shape[0], -1)
x = x.view(-1, self.flat_sz)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
return x
def score_sr(self, sub, rel, sigmoid = False):
sub_emb = self.emb_e(sub)
rel_emb = self.emb_rel(rel)
x = self.conve_architecture(sub_emb, rel_emb)
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
pred += self.b.expand_as(pred)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_or(self, obj, rel, sigmoid = False):
obj_emb = self.emb_e(obj)
rel_emb = self.emb_rel(rel)
x = self.conve_architecture(obj_emb, rel_emb)
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
pred += self.b.expand_as(pred)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
'''
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
For conve, computations for both modes are equivalent, so we do not need if-else block
'''
x = self.conve_architecture(sub_emb, rel_emb)
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
pred += self.b.expand_as(pred)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - score
'''
sub_emb = self.emb_e(sub)
rel_emb = self.emb_rel(rel)
obj_emb = self.emb_e(obj)
x = self.conve_architecture(sub_emb, rel_emb)
pred = torch.mm(x, obj_emb.transpose(1,0))
#print(pred.shape)
pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
# above works fine for single input triples;
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
# so use torch.diagonal() after calling this function
pred = torch.diagonal(pred)
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
'''
Inputs - embeddings of subject, relation, object
Return - score
'''
x = self.conve_architecture(emb_s, emb_r)
pred = torch.mm(x, emb_o.transpose(1,0))
#pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - don't know which obj
# above works fine for single input triples;
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
# so use torch.diagonal() after calling this function
pred = torch.diagonal(pred)
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
'''
Inputs - subject, relation, object
Return - a vector score for the triple instead of reducing over the embedding dimension
'''
sub_emb = self.emb_e(sub)
rel_emb = self.emb_rel(rel)
obj_emb = self.emb_e(obj)
x = self.conve_architecture(sub_emb, rel_emb)
#pred = torch.mm(x, obj_emb.transpose(1,0))
pred = x*obj_emb
#print(pred.shape, self.b[obj].shape) #shapes are [7,200] and [7]
#pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - can't add scalar to vector
#pred = sub_emb*rel_emb*obj_emb
if sigmoid:
pred = torch.sigmoid(pred)
return pred