File size: 14,115 Bytes
4fb0bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import functools
import logging

import torch
import torch.nn.functional as F
import math
import numpy as np

logger = logging.getLogger(__name__)


def get_device_of(tensor):
    """This function returns the device of the tensor
    refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py

    Arguments:
        tensor {tensor} -- tensor

    Returns:
        int -- device
    """

    if not tensor.is_cuda:
        return -1
    else:
        return tensor.get_device()


def get_range_vector(size, device):
    """This function returns a range vector with the desired size, starting at 0
    the CUDA implementation is meant to avoid copy data from CPU to GPU
    refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py

    Arguments:
        size {int} -- the size of range
        device {int} -- device

    Returns:
        torch.Tensor -- range vector
    """

    if device > -1:
        return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1
    else:
        return torch.arange(0, size, dtype=torch.long)


def flatten_and_batch_shift_indices(indices, sequence_length):
    """This function returns a vector that correctly indexes into the flattened target,
    the sequence length of the target must be provided to compute the appropriate offsets.
    refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py

    Arguments:
        indices {tensor} -- index tensor
        sequence_length {int} -- sequence length

    Returns:
        tensor -- offset index tensor
    """

    # Shape: (batch_size)
    if torch.max(indices) >= sequence_length or torch.min(indices) < 0:
        raise RuntimeError("All elements in indices should be in range (0, {})".format(sequence_length - 1))
    offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length
    for _ in range(len(indices.size()) - 1):
        offsets = offsets.unsqueeze(1)

    # Shape: (batch_size, d_1, ..., d_n)
    offset_indices = indices + offsets

    # Shape: (batch_size * d_1 * ... * d_n)
    offset_indices = offset_indices.view(-1)
    return offset_indices


def batched_index_select(target, indices, flattened_indices=None):
    """This function returns selected values in the target with respect to the provided indices,
    which have size ``(batch_size, d_1, ..., d_n, embedding_size)``
    refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py

    Arguments:
        target {torch.Tensor} -- target tensor
        indices {torch.LongTensor} -- index tensor

    Keyword Arguments:
        flattened_indices {Optional[torch.LongTensor]} -- flattened index tensor (default: {None})

    Returns:
        torch.Tensor -- selected tensor
    """

    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.view(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.size()) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.view(*selected_shape)
    return selected_targets


def get_padding_vector(size, dtype, device):
    """This function initializes padding unit

    Arguments:
        size {int} -- padding unit size
        dtype {torch.dtype} -- dtype
        device {int} -- device = -1 if cpu, device >= 0 if gpu

    Returns:
        tensor -- padding tensor
    """

    pad = torch.zeros(size, dtype=dtype)
    if device > -1:
        pad = pad.cuda(device=device, non_blocking=True)
    return pad


def array2tensor(array, dtype, device):
    """This function transforms numpy array to tensor

    Arguments:
        array {numpy.array} -- numpy array
        dtype {torch.dtype} -- torch dtype
        device {int} -- device = -1 if cpu, device >= 0 if gpu

    Returns:
        tensor -- tensor
    """
    tensor = torch.as_tensor(array, dtype=dtype)
    if device > -1:
        tensor = tensor.cuda(device=device, non_blocking=True)
    return tensor


def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
        refer to: https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py
    """

    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def pad_vecs(vecs, padding_size, dtype, device):
    """This function pads vectors for batch

    Arguments:
        vecs {list} -- vector list
        padding_size {int} -- padding dims
        dtype {torch.dtype} -- dtype
        device {int} -- device = -1 if cpu, device >= 0 if gpu

    Returns:
        tensor -- padded vectors
    """
    max_length = max(len(vec) for vec in vecs)

    if max_length == 0:
        pad_vecs = torch.cat([get_padding_vector((1, padding_size), dtype, device).unsqueeze(0) for _ in vecs], 0)
        return pad_vecs

    pad_vecs = []
    for vec in vecs:
        pad_vec = torch.cat(vec + [get_padding_vector((1, padding_size), dtype, device)] * (max_length - len(vec)),
                            0).unsqueeze(0)

        assert pad_vec.size() == (1, max_length, padding_size), "the size of pad vector is not correct"

        pad_vecs.append(pad_vec)
    return torch.cat(pad_vecs, 0)


def get_bilstm_minus(batch_seq_encoder_repr, span_list, seq_lens):
    """This function gets span representation using bilstm minus

    Arguments:
        batch_seq_encoder_repr {list} -- batch sequence encoder representation
        span_list {list} -- span list
        seq_lens {list} -- sequence length list

    Returns:
        tensor -- span representation vector
    """

    assert len(batch_seq_encoder_repr) == len(
        span_list), "the length of batch seq encoder repr is not equal to span list's length"

    assert len(span_list) == len(seq_lens), "the length of span list is not equal to batch seq lens's length"

    hidden_size = batch_seq_encoder_repr.size(-1)
    span_vecs = []
    for seq_encoder_repr, (s, e), seq_len in zip(batch_seq_encoder_repr, span_list, seq_lens):
        rnn_output = seq_encoder_repr[:seq_len]
        forward_rnn_output, backward_rnn_output = rnn_output.split(hidden_size // 2, 1)
        forward_span_vec = get_forward_segment(forward_rnn_output, s, e, get_device_of(forward_rnn_output))
        backward_span_vec = get_backward_segment(backward_rnn_output, s, e, get_device_of(backward_rnn_output))
        span_vec = torch.cat([forward_span_vec, backward_span_vec], 0).unsqueeze(0)
        span_vecs.append(span_vec)
    return torch.cat(span_vecs, 0)


def get_forward_segment(forward_rnn_output, s, e, device):
    """This function gets span representaion in forward rnn

    Arguments:
        forward_rnn_output {tensor} -- forward rnn output
        s {int} -- span start
        e {int} -- span end
        device {int} -- device

    Returns:
        tensor -- span representaion vector
    """

    seq_len, hidden_size = forward_rnn_output.size()
    if s >= e:
        vec = torch.zeros(hidden_size, dtype=forward_rnn_output.dtype)

        if device > -1:
            vec = vec.cuda(device=device, non_blocking=True)
        return vec

    if s == 0:
        return forward_rnn_output[e - 1]
    return forward_rnn_output[e - 1] - forward_rnn_output[s - 1]


def get_backward_segment(backward_rnn_output, s, e, device):
    """This function gets span representaion in backward rnn

    Arguments:
        forward_rnn_output {tensor} -- backward rnn output
        s {int} -- span start
        e {int} -- span end
        device {int} -- device

    Returns:
        tensor -- span representaion vector
    """

    seq_len, hidden_size = backward_rnn_output.size()
    if s >= e:
        vec = torch.zeros(hidden_size, dtype=backward_rnn_output.dtype)

        if device > -1:
            vec = vec.cuda(device=device, non_blocking=True)
        return vec

    if e == seq_len:
        return backward_rnn_output[s]
    return backward_rnn_output[s] - backward_rnn_output[e]


def get_dist_vecs(span_list, max_sent_len, device):
    """This function gets distance embedding

    Arguments:
        span_list {list} -- span list

    Returns:
        tensor -- distance embedding vector
    """

    dist_vecs = []
    for s, e in span_list:
        assert s <= e, "span start is greater than end"

        vec = torch.Tensor(np.eye(max_sent_len)[e - s])
        if device > -1:
            vec = vec.cuda(device=device, non_blocking=True)

        dist_vecs.append(vec)

    return torch.stack(dist_vecs)


def get_conv_vecs(batch_token_repr, span_list, span_batch_size, conv_layer):
    """This funciton gets span vector representation through convolution layer

    Arguments:
        batch_token_repr {list} -- batch token representation
        span_list {list} -- span list
        span_batch_size {int} -- span convolutuion batch size
        conv_layer {nn.Module} -- convolution layer

    Returns:
        tensor -- conv vectors
    """

    assert len(batch_token_repr) == len(span_list), "the length of batch token repr is not equal to span list's length"

    span_vecs = []
    for token_repr, (s, e) in zip(batch_token_repr, span_list):
        if s == e:
            span_vecs.append([])
            continue

        span_vecs.append(list(token_repr[s:e].split(1)))

    span_conv_vecs = []
    for id in range(0, len(span_vecs), span_batch_size):
        span_pad_vecs = pad_vecs(span_vecs[id:id + span_batch_size], conv_layer.get_input_dims(),
                                 batch_token_repr[0].dtype, get_device_of(batch_token_repr[0]))
        span_conv_vecs.append(conv_layer(span_pad_vecs))
    return torch.cat(span_conv_vecs, dim=0)


def get_n_trainable_parameters(model):
    """This function calculates the number of trainable parameters
    of the model
    
    Arguments:
        model {nn.Module} -- model
    
    Returns:
        int -- the number of trainable parameters of the model
    """

    cnt = 0
    for param in list(model.parameters()):
        if param.requires_grad:
            cnt += functools.reduce(lambda x, y: x * y, list(param.size()), 1)
    return cnt


def js_div(p, q, reduction='batchmean'):
    """js_div caculate Jensen Shannon Divergence (JSD).

    Args:
        p (tensor): distribution p
        q (tensor): distribution q
        reduction (str, optional): reduction. Defaults to 'batchmean'.

    Returns:
        tensor: JS divergence
    """

    m = 0.5 * (p + q)
    return (F.kl_div(p, m, reduction=reduction) + F.kl_div(q, m, reduction=reduction)) * 0.5


def load_weight_from_pretrained_model(model, pretrained_state_dict, prefix=""):
    """load_weight_from_pretrained_model This function loads weight from pretrained model.
    
    Arguments:
        model {nn.Module} -- model
        pretrained_state_dict {dict} -- state dict of pretrained model

    Keyword Arguments:
        prefix {str} -- prefix for pretrained model (default: {""})
    """

    model_state_dict = model.state_dict()

    # # load weight except decode weight
    # filtered_state_dict = {
    #     k: pretrained_state_dict[k]
    #     for k, v in model_state_dict.items() if k in pretrained_state_dict
    #     and v.size() == pretrained_state_dict[k].size() and 'decoder' not in k
    # }

    # # load bert encoder & cnn
    # filtered_state_dict.update({
    #     k: pretrained_state_dict[k[k.find('.') + 1:]]
    #     for k, v in model_state_dict.items() if k[k.find('.') + 1:] in pretrained_state_dict
    #     and v.size() == pretrained_state_dict[k[k.find('.') + 1:]].size() and 'decoder' not in k
    # })

    filtered_state_dict = {}
    for k, v in model_state_dict.items():
        if 'decoder' in k:
            continue
        # if 'bert_encoder' not in k:
        #     continue
        k = k.split('.')
        for candi_name in ['.'.join(k), '.'.join(k[1:]), '.'.join(k[2:])]:
            if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size():
                filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name]
                break

            candi_name = prefix + candi_name
            if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size():
                filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name]
                break

    # only load bert encoder
    # filtered_state_dict = {k: pretrained_state_dict[k[k.find('.') + 1:]] for k, v in model_state_dict.items() if 'bert_encoder' in k and k[k.find('.') + 1:] in pretrained_state_dict and v.size() == pretrained_state_dict[k[k.find('.') + 1:]].size() and 'decoder' not in k}

    logger.info("Load weights parameters:")
    for name in filtered_state_dict:
        logger.info(name)

    model_state_dict.update(filtered_state_dict)
    model.load_state_dict(model_state_dict)


def clone_weights(first_module, second_module):
    """This function clones(ties) weights from first module to second module
    refers to: https://huggingface.co./transformers/v1.2.0/_modules/pytorch_transformers/modeling_utils.html#PreTrainedModel
    
    Arguments:
        first_module {nn.Module} -- first module
        second_module {nn.Module} -- second module
    """

    first_module.weight = second_module.weight

    if hasattr(first_module, 'bias') and first_module.bias is not None:
        first_module.bias.data = torch.nn.functional.pad(first_module.bias.data,
                                                         (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
                                                         'constant', 0)