Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 IDEA-CCNL The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch TransfoXLDenoise model. """ | |
import math | |
import torch | |
import torch.utils.checkpoint as checkpoint | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
from transformers.modeling_utils import ( | |
PreTrainedModel | |
) | |
from transformers.modeling_outputs import ModelOutput | |
from .configuration_transfo_xl_denoise import TransfoXLDenoiseConfig | |
_CHECKPOINT_FOR_DOC = "transformer-xl-1b-base" | |
_CONFIG_FOR_DOC = "TransfoXLDenoiseConfig" | |
_TOKENIZER_FOR_DOC = "TransfoXLDenoiseTokenizer" | |
Transfo_XL_Denoise_START_DOCSTRING = r""" | |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. | |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general | |
usage and behavior. | |
Parameters: | |
config ([`~TransfoXLDenoiseConfig`]): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the configuration. | |
Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
Transfo_XL_Denoise_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `({0})`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`TransfoXLDenoiseTokenizer`]. | |
See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): | |
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: | |
- 0 corresponds to a *sentence A* token, | |
- 1 corresponds to a *sentence B* token. | |
[What are token type IDs?](../glossary#token-type-ids) | |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. | |
Selected in the range `[0, config.max_position_embeddings - 1]`. | |
[What are position IDs?](../glossary#position-ids) | |
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): | |
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert *input_ids* indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
Transfo_XL_Denoise_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
"transformer-xl-1b-base", | |
] | |
class TransfoXLDenoiseModelOutput(ModelOutput): | |
logits: torch.FloatTensor = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
class PositionalEmbedding(torch.nn.Module): | |
def __init__(self, hidden_size): | |
super(PositionalEmbedding, self).__init__() | |
self.hidden_size = hidden_size | |
inv_freq = 1 / (10000 ** (torch.arange(0.0, hidden_size, 2.0) / hidden_size)) | |
self.register_buffer('inv_freq', inv_freq) | |
def forward(self, pos_seq, bsz=None): | |
sinusoid_inp = torch.ger(pos_seq, self.inv_freq) | |
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | |
if bsz is not None: | |
return pos_emb[None, :, :].expand(bsz, -1, -1) | |
else: | |
return pos_emb[None, :, :] | |
def ensure_divisibility(numerator, denominator): | |
"""Ensure that numerator is divisible by the denominator.""" | |
assert numerator % denominator == 0, '{} is not divisible by {}'.format( | |
numerator, denominator) | |
def divide(numerator, denominator): | |
"""Ensure that numerator is divisible by the denominator and return | |
the division value.""" | |
ensure_divisibility(numerator, denominator) | |
return numerator // denominator | |
def scaled_init_method(sigma, num_layers): | |
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" | |
std = sigma / math.sqrt(2.0 * num_layers) | |
def init_(tensor): | |
return torch.nn.init.normal_(tensor, mean=0.0, std=std) | |
return init_ | |
def unscaled_init_method(sigma): | |
"""Init method based on N(0, sigma).""" | |
def init_(tensor): | |
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) | |
return init_ | |
def gelu_impl(x): | |
"""OpenAI's gelu implementation.""" | |
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x | |
* (1.0 + 0.044715 * x * x))) | |
def gelu(x): | |
return gelu_impl(x) | |
class GPT2SelfAttention(torch.nn.Module): | |
"""Parallel self-attention layer for GPT2. | |
Self-attention layer takes input with size [b, s, h] where b is | |
the batch size, s is the sequence lenght, and h is the hidden size | |
and creates output of the same size. | |
Arguments: | |
hidden_size: total hidden size of the layer (h). | |
num_attention_heads: number of attention heads (n). Note that we | |
require n to be divisible by number of GPUs | |
used to parallelize the model. Also, we | |
require hidden size to be divisible by n. | |
dropout_prob: dropout probability for the attention scores. | |
init_method: weight initialization. | |
output_layer_init_method: output layer initialization. If None, use | |
`init_method`. | |
We use the following notation: | |
h: hidden_size | |
n: num_attention_heads | |
p: number of partitions | |
np: n/p | |
hp: h/p | |
hn: h/n | |
b: batch size | |
s: sequence length | |
""" | |
def __init__(self, hidden_size, num_attention_heads, | |
attention_dropout_prob, output_dropout_prob, | |
init_method, output_layer_init_method=None, relative_encoding=False): | |
super(GPT2SelfAttention, self).__init__() | |
# Set output layer initialization if not provided. | |
if output_layer_init_method is None: | |
output_layer_init_method = init_method | |
# Per attention head and per partition values. | |
self.hidden_size_per_partition = hidden_size | |
self.hidden_size_per_attention_head = divide(hidden_size, | |
num_attention_heads) | |
self.num_attention_heads_per_partition = num_attention_heads | |
self.relative_encoding = relative_encoding | |
# Strided linear layer. | |
self.query_key_value = torch.nn.Linear(hidden_size, | |
3 * hidden_size, bias=True) | |
if relative_encoding: | |
self.relative = torch.nn.Linear(hidden_size, hidden_size, bias=True) | |
# Dropout. Note that for a single iteration, this layer will generate | |
# different outputs on different number of parallel partitions but | |
# on average it should not be partition dependent. | |
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) | |
# Output. | |
self.dense = torch.nn.Linear(hidden_size, hidden_size, bias=True) | |
self.output_dropout = torch.nn.Dropout(output_dropout_prob) | |
def _transpose_for_scores(self, tensor): | |
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with | |
size [b, np, s, hn]. | |
""" | |
new_tensor_shape = tensor.size()[:-1] + \ | |
(self.num_attention_heads_per_partition, | |
self.hidden_size_per_attention_head) | |
tensor = tensor.view(*new_tensor_shape) | |
return tensor.permute(0, 2, 1, 3) | |
def _rel_shift(x, zero_triu=False): | |
# ql x kl x bsz x h | |
# bsz x h x ql x kl | |
zero_pad = torch.zeros((*x.size()[:-2], x.size(-2), 1), | |
device=x.device, dtype=x.dtype) | |
x_padded = torch.cat([zero_pad, x], dim=-1) | |
x_padded = x_padded.view(*x.size()[:-2], x.size(-1) + 1, x.size(-2)) | |
x = x_padded[:, :, 1:].view_as(x) | |
if zero_triu: | |
ones = torch.ones((x.size(0), x.size(1))) | |
x = x * torch.tril(ones, x.size(1) - x.size(0))[:, :, None, None] | |
return x | |
def _rel_shift_latest(x: torch.Tensor): | |
ndims = x.dim() | |
x_shape = x.size() | |
row_dim = 2 | |
col_dim = row_dim + 1 | |
assert col_dim < ndims | |
tgt_shape_1, tgt_shape_2 = [], [] | |
for i in range(ndims): | |
if i == row_dim: | |
tgt_shape_1.append(x_shape[col_dim]) | |
tgt_shape_2.append(x_shape[row_dim]) | |
elif i == col_dim: | |
tgt_shape_1.append(x_shape[row_dim]) | |
tgt_shape_2.append(x_shape[col_dim] - 1) | |
else: | |
tgt_shape_1.append(x_shape[i]) | |
tgt_shape_2.append(x_shape[i]) | |
x = x.view(*tgt_shape_1) | |
x = x[:, :, 1:, :] | |
x = x.view(*tgt_shape_2) | |
return x | |
def forward(self, hidden_states, ltor_mask, position_embeddings=None, r_w_bias=None, r_r_bias=None, mem=None): | |
# hidden_states: [b, s, h] | |
# ltor_mask: [1, 1, s, s] | |
# Attention heads. [b, s, hp] | |
query_length = hidden_states.size(1) | |
if mem is None: | |
mixed_x_layer = self.query_key_value(hidden_states) | |
(mixed_query_layer, | |
mixed_key_layer, | |
mixed_value_layer) = torch.chunk(mixed_x_layer, 3, dim=-1) | |
else: | |
cat = torch.cat((mem, hidden_states), 1) | |
mixed_x_layer = self.query_key_value(cat) | |
(mixed_query_layer, | |
mixed_key_layer, | |
mixed_value_layer) = torch.chunk(mixed_x_layer, 3, dim=-1) | |
mixed_query_layer = mixed_query_layer[:, -query_length:] | |
# Reshape and transpose [b, np, s, hn] | |
query_layer = self._transpose_for_scores(mixed_query_layer) | |
key_layer = self._transpose_for_scores(mixed_key_layer) | |
value_layer = self._transpose_for_scores(mixed_value_layer) | |
if self.relative_encoding: | |
relative_layer = self.relative(position_embeddings) | |
relative_layer = self._transpose_for_scores( | |
relative_layer) # 1 (bsz) x n_head x klen x d_head | |
# Raw attention scores. [b, np, qs, ks] | |
rw_head_q = query_layer + r_w_bias.unsqueeze(1) | |
ac_score = torch.matmul(rw_head_q, key_layer.transpose(-1, -2)) | |
rr_head_q = query_layer + r_r_bias.unsqueeze(1) | |
bd_score = torch.matmul(rr_head_q, relative_layer.transpose(-1, -2)) | |
bd_score = self._rel_shift(bd_score) # qlen x klen x bsz x n_head | |
# bd_score = bd_score.permute(2, 3, 0, 1) # bsz n_head qlen klen | |
attention_scores = ac_score + bd_score | |
else: | |
# Raw attention scores. [b, np, s, s] | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
attention_scores = attention_scores / math.sqrt( | |
self.hidden_size_per_attention_head) | |
# Apply the left to right attention mask. | |
attention_scores = torch.mul(attention_scores, ltor_mask) - \ | |
10000.0 * (1.0 - ltor_mask) | |
# Attention probabilities. [b, np, s, s] | |
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
# with get_cuda_rng_tracker().fork(): | |
# attention_probs = self.attention_dropout(attention_probs) | |
# Context layer. | |
# [b, np, s, hn] | |
context_layer = torch.matmul(attention_probs, value_layer) | |
# [b, s, np, hn] | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + \ | |
(self.hidden_size_per_partition,) | |
# [b, s, hp] | |
context_layer = context_layer.view(*new_context_layer_shape) | |
# Output. [b, s, h] | |
output = self.dense(context_layer) | |
output = self.output_dropout(output) | |
return output | |
class GPT2MLP(torch.nn.Module): | |
"""MLP for GPT2. | |
MLP will take the input with h hidden state, project it to 4*h | |
hidden dimension, perform gelu transformation, and project the | |
state back into h hidden dimension. At the end, dropout is also | |
applied. | |
Arguments: | |
hidden_size: The hidden size of the self attention. | |
output_dropout_prob: dropout probability for the outputs | |
after self attention and final output. | |
init_method: initialization method used for the weights. Note | |
that all biases are initialized to zero and | |
layernorm weight are initialized to one. | |
output_layer_init_method: output layer initialization. If None, | |
use `init_method`. | |
""" | |
def __init__(self, hidden_size, output_dropout_prob, init_method, | |
output_layer_init_method=None): | |
super(GPT2MLP, self).__init__() | |
# Set output layer initialization if not provided. | |
if output_layer_init_method is None: | |
output_layer_init_method = init_method | |
# Project to 4h. | |
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4 * hidden_size) | |
# Project back to h. | |
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) | |
self.dropout = torch.nn.Dropout(output_dropout_prob) | |
def forward(self, hidden_states): | |
# [b, s, 4hp] | |
intermediate_parallel = self.dense_h_to_4h(hidden_states) | |
intermediate_parallel = gelu(intermediate_parallel) | |
# [b, s, h] | |
output = self.dense_4h_to_h(intermediate_parallel) | |
output = self.dropout(output) | |
return output | |
class GPT2TransformerLayer(torch.nn.Module): | |
"""A single layer transformer for GPT2. | |
We use the following notation: | |
h: hidden size | |
n: number of attention heads | |
b: batch size | |
s: sequence length | |
Transformore layer takes input with size [b, s, h] and returns an | |
output of the same size. | |
Arguments: | |
hidden_size: The hidden size of the self attention. | |
num_attention_heads: number of attention head in the self | |
attention. | |
attention_dropout_prob: dropout probability of the attention | |
score in self attention. | |
output_dropout_prob: dropout probability for the outputs | |
after self attention and final output. | |
layernorm_epsilon: epsilon used in layernorm to avoid | |
division by zero. | |
init_method: initialization method used for the weights. Note | |
that all biases are initialized to zero and | |
layernorm weight are initialized to one. | |
output_layer_init_method: output layers (attention output and | |
mlp output) initialization. If None, | |
use `init_method`. | |
""" | |
def __init__(self, | |
hidden_size, | |
num_attention_heads, | |
attention_dropout_prob, | |
output_dropout_prob, | |
layernorm_epsilon, | |
init_method, | |
output_layer_init_method=None, | |
relative_encoding=False): | |
super(GPT2TransformerLayer, self).__init__() | |
# Set output layer initialization if not provided. | |
if output_layer_init_method is None: | |
output_layer_init_method = init_method | |
# Layernorm on the input data. | |
self.input_layernorm = torch.nn.LayerNorm(hidden_size, eps=layernorm_epsilon) | |
# Self attention. | |
self.attention = GPT2SelfAttention( | |
hidden_size, | |
num_attention_heads, | |
attention_dropout_prob, | |
output_dropout_prob, | |
init_method, | |
output_layer_init_method=output_layer_init_method, | |
relative_encoding=relative_encoding) | |
# Layernorm on the input data. | |
self.post_attention_layernorm = torch.nn.LayerNorm(hidden_size, | |
eps=layernorm_epsilon) | |
# MLP | |
self.mlp = GPT2MLP( | |
hidden_size, | |
output_dropout_prob, | |
init_method, | |
output_layer_init_method=output_layer_init_method) | |
def forward(self, hidden_states, ltor_mask, position_embeddings=None, r_w_bias=None, r_r_bias=None, mem=None): | |
# hidden_states: [b, s, h] | |
# ltor_mask: [1, 1, s, s] | |
# Layer norm at the begining of the transformer layer. | |
layernorm_output = self.input_layernorm(hidden_states) | |
mem = self.input_layernorm(mem) if mem is not None else None | |
# Self attention. | |
attention_output = self.attention( | |
layernorm_output, ltor_mask, position_embeddings, r_w_bias, r_r_bias, mem) | |
# Residual connection. | |
# print(f'hz {hidden_states.shape}, attn {attention_output.shape}') | |
layernorm_input = hidden_states + attention_output | |
# Layer norm post the self attention. | |
layernorm_output = self.post_attention_layernorm(layernorm_input) | |
# MLP. | |
mlp_output = self.mlp(layernorm_output) | |
# Second residual connection. | |
output = layernorm_input + mlp_output | |
return output | |
class GPT2Transformer(torch.nn.Module): | |
"""GPT-2 transformer. | |
This module takes input from embedding layer and it's output can | |
be used directly by a logit layer. It consists of L (num-layers) | |
blocks of: | |
layer norm | |
self attention | |
residual connection | |
layer norm | |
mlp | |
residual connection | |
followed by a final layer norm. | |
Arguments: | |
num_layers: Number of transformer layers. | |
hidden_size: The hidden size of the self attention. | |
num_attention_heads: number of attention head in the self | |
attention. | |
attention_dropout_prob: dropout probability of the attention | |
score in self attention. | |
output_dropout_prob: dropout probability for the outputs | |
after self attention and final output. | |
checkpoint_activations: if True, checkpoint activations. | |
checkpoint_num_layers: number of layers to checkpoint. This | |
is basically the chunk size in checkpoitning. | |
layernorm_epsilon: epsilon used in layernorm to avoid | |
division by zero. | |
init_method_std: standard deviation of the init method which has | |
the form N(0, std). | |
use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers) | |
scaling for the output weights ( | |
output of self attention and mlp). | |
""" | |
def __init__(self, | |
num_layers, | |
hidden_size, | |
num_attention_heads, | |
max_sequence_length, | |
max_memory_length, | |
embedding_dropout_prob, | |
attention_dropout_prob, | |
output_dropout_prob, | |
checkpoint_activations, | |
checkpoint_num_layers=1, | |
layernorm_epsilon=1.0e-5, | |
init_method_std=0.02, | |
use_scaled_init_for_output_weights=True, | |
relative_encoding=False): | |
super(GPT2Transformer, self).__init__() | |
# Store activation checkpoiting flag. | |
self.checkpoint_activations = checkpoint_activations | |
self.checkpoint_num_layers = checkpoint_num_layers | |
self.max_memory_length = max_memory_length | |
output_layer_init_method = None | |
if use_scaled_init_for_output_weights: | |
output_layer_init_method = scaled_init_method(init_method_std, | |
num_layers) | |
# Embeddings dropout | |
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) | |
self.relative_encoding = relative_encoding | |
if relative_encoding: | |
# Relative position embedding | |
self.position_embeddings = PositionalEmbedding(hidden_size) | |
# Per attention head and per partition values. | |
self.hidden_size_per_attention_head = divide(hidden_size, | |
num_attention_heads) | |
self.num_attention_heads_per_partition = num_attention_heads | |
self.r_w_bias = torch.nn.Parameter( | |
torch.Tensor(self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) | |
self.r_r_bias = torch.nn.Parameter( | |
torch.Tensor(self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) | |
# Always initialize bias to zero. | |
with torch.no_grad(): | |
self.r_w_bias.zero_() | |
self.r_r_bias.zero_() | |
else: | |
# Position embedding (serial). | |
self.position_embeddings = torch.nn.Embedding(max_sequence_length, | |
hidden_size) | |
# Initialize the position embeddings. | |
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) | |
def get_layer(): | |
return GPT2TransformerLayer( | |
hidden_size, | |
num_attention_heads, | |
attention_dropout_prob, | |
output_dropout_prob, | |
layernorm_epsilon, | |
unscaled_init_method(init_method_std), | |
output_layer_init_method=output_layer_init_method, | |
relative_encoding=relative_encoding) | |
# Transformer layers. | |
self.layers = torch.nn.ModuleList( | |
[get_layer() for _ in range(num_layers)]) | |
# Final layer norm before output. | |
self.final_layernorm = torch.nn.LayerNorm(hidden_size, eps=layernorm_epsilon) | |
def forward(self, hidden_states, position_ids, attention_mask, *mems): | |
batch_size, query_length = hidden_states.size()[:2] | |
memory_length = mems[0].size(1) if mems else 0 | |
key_length = query_length + memory_length | |
attention_mask = attention_mask[:, :, :, -query_length - memory_length:] | |
if self.relative_encoding: | |
# why drop twice here | |
# hidden_states = self.embedding_dropout(hidden_states) | |
position_sequence = torch.arange(key_length - 1, -1, -1.0, device=hidden_states.device, | |
dtype=hidden_states.dtype) | |
position_embeddings = self.position_embeddings(position_sequence) | |
# Apply dropout | |
position_embeddings = self.embedding_dropout(position_embeddings) | |
hidden_states = self.embedding_dropout(hidden_states) | |
else: | |
position_embeddings = self.position_embeddings(position_ids) | |
hidden_states = hidden_states + position_embeddings | |
hidden_states = self.embedding_dropout(hidden_states) | |
if self.max_memory_length > 0: | |
mem_layers = [hidden_states.detach()] | |
else: | |
mem_layers = [] | |
def custom(start, end): | |
def custom_forward(*inputs): | |
layers_ = self.layers[start:end] | |
x_, inputs = inputs[0], inputs[1:] | |
if self.relative_encoding: | |
inputs, mems_ = inputs[:4], inputs[4:] | |
else: | |
inputs, mems_ = inputs[:1], inputs[1:] | |
for i, layer in enumerate(layers_): | |
mem_i_ = mems_[i] if mems_ else None | |
x_ = layer(x_, *inputs, mem=mem_i_) | |
if self.max_memory_length > 0: | |
mem_layers.append(x_.detach()) | |
return x_ | |
return custom_forward | |
if self.checkpoint_activations: | |
la = 0 | |
num_layers = len(self.layers) | |
chunk_length = self.checkpoint_num_layers | |
while la < num_layers: | |
args = [hidden_states, attention_mask] | |
if self.relative_encoding: | |
args += [position_embeddings, self.r_w_bias, self.r_r_bias] | |
if mems: | |
args += mems[la: la + chunk_length] | |
hidden_states = checkpoint(custom(la, la + chunk_length), *args) | |
la += chunk_length | |
else: | |
for i, layer in enumerate(self.layers): | |
args = [hidden_states, attention_mask] | |
if self.relative_encoding: | |
args += [position_embeddings, self.r_w_bias, self.r_r_bias] | |
mem_i = mems[i] if mems else None | |
hidden_states = layer(*args, mem=mem_i) | |
if self.max_memory_length > 0: | |
mem_layers.append(hidden_states.detach()) | |
# Final layer norm. | |
output = self.final_layernorm(hidden_states) | |
if self.max_memory_length > 0: | |
mem_layers = self.update_mems(mem_layers, mems) | |
return (output, *mem_layers) | |
def update_mems(self, hiddens, mems): | |
memory_length = mems[0].size(1) if mems else 0 | |
query_length = hiddens[0].size(1) | |
new_memory_length = min(self.max_memory_length, memory_length + query_length) | |
new_mems = [] | |
with torch.no_grad(): | |
for i in range(len(hiddens)): | |
if new_memory_length <= query_length: | |
new_mems.append(hiddens[i][:, -new_memory_length:]) | |
else: | |
new_mems.append( | |
torch.cat( | |
(mems[i][:, -new_memory_length + query_length:], hiddens[i]), dim=1)) | |
return new_mems | |
class TransfoXLDenoisePreTrainedModel(PreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and | |
a simple interface for downloading and loading pretrained models. | |
""" | |
config_class = TransfoXLDenoiseConfig | |
base_model_prefix = "transfo_xl_denoise" | |
supports_gradient_checkpointing = True | |
_keys_to_ignore_on_load_missing = [r"position_ids"] | |
def _init_weights(self, module): | |
""" Initialize the weights """ | |
pass # to bypass the not implement error | |
class TransfoXLDenoiseModel(TransfoXLDenoisePreTrainedModel): | |
"""GPT-2 Language model. | |
The output of the forward method are the logits (parallel or | |
serial depending on the `parallel_output` flag. | |
""" | |
def __init__(self, config: TransfoXLDenoiseConfig): | |
super().__init__(config) | |
self.config = config | |
# Word embeddings (parallel). | |
self.word_embeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size) | |
# Transformer | |
self.transformer = GPT2Transformer(config.num_layers, | |
config.hidden_size, | |
config.num_attention_heads, | |
config.max_sequence_length, | |
config.max_memory_length, | |
config.embedding_dropout_prob, | |
config.attention_dropout_prob, | |
config.output_dropout_prob, | |
config.checkpoint_activations, | |
config.checkpoint_num_layers, | |
relative_encoding=config.relative_encoding) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
hidden_states=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
**unused, | |
): | |
r""" | |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention | |
if the model is configured as a decoder. | |
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on the padding token indices of the encoder input. This mask | |
is used in the cross-attention if the model is configured as a decoder. | |
Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with | |
each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` | |
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)` | |
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up | |
decoding (see `past_key_values`). | |
""" | |
# Embeddings. | |
# one-hot batch_size * seq_len * vocab_size, can use gradient | |
# if input_ids.shape[-1] == self.word_embeddings.weight.shape[0]: | |
# words_embeddings = torch.einsum("ijk,kl->ijl", input_ids, self.word_embeddings.weight) | |
# else: | |
# print(f'input_ids {input_ids.device}, word_embedding {self.word_embeddings.weight.device}') | |
# words_embeddings = self.word_embeddings(input_ids) | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
assert input_ids is not None and attention_mask is not None and position_ids is not None, \ | |
"You have to specify input_ids, attention_mask, and position_ids. Check tokenizer.encode_plus for details" | |
if not hidden_states: | |
hidden_states = [] | |
embeddings = self.word_embeddings(input_ids) | |
# Transformer. | |
transformer_output = self.transformer( | |
embeddings, position_ids, attention_mask, *hidden_states) | |
logits, *hidden_states = transformer_output | |
logits = F.linear(logits, self.word_embeddings.weight) | |
if not return_dict: | |
return logits, hidden_states | |
return TransfoXLDenoiseModelOutput( | |
logits=logits, | |
hidden_states=hidden_states | |
) | |