aliabd
full working demo
d5175d3
raw
history blame
14.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class LatencyMetric(object):
@staticmethod
def length_from_padding_mask(padding_mask, batch_first: bool = False):
dim = 1 if batch_first else 0
return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True)
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
if batch_first:
# convert to batch_last
delays = delays.t()
src_lens = src_lens.t()
tgt_len, bsz = delays.size()
_, bsz_1 = src_lens.size()
if target_padding_mask is not None:
target_padding_mask = target_padding_mask.t()
tgt_len_1, bsz_2 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
assert bsz == bsz_1
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([1, bsz]).float()
else:
# 1, batch_size
tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float()
delays = delays.masked_fill(target_padding_mask, 0)
return delays, src_lens, tgt_lens, target_padding_mask
def __call__(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
Expected sizes:
delays: tgt_len, batch_size
src_lens: 1, batch_size
target_padding_mask: tgt_len, batch_size
"""
raise NotImplementedError
class AverageProportion(LatencyMetric):
"""
Function to calculate Average Proportion from
Can neural machine translation do simultaneous translation?
(https://arxiv.org/abs/1606.02012)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else:
AP = torch.sum(delays, dim=0, keepdim=True)
AP = AP / (src_lens * tgt_lens)
return AP
class AverageLagging(LatencyMetric):
"""
Function to calculate Average Lagging from
STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework
(https://arxiv.org/abs/1810.08398)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
Where
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens
lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
return AL
class DifferentiableAverageLagging(LatencyMetric):
"""
Function to calculate Differentiable Average Lagging from
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
(https://arxiv.org/abs/1906.05218)
Delays are monotonic steps, range from 0 to src_len-1.
(In the original paper thery are from 1 to src_len)
Give src x tgt y, AP is calculated as:
DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma
Where
delays'_i =
1. delays_i if i == 1
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
gamma = tgt_lens / src_lens
new_delays = torch.zeros_like(delays)
for i in range(delays.size(0)):
if i == 0:
new_delays[i] = delays[i]
else:
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0),
],
dim=0,
).max(dim=0)[0]
DAL = (
new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens
return DAL
class LatencyMetricVariance(LatencyMetric):
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = True,
start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
# convert to batch_last
bsz, num_heads_x_layers, tgt_len = delays.size()
bsz_1, _ = src_lens.size()
assert bsz == bsz_1
if target_padding_mask is not None:
bsz_2, tgt_len_1 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float()
else:
# batch_size, 1
tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float()
delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0)
return delays, src_lens, tgt_lens, target_padding_mask
class VarianceDelay(LatencyMetricVariance):
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
delays : bsz, num_heads_x_layers, tgt_len
src_lens : bsz, 1
target_lens : bsz, 1
target_padding_mask: bsz, tgt_len or None
"""
if delays.size(1) == 1:
return delays.new_zeros([1])
variance_delays = delays.var(dim=1)
if target_padding_mask is not None:
variance_delays.masked_fill_(target_padding_mask, 0)
return variance_delays.sum(dim=1, keepdim=True) / tgt_lens
class LatencyInference(object):
def __init__(self, start_from_zero=True):
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.start_from_zero = start_from_zero
def __call__(self, monotonic_step, src_lens):
"""
monotonic_step range from 0 to src_len. src_len means eos
delays: bsz, tgt_len
src_lens: bsz, 1
"""
if not self.start_from_zero:
monotonic_step -= 1
src_lens = src_lens
delays = monotonic_step.view(
monotonic_step.size(0), -1, monotonic_step.size(-1)
).max(dim=1)[0]
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays
).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
delays.float(),
src_lens.float(),
target_padding_mask=None,
batch_first=True,
start_from_zero=True,
).t()
return return_dict
class LatencyTraining(object):
def __init__(
self,
avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
self.avg_type = avg_type
self.var_type = var_type
self.stay_on_last_token = stay_on_last_token
self.average_method = average_method
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.variance_calculator = {
"variance_delay": VarianceDelay(),
}
def expected_delays_from_attention(
self, attention, source_padding_mask=None, target_padding_mask=None
):
if type(attention) == list:
# bsz, num_heads, tgt_len, src_len
bsz, num_heads, tgt_len, src_len = attention[0].size()
attention = torch.cat(attention, dim=1)
bsz, num_heads_x_layers, tgt_len, src_len = attention.size()
# bsz * num_heads * num_layers, tgt_len, src_len
attention = attention.view(-1, tgt_len, src_len)
else:
# bsz * num_heads * num_layers, tgt_len, src_len
bsz, tgt_len, src_len = attention.size()
num_heads_x_layers = 1
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
.type_as(attention)
)
if source_padding_mask is not None:
src_offset = (
source_padding_mask.type_as(attention)
.sum(dim=1, keepdim=True)
.expand(bsz, num_heads_x_layers)
.contiguous()
.view(-1, 1)
)
src_lens = src_len - src_offset
if source_padding_mask[:, 0].any():
# Pad left
src_offset = src_offset.view(-1, 1, 1)
steps = steps - src_offset
steps = steps.masked_fill(steps <= 0, 0)
else:
src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
def avg_loss(self, expected_delays, src_lens, target_padding_mask):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
)
if self.average_method == "average":
# bsz * tgt_len
expected_delays = expected_delays.mean(dim=1)
elif self.average_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_delays, dim=1)
expected_delays = torch.sum(expected_delays * weights, dim=1)
elif self.average_method == "max":
# bsz * num_heads_x_num_layers, tgt_len
expected_delays = expected_delays.max(dim=1)[0]
else:
raise RuntimeError(f"{self.average_method} is not supported")
src_lens = src_lens.view(bsz, -1)[:, :1]
target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0]
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
# bsz * num_heads_x_num_layers, 1
return self.avg_weight * average_delays.sum()
else:
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")
return self.var_weight * variance_delays.sum()
else:
return 0.0
def loss(self, attention, source_padding_mask=None, target_padding_mask=None):
expected_delays, src_lens = self.expected_delays_from_attention(
attention, source_padding_mask, target_padding_mask
)
latency_loss = 0
latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask)
latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask)
return latency_loss