Compact_Facts / utils /nn_utils.py
khulnasoft's picture
Upload 108 files
4fb0bd1 verified
import functools
import logging
import torch
import torch.nn.functional as F
import math
import numpy as np
logger = logging.getLogger(__name__)
def get_device_of(tensor):
"""This function returns the device of the tensor
refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py
Arguments:
tensor {tensor} -- tensor
Returns:
int -- device
"""
if not tensor.is_cuda:
return -1
else:
return tensor.get_device()
def get_range_vector(size, device):
"""This function returns a range vector with the desired size, starting at 0
the CUDA implementation is meant to avoid copy data from CPU to GPU
refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py
Arguments:
size {int} -- the size of range
device {int} -- device
Returns:
torch.Tensor -- range vector
"""
if device > -1:
return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1
else:
return torch.arange(0, size, dtype=torch.long)
def flatten_and_batch_shift_indices(indices, sequence_length):
"""This function returns a vector that correctly indexes into the flattened target,
the sequence length of the target must be provided to compute the appropriate offsets.
refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py
Arguments:
indices {tensor} -- index tensor
sequence_length {int} -- sequence length
Returns:
tensor -- offset index tensor
"""
# Shape: (batch_size)
if torch.max(indices) >= sequence_length or torch.min(indices) < 0:
raise RuntimeError("All elements in indices should be in range (0, {})".format(sequence_length - 1))
offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length
for _ in range(len(indices.size()) - 1):
offsets = offsets.unsqueeze(1)
# Shape: (batch_size, d_1, ..., d_n)
offset_indices = indices + offsets
# Shape: (batch_size * d_1 * ... * d_n)
offset_indices = offset_indices.view(-1)
return offset_indices
def batched_index_select(target, indices, flattened_indices=None):
"""This function returns selected values in the target with respect to the provided indices,
which have size ``(batch_size, d_1, ..., d_n, embedding_size)``
refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py
Arguments:
target {torch.Tensor} -- target tensor
indices {torch.LongTensor} -- index tensor
Keyword Arguments:
flattened_indices {Optional[torch.LongTensor]} -- flattened index tensor (default: {None})
Returns:
torch.Tensor -- selected tensor
"""
if flattened_indices is None:
# Shape: (batch_size * d_1 * ... * d_n)
flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))
# Shape: (batch_size * sequence_length, embedding_size)
flattened_target = target.view(-1, target.size(-1))
# Shape: (batch_size * d_1 * ... * d_n, embedding_size)
flattened_selected = flattened_target.index_select(0, flattened_indices)
selected_shape = list(indices.size()) + [target.size(-1)]
# Shape: (batch_size, d_1, ..., d_n, embedding_size)
selected_targets = flattened_selected.view(*selected_shape)
return selected_targets
def get_padding_vector(size, dtype, device):
"""This function initializes padding unit
Arguments:
size {int} -- padding unit size
dtype {torch.dtype} -- dtype
device {int} -- device = -1 if cpu, device >= 0 if gpu
Returns:
tensor -- padding tensor
"""
pad = torch.zeros(size, dtype=dtype)
if device > -1:
pad = pad.cuda(device=device, non_blocking=True)
return pad
def array2tensor(array, dtype, device):
"""This function transforms numpy array to tensor
Arguments:
array {numpy.array} -- numpy array
dtype {torch.dtype} -- torch dtype
device {int} -- device = -1 if cpu, device >= 0 if gpu
Returns:
tensor -- tensor
"""
tensor = torch.as_tensor(array, dtype=dtype)
if device > -1:
tensor = tensor.cuda(device=device, non_blocking=True)
return tensor
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
refer to: https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def pad_vecs(vecs, padding_size, dtype, device):
"""This function pads vectors for batch
Arguments:
vecs {list} -- vector list
padding_size {int} -- padding dims
dtype {torch.dtype} -- dtype
device {int} -- device = -1 if cpu, device >= 0 if gpu
Returns:
tensor -- padded vectors
"""
max_length = max(len(vec) for vec in vecs)
if max_length == 0:
pad_vecs = torch.cat([get_padding_vector((1, padding_size), dtype, device).unsqueeze(0) for _ in vecs], 0)
return pad_vecs
pad_vecs = []
for vec in vecs:
pad_vec = torch.cat(vec + [get_padding_vector((1, padding_size), dtype, device)] * (max_length - len(vec)),
0).unsqueeze(0)
assert pad_vec.size() == (1, max_length, padding_size), "the size of pad vector is not correct"
pad_vecs.append(pad_vec)
return torch.cat(pad_vecs, 0)
def get_bilstm_minus(batch_seq_encoder_repr, span_list, seq_lens):
"""This function gets span representation using bilstm minus
Arguments:
batch_seq_encoder_repr {list} -- batch sequence encoder representation
span_list {list} -- span list
seq_lens {list} -- sequence length list
Returns:
tensor -- span representation vector
"""
assert len(batch_seq_encoder_repr) == len(
span_list), "the length of batch seq encoder repr is not equal to span list's length"
assert len(span_list) == len(seq_lens), "the length of span list is not equal to batch seq lens's length"
hidden_size = batch_seq_encoder_repr.size(-1)
span_vecs = []
for seq_encoder_repr, (s, e), seq_len in zip(batch_seq_encoder_repr, span_list, seq_lens):
rnn_output = seq_encoder_repr[:seq_len]
forward_rnn_output, backward_rnn_output = rnn_output.split(hidden_size // 2, 1)
forward_span_vec = get_forward_segment(forward_rnn_output, s, e, get_device_of(forward_rnn_output))
backward_span_vec = get_backward_segment(backward_rnn_output, s, e, get_device_of(backward_rnn_output))
span_vec = torch.cat([forward_span_vec, backward_span_vec], 0).unsqueeze(0)
span_vecs.append(span_vec)
return torch.cat(span_vecs, 0)
def get_forward_segment(forward_rnn_output, s, e, device):
"""This function gets span representaion in forward rnn
Arguments:
forward_rnn_output {tensor} -- forward rnn output
s {int} -- span start
e {int} -- span end
device {int} -- device
Returns:
tensor -- span representaion vector
"""
seq_len, hidden_size = forward_rnn_output.size()
if s >= e:
vec = torch.zeros(hidden_size, dtype=forward_rnn_output.dtype)
if device > -1:
vec = vec.cuda(device=device, non_blocking=True)
return vec
if s == 0:
return forward_rnn_output[e - 1]
return forward_rnn_output[e - 1] - forward_rnn_output[s - 1]
def get_backward_segment(backward_rnn_output, s, e, device):
"""This function gets span representaion in backward rnn
Arguments:
forward_rnn_output {tensor} -- backward rnn output
s {int} -- span start
e {int} -- span end
device {int} -- device
Returns:
tensor -- span representaion vector
"""
seq_len, hidden_size = backward_rnn_output.size()
if s >= e:
vec = torch.zeros(hidden_size, dtype=backward_rnn_output.dtype)
if device > -1:
vec = vec.cuda(device=device, non_blocking=True)
return vec
if e == seq_len:
return backward_rnn_output[s]
return backward_rnn_output[s] - backward_rnn_output[e]
def get_dist_vecs(span_list, max_sent_len, device):
"""This function gets distance embedding
Arguments:
span_list {list} -- span list
Returns:
tensor -- distance embedding vector
"""
dist_vecs = []
for s, e in span_list:
assert s <= e, "span start is greater than end"
vec = torch.Tensor(np.eye(max_sent_len)[e - s])
if device > -1:
vec = vec.cuda(device=device, non_blocking=True)
dist_vecs.append(vec)
return torch.stack(dist_vecs)
def get_conv_vecs(batch_token_repr, span_list, span_batch_size, conv_layer):
"""This funciton gets span vector representation through convolution layer
Arguments:
batch_token_repr {list} -- batch token representation
span_list {list} -- span list
span_batch_size {int} -- span convolutuion batch size
conv_layer {nn.Module} -- convolution layer
Returns:
tensor -- conv vectors
"""
assert len(batch_token_repr) == len(span_list), "the length of batch token repr is not equal to span list's length"
span_vecs = []
for token_repr, (s, e) in zip(batch_token_repr, span_list):
if s == e:
span_vecs.append([])
continue
span_vecs.append(list(token_repr[s:e].split(1)))
span_conv_vecs = []
for id in range(0, len(span_vecs), span_batch_size):
span_pad_vecs = pad_vecs(span_vecs[id:id + span_batch_size], conv_layer.get_input_dims(),
batch_token_repr[0].dtype, get_device_of(batch_token_repr[0]))
span_conv_vecs.append(conv_layer(span_pad_vecs))
return torch.cat(span_conv_vecs, dim=0)
def get_n_trainable_parameters(model):
"""This function calculates the number of trainable parameters
of the model
Arguments:
model {nn.Module} -- model
Returns:
int -- the number of trainable parameters of the model
"""
cnt = 0
for param in list(model.parameters()):
if param.requires_grad:
cnt += functools.reduce(lambda x, y: x * y, list(param.size()), 1)
return cnt
def js_div(p, q, reduction='batchmean'):
"""js_div caculate Jensen Shannon Divergence (JSD).
Args:
p (tensor): distribution p
q (tensor): distribution q
reduction (str, optional): reduction. Defaults to 'batchmean'.
Returns:
tensor: JS divergence
"""
m = 0.5 * (p + q)
return (F.kl_div(p, m, reduction=reduction) + F.kl_div(q, m, reduction=reduction)) * 0.5
def load_weight_from_pretrained_model(model, pretrained_state_dict, prefix=""):
"""load_weight_from_pretrained_model This function loads weight from pretrained model.
Arguments:
model {nn.Module} -- model
pretrained_state_dict {dict} -- state dict of pretrained model
Keyword Arguments:
prefix {str} -- prefix for pretrained model (default: {""})
"""
model_state_dict = model.state_dict()
# # load weight except decode weight
# filtered_state_dict = {
# k: pretrained_state_dict[k]
# for k, v in model_state_dict.items() if k in pretrained_state_dict
# and v.size() == pretrained_state_dict[k].size() and 'decoder' not in k
# }
# # load bert encoder & cnn
# filtered_state_dict.update({
# k: pretrained_state_dict[k[k.find('.') + 1:]]
# for k, v in model_state_dict.items() if k[k.find('.') + 1:] in pretrained_state_dict
# and v.size() == pretrained_state_dict[k[k.find('.') + 1:]].size() and 'decoder' not in k
# })
filtered_state_dict = {}
for k, v in model_state_dict.items():
if 'decoder' in k:
continue
# if 'bert_encoder' not in k:
# continue
k = k.split('.')
for candi_name in ['.'.join(k), '.'.join(k[1:]), '.'.join(k[2:])]:
if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size():
filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name]
break
candi_name = prefix + candi_name
if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size():
filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name]
break
# only load bert encoder
# filtered_state_dict = {k: pretrained_state_dict[k[k.find('.') + 1:]] for k, v in model_state_dict.items() if 'bert_encoder' in k and k[k.find('.') + 1:] in pretrained_state_dict and v.size() == pretrained_state_dict[k[k.find('.') + 1:]].size() and 'decoder' not in k}
logger.info("Load weights parameters:")
for name in filtered_state_dict:
logger.info(name)
model_state_dict.update(filtered_state_dict)
model.load_state_dict(model_state_dict)
def clone_weights(first_module, second_module):
"""This function clones(ties) weights from first module to second module
refers to: https://huggingface.co./transformers/v1.2.0/_modules/pytorch_transformers/modeling_utils.html#PreTrainedModel
Arguments:
first_module {nn.Module} -- first module
second_module {nn.Module} -- second module
"""
first_module.weight = second_module.weight
if hasattr(first_module, 'bias') and first_module.bias is not None:
first_module.bias.data = torch.nn.functional.pad(first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant', 0)