|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from collections.abc import Iterable |
|
from itertools import repeat |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_equal(value1, value2, name1=None, name2=None): |
|
"""Asserts two values are equal otherwise raise an error.""" |
|
|
|
str_name1 = "" if name1 is None else "{} ".format(name1) |
|
str_name2 = "" if name2 is None else "{} ".format(name2) |
|
if value1 != value2: |
|
str_value1 = "{}" if name1 is None else "({})" |
|
str_value1 = str_value1.format(value1) |
|
str_value2 = "{}" if name2 is None else "({})" |
|
str_value2 = str_value2.format(value2) |
|
raise ValueError( |
|
"Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) |
|
) |
|
|
|
|
|
def fill_config(config, key, value): |
|
if value is not None: |
|
if key not in config or config[key] is None: |
|
config[key] = value |
|
assert_equal(value, config[key], "value", f'config["{key}"]') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_return_expected(value, undefined_value, expected_value, name=None): |
|
""" |
|
Return the expected value while checking if the given value is undefined or |
|
equal to the expected value. |
|
""" |
|
if (undefined_value is None and value is None) or (undefined_value == value): |
|
return expected_value |
|
if value != expected_value: |
|
str_name = "" if name is None else "{} ".format(name) |
|
str_value = "{}" if name is None else "({})" |
|
str_value = str_value.format(value) |
|
raise ValueError( |
|
"Expected {}{} == {}".format(str_name, str_value, expected_value) |
|
) |
|
return expected_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_time_axis(layout): |
|
""" |
|
Extract the time axis from the layout, for example for breaking sequence into |
|
segments. |
|
""" |
|
if layout in ["TB", "TBD"]: |
|
return 0 |
|
if layout in ["BT", "BTD"]: |
|
return 1 |
|
if layout in ["BCTD"]: |
|
return 2 |
|
raise ValueError("Unsupported layout = {}".format(layout)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch_axis(layout): |
|
""" |
|
Extract the batch axis from the layout |
|
""" |
|
if layout in ["TB", "TBD"]: |
|
return 1 |
|
if layout in ["BT", "BTD", "BCTD"]: |
|
return 0 |
|
raise ValueError("Unsupported layout = {}".format(layout)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def monotonically_increasing_and_bounded(iterable, min=None, max=None): |
|
""" |
|
Check if the elements in the given iterable are monotonically increasing and |
|
bounded by upper/lower bounds. |
|
""" |
|
if not isinstance(iterable, Iterable): |
|
raise TypeError( |
|
"Expected iterable to be of type Iterable, got ({})".format( |
|
iterable.__class__.__name__ |
|
) |
|
) |
|
for i in range(len(iterable)): |
|
if min is not None and iterable[i] < min: |
|
return False |
|
if max is not None and iterable[i] > max: |
|
return False |
|
if i > 0 and iterable[i] <= iterable[i - 1]: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_pair(value, name): |
|
"""Make a pair (of type tuple) of given value.""" |
|
if isinstance(value, Iterable): |
|
if len(value) != 2: |
|
raise ValueError( |
|
"Expected `{}` to have exactly 2 elements, got: ({})".format( |
|
name, value |
|
) |
|
) |
|
return value |
|
return tuple(repeat(value, 2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_conv_output_attrs( |
|
module, input_channels, input_dim, batch_size=1, max_length=8 |
|
): |
|
"""Get output attributes of a module with input.""" |
|
input = torch.randn(batch_size, input_channels, max_length, input_dim) |
|
output = module(input) |
|
output_channels = output.shape[1] |
|
output_dim = output.shape[-1] |
|
return output_channels, output_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NoOp(torch.nn.Module): |
|
""" |
|
NoOp simply passes the input as the output. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Permute(torch.nn.Module): |
|
def __init__(self, dims): |
|
super().__init__() |
|
self.dims = dims |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input.permute(self.dims).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lengths_to_padding_mask(lengths: Tensor) -> Tensor: |
|
"""Convert lengths of shape (B, ) to padding mask.""" |
|
batch_size = lengths.shape[0] |
|
max_length = int(torch.max(lengths).item()) |
|
padding_mask = torch.arange( |
|
max_length, device=lengths.device, dtype=lengths.dtype |
|
).expand(batch_size, max_length) >= lengths.unsqueeze(1) |
|
|
|
return padding_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lengths_to_attention_mask( |
|
lengths: Tensor, |
|
left_context: Optional[int] = None, |
|
right_context: Optional[int] = None, |
|
) -> Optional[Tensor]: |
|
""" |
|
Generate attention mask based on (lengths, left_context, right_context). |
|
left_context is None means unlimited left context. |
|
right_context is None means unlimited right context. |
|
""" |
|
|
|
if left_context is None and right_context is None: |
|
return None |
|
|
|
max_length = int(torch.max(lengths).item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indices = torch.arange( |
|
max_length, device=lengths.device, dtype=lengths.dtype |
|
).expand(max_length, max_length) - torch.arange( |
|
max_length, device=lengths.device |
|
).view( |
|
max_length, -1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool_mask = ( |
|
torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if left_context is not None: |
|
left_mask = indices >= -left_context |
|
bool_mask = bool_mask & left_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if right_context is not None: |
|
right_mask = indices <= right_context |
|
bool_mask = bool_mask & right_mask |
|
|
|
bool_mask = (~bool_mask).to(device=lengths.device) |
|
return bool_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_output_norm(module, output_norm=None): |
|
""" |
|
Infer the output norm (string and module) needed on the module gvien desired |
|
output normalization. |
|
""" |
|
if output_norm == module.output_norm(): |
|
|
|
return (None, NoOp()) |
|
|
|
if output_norm is None and module.output_norm() is not None: |
|
logger = logging.getLogger("infer_output_norm()") |
|
logger.warning( |
|
"trying to set output_norm ({}) ".format(output_norm) |
|
+ "but got module.output_norm() ({}), ".format(module.output_norm()) |
|
+ "the combined output_norm() will be ({})".format(module.output_norm()) |
|
) |
|
return (None, NoOp()) |
|
|
|
if output_norm == "log_softmax": |
|
if module.output_norm() is not None: |
|
raise ValueError( |
|
"incompatible output_norm ({}) ".format(output_norm) |
|
+ "and module.output_norm() ({})".format(module.output_norm()) |
|
) |
|
else: |
|
return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) |
|
|
|
if output_norm == "softmax": |
|
if module.output_norm() is not None: |
|
raise ValueError( |
|
"incompatible output_norm ({}) ".format(output_norm) |
|
+ "and module.output_norm() ({})".format(module.output_norm()) |
|
) |
|
else: |
|
return ("softmax", torch.nn.Softmax(dim=-1)) |
|
|
|
raise ValueError( |
|
"output_norm ({}) not in ".format(output_norm) |
|
+ "supported list = [None, softmax, log_softmax]" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_channels_from_layout(layout, channels): |
|
"""Extract the number of channels from the layout.""" |
|
if layout in ("TBD", "BTD"): |
|
if channels is not None and channels != 1: |
|
raise ValueError( |
|
"Expected channels ({}) to be 1 for layout = {}".format( |
|
channels, layout |
|
) |
|
) |
|
if channels is None: |
|
return 1 |
|
return channels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.export |
|
def pad_sequence( |
|
sequence: Tensor, |
|
time_axis: int, |
|
extra_left_context: int = 0, |
|
extra_right_context: int = 0, |
|
) -> Tensor: |
|
"""Pad extra left/right contexts to the sequence.""" |
|
|
|
if extra_left_context == 0 and extra_right_context == 0: |
|
return sequence |
|
|
|
tensors_to_concat = [] |
|
|
|
if extra_left_context: |
|
size = (extra_left_context,) |
|
fill_value = 0 |
|
indices = torch.full( |
|
size=size, |
|
fill_value=fill_value, |
|
dtype=torch.long, |
|
device=sequence.device, |
|
) |
|
left_padding = torch.index_select(sequence, time_axis, indices) |
|
tensors_to_concat.append(left_padding) |
|
|
|
tensors_to_concat.append(sequence) |
|
|
|
|
|
|
|
if extra_right_context: |
|
size = list(sequence.shape) |
|
size[time_axis] = extra_right_context |
|
right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) |
|
tensors_to_concat.append(right_padding) |
|
|
|
padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) |
|
return padded_sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.export |
|
def sequence_to_segments( |
|
sequence: Tensor, |
|
time_axis: int, |
|
lengths: Tensor, |
|
segment_size: Optional[int] = None, |
|
extra_left_context: int = 0, |
|
extra_right_context: int = 0, |
|
) -> List[Tuple[Tensor, Tensor]]: |
|
"""Breaks sequence into segments.""" |
|
|
|
sequence = pad_sequence( |
|
sequence=sequence, |
|
time_axis=time_axis, |
|
extra_left_context=extra_left_context, |
|
extra_right_context=extra_right_context, |
|
) |
|
|
|
lengths = lengths + extra_left_context + extra_right_context |
|
|
|
segments: List[Tuple[Tensor, Tensor]] = [] |
|
|
|
if segment_size is None: |
|
segments.append((sequence, lengths)) |
|
return segments |
|
|
|
offset = 0 |
|
end = sequence.shape[time_axis] |
|
step = segment_size |
|
size = extra_left_context + segment_size + extra_right_context |
|
|
|
while offset + extra_left_context + extra_right_context < end: |
|
clamped_size = min(size, end - offset) |
|
segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) |
|
indices = torch.arange( |
|
start=offset, |
|
end=(offset + clamped_size), |
|
step=1, |
|
dtype=torch.long, |
|
device=sequence.device, |
|
) |
|
segment_tensor = torch.index_select(sequence, time_axis, indices) |
|
segments.append((segment_tensor, segment_lengths)) |
|
offset = offset + step |
|
|
|
return segments |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.export |
|
def segments_to_sequence( |
|
segments: List[Tuple[Tensor, Tensor]], time_axis: int |
|
) -> Tuple[Tensor, Tensor]: |
|
"""Concatenate segments into a full sequence.""" |
|
if len(segments) == 1: |
|
return segments[0] |
|
|
|
tensors_to_concat: List[Tensor] = [] |
|
lengths_to_stack: List[Tensor] = [] |
|
|
|
for tensor, lengths in segments: |
|
tensors_to_concat.append(tensor) |
|
lengths_to_stack.append(lengths) |
|
|
|
sequence = torch.cat(tensors_to_concat, dim=time_axis) |
|
lengths = torch.stack(lengths_to_stack, dim=0) |
|
lengths = torch.sum(lengths, dim=0) |
|
|
|
return sequence, lengths |
|
|
|
|
|
def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): |
|
""" |
|
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor |
|
|
|
Args: |
|
lengths: a (B, )-shaped tensor |
|
batch_first: whether to return a (B, T) tensor |
|
|
|
Return: |
|
max_length: maximum length of B sequences |
|
encoder_padding_mask: a (max_length, B) binary mask, where |
|
[t, b] = False for t < lengths[b] and True otherwise |
|
|
|
TODO: |
|
kernelize this function if benchmarking shows this function is slow |
|
""" |
|
max_lengths = torch.max(lengths).item() |
|
bsz = lengths.size(0) |
|
encoder_padding_mask = torch.arange( |
|
max_lengths |
|
).to( |
|
lengths.device |
|
).view( |
|
1, max_lengths |
|
).expand( |
|
bsz, -1 |
|
) > lengths.view( |
|
bsz, 1 |
|
).expand( |
|
-1, max_lengths |
|
) |
|
if not batch_first: |
|
return encoder_padding_mask.t(), max_lengths |
|
else: |
|
return encoder_padding_mask, max_lengths |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention_suppression(attention_weights: Tensor, scale: float): |
|
|
|
attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) |
|
attention_nozeros = attention_prob.to(torch.bool) |
|
nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) |
|
|
|
|
|
key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) |
|
|
|
|
|
key_mean = key_sum / (nozeros_sum + 1e-8) |
|
|
|
|
|
dis = (attention_prob - key_mean) * (attention_prob - key_mean) |
|
|
|
|
|
dis_masked = torch.where( |
|
attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) |
|
) |
|
|
|
key_var = torch.sum(dis_masked, dim=-1, keepdim=True) |
|
key_var = key_var / (nozeros_sum - 1.0 + 1e-8) |
|
key_std = torch.sqrt(key_var) |
|
key_thread = key_mean - scale * key_std |
|
|
|
|
|
|
|
inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() |
|
inf_tensor[:] = float("-inf") |
|
attention_weights_float = torch.where( |
|
attention_prob < key_thread, |
|
inf_tensor, |
|
attention_weights.float(), |
|
) |
|
|
|
return attention_weights_float.type_as(attention_weights) |
|
|
|
|
|
def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): |
|
return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) |
|
|