|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Callable, List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
from packaging import version |
|
from torch import nn |
|
|
|
from .utils import logging |
|
|
|
|
|
ALL_LAYERNORM_LAYERS = [nn.LayerNorm] |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) |
|
|
|
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10") |
|
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11") |
|
|
|
|
|
def softmax_backward_data(parent, grad_output, output, dim, self): |
|
""" |
|
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according |
|
to the torch version detected. |
|
""" |
|
|
|
from torch import _softmax_backward_data |
|
|
|
if is_torch_less_than_1_11: |
|
return _softmax_backward_data(grad_output, output, parent.dim, self) |
|
else: |
|
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) |
|
|
|
|
|
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: |
|
""" |
|
Prune a linear layer to keep only entries in index. |
|
|
|
Used to remove heads. |
|
|
|
Args: |
|
layer (`torch.nn.Linear`): The layer to prune. |
|
index (`torch.LongTensor`): The indices to keep in the layer. |
|
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. |
|
|
|
Returns: |
|
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. |
|
""" |
|
index = index.to(layer.weight.device) |
|
W = layer.weight.index_select(dim, index).clone().detach() |
|
if layer.bias is not None: |
|
if dim == 1: |
|
b = layer.bias.clone().detach() |
|
else: |
|
b = layer.bias[index].clone().detach() |
|
new_size = list(layer.weight.size()) |
|
new_size[dim] = len(index) |
|
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) |
|
new_layer.weight.requires_grad = False |
|
new_layer.weight.copy_(W.contiguous()) |
|
new_layer.weight.requires_grad = True |
|
if layer.bias is not None: |
|
new_layer.bias.requires_grad = False |
|
new_layer.bias.copy_(b.contiguous()) |
|
new_layer.bias.requires_grad = True |
|
return new_layer |
|
|
|
|
|
class Conv1D(nn.Module): |
|
""" |
|
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). |
|
|
|
Basically works like a linear layer but the weights are transposed. |
|
|
|
Args: |
|
nf (`int`): The number of output features. |
|
nx (`int`): The number of input features. |
|
""" |
|
|
|
def __init__(self, nf, nx): |
|
super().__init__() |
|
self.nf = nf |
|
self.weight = nn.Parameter(torch.empty(nx, nf)) |
|
self.bias = nn.Parameter(torch.zeros(nf)) |
|
nn.init.normal_(self.weight, std=0.02) |
|
|
|
def forward(self, x): |
|
size_out = x.size()[:-1] + (self.nf,) |
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) |
|
x = x.view(size_out) |
|
return x |
|
|
|
|
|
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: |
|
""" |
|
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights |
|
are transposed. |
|
|
|
Used to remove heads. |
|
|
|
Args: |
|
layer ([`~pytorch_utils.Conv1D`]): The layer to prune. |
|
index (`torch.LongTensor`): The indices to keep in the layer. |
|
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. |
|
|
|
Returns: |
|
[`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. |
|
""" |
|
index = index.to(layer.weight.device) |
|
W = layer.weight.index_select(dim, index).clone().detach() |
|
if dim == 0: |
|
b = layer.bias.clone().detach() |
|
else: |
|
b = layer.bias[index].clone().detach() |
|
new_size = list(layer.weight.size()) |
|
new_size[dim] = len(index) |
|
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) |
|
new_layer.weight.requires_grad = False |
|
new_layer.weight.copy_(W.contiguous()) |
|
new_layer.weight.requires_grad = True |
|
new_layer.bias.requires_grad = False |
|
new_layer.bias.copy_(b.contiguous()) |
|
new_layer.bias.requires_grad = True |
|
return new_layer |
|
|
|
|
|
def prune_layer( |
|
layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None |
|
) -> Union[nn.Linear, Conv1D]: |
|
""" |
|
Prune a Conv1D or linear layer to keep only entries in index. |
|
|
|
Used to remove heads. |
|
|
|
Args: |
|
layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune. |
|
index (`torch.LongTensor`): The indices to keep in the layer. |
|
dim (`int`, *optional*): The dimension on which to keep the indices. |
|
|
|
Returns: |
|
`torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. |
|
""" |
|
if isinstance(layer, nn.Linear): |
|
return prune_linear_layer(layer, index, dim=0 if dim is None else dim) |
|
elif isinstance(layer, Conv1D): |
|
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) |
|
else: |
|
raise ValueError(f"Can't prune layer of class {layer.__class__}") |
|
|
|
|
|
def apply_chunking_to_forward( |
|
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors |
|
) -> torch.Tensor: |
|
""" |
|
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension |
|
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. |
|
|
|
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly |
|
applying `forward_fn` to `input_tensors`. |
|
|
|
Args: |
|
forward_fn (`Callable[..., torch.Tensor]`): |
|
The forward function of the model. |
|
chunk_size (`int`): |
|
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. |
|
chunk_dim (`int`): |
|
The dimension over which the `input_tensors` should be chunked. |
|
input_tensors (`Tuple[torch.Tensor]`): |
|
The input tensors of `forward_fn` which will be chunked |
|
|
|
Returns: |
|
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. |
|
|
|
|
|
Examples: |
|
|
|
```python |
|
# rename the usual forward() fn to forward_chunk() |
|
def forward_chunk(self, hidden_states): |
|
hidden_states = self.decoder(hidden_states) |
|
return hidden_states |
|
|
|
|
|
# implement a chunked forward function |
|
def forward(self, hidden_states): |
|
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) |
|
```""" |
|
|
|
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" |
|
|
|
|
|
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) |
|
if num_args_in_forward_chunk_fn != len(input_tensors): |
|
raise ValueError( |
|
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " |
|
"tensors are given" |
|
) |
|
|
|
if chunk_size > 0: |
|
tensor_shape = input_tensors[0].shape[chunk_dim] |
|
for input_tensor in input_tensors: |
|
if input_tensor.shape[chunk_dim] != tensor_shape: |
|
raise ValueError( |
|
f"All input tenors have to be of the same shape: {tensor_shape}, " |
|
f"found shape {input_tensor.shape[chunk_dim]}" |
|
) |
|
|
|
if input_tensors[0].shape[chunk_dim] % chunk_size != 0: |
|
raise ValueError( |
|
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " |
|
f"size {chunk_size}" |
|
) |
|
|
|
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size |
|
|
|
|
|
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) |
|
|
|
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) |
|
|
|
return torch.cat(output_chunks, dim=chunk_dim) |
|
|
|
return forward_fn(*input_tensors) |
|
|
|
|
|
def find_pruneable_heads_and_indices( |
|
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] |
|
) -> Tuple[Set[int], torch.LongTensor]: |
|
""" |
|
Finds the heads and their indices taking `already_pruned_heads` into account. |
|
|
|
Args: |
|
heads (`List[int]`): List of the indices of heads to prune. |
|
n_heads (`int`): The number of heads in the model. |
|
head_size (`int`): The size of each head. |
|
already_pruned_heads (`Set[int]`): A set of already pruned heads. |
|
|
|
Returns: |
|
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. |
|
""" |
|
mask = torch.ones(n_heads, head_size) |
|
heads = set(heads) - already_pruned_heads |
|
for head in heads: |
|
|
|
head = head - sum(1 if h < head else 0 for h in already_pruned_heads) |
|
mask[head] = 0 |
|
mask = mask.view(-1).contiguous().eq(1) |
|
index: torch.LongTensor = torch.arange(len(mask))[mask].long() |
|
return heads, index |
|
|
|
|
|
def meshgrid( |
|
*tensors: Union[torch.Tensor, List[torch.Tensor]], indexing: Optional[str] = None |
|
) -> Tuple[torch.Tensor, ...]: |
|
""" |
|
Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument. |
|
|
|
Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html |
|
""" |
|
if is_torch_greater_or_equal_than_1_10: |
|
return torch.meshgrid(*tensors, indexing=indexing) |
|
else: |
|
if indexing != "ij": |
|
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.') |
|
return torch.meshgrid(*tensors) |
|
|