|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq.models.nat import ( |
|
_apply_del_words, |
|
_apply_ins_masks, |
|
_apply_ins_words, |
|
_fill, |
|
_skip, |
|
_skip_encoder_out, |
|
) |
|
|
|
|
|
class _EnsembleModelEncoder(object): |
|
def __init__(self, models): |
|
self.models = models |
|
|
|
def reorder_encoder_out(self, encoder_outs, new_order): |
|
encoder_outs = [ |
|
model.encoder.reorder_encoder_out(encoder_out, new_order) |
|
for model, encoder_out in zip(self.models, encoder_outs) |
|
] |
|
return encoder_outs |
|
|
|
|
|
class BasicEnsembleModel(torch.nn.Module): |
|
"""A wrapper around an ensemble of models.""" |
|
|
|
def __init__(self, models): |
|
super().__init__() |
|
self.models = torch.nn.ModuleList(models) |
|
self.bos = self.models[0].decoder.dictionary.bos() |
|
self.eos = self.models[0].decoder.dictionary.eos() |
|
self.pad = self.models[0].decoder.dictionary.pad() |
|
self.unk = self.models[0].decoder.dictionary.unk() |
|
self.encoder = _EnsembleModelEncoder(self.models) |
|
|
|
def has_encoder(self): |
|
return hasattr(self.models[0], "encoder") |
|
|
|
def max_decoder_positions(self): |
|
return min(m.max_decoder_positions() for m in self.models) |
|
|
|
@torch.no_grad() |
|
def forward_encoder(self, encoder_input): |
|
if not self.has_encoder(): |
|
return None |
|
return [model.forward_encoder(encoder_input) for model in self.models] |
|
|
|
@torch.no_grad() |
|
def forward_decoder(self, *inputs): |
|
raise NotImplementedError |
|
|
|
def initialize_output_tokens(self, *inputs): |
|
raise NotImplementedError |
|
|
|
|
|
class EnsembleLevT(BasicEnsembleModel): |
|
"""A wrapper around an ensemble of models.""" |
|
|
|
def __init__(self, models): |
|
super().__init__(models) |
|
|
|
@torch.no_grad() |
|
def forward_decoder( |
|
self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs |
|
): |
|
|
|
|
|
|
|
|
|
output_tokens = decoder_out.output_tokens |
|
output_scores = decoder_out.output_scores |
|
attn = decoder_out.attn |
|
|
|
bsz = output_tokens.size(0) |
|
if max_ratio is None: |
|
max_lens = output_tokens.new().fill_(255) |
|
else: |
|
if not encoder_outs[0]["encoder_padding_mask"]: |
|
src_lens = ( |
|
encoder_outs[0]["encoder_out"][0].new(bsz) |
|
.fill_(encoder_outs[0]["encoder_out"][0].size(1)) |
|
) |
|
else: |
|
src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) |
|
max_lens = (src_lens * max_ratio).clamp(min=10).long() |
|
|
|
|
|
|
|
can_del_word = output_tokens.ne(self.pad).sum(1) > 2 |
|
if can_del_word.sum() != 0: |
|
output_tokens, output_scores, attn = self.forward_word_del( |
|
encoder_outs, |
|
output_tokens, |
|
output_scores, |
|
attn, |
|
can_del_word, |
|
) |
|
|
|
|
|
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens |
|
if can_ins_mask.sum() != 0: |
|
output_tokens, output_scores = self.forward_mask_ins( |
|
encoder_outs, |
|
output_tokens, |
|
output_scores, |
|
can_ins_mask, |
|
eos_penalty, |
|
max_lens, |
|
) |
|
|
|
|
|
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 |
|
if can_ins_word.sum() != 0: |
|
output_tokens, output_scores, attn = self.forward_word_ins( |
|
encoder_outs, |
|
output_tokens, |
|
output_scores, |
|
attn, |
|
can_ins_word, |
|
) |
|
|
|
|
|
cut_off = output_tokens.ne(self.pad).sum(1).max() |
|
output_tokens = output_tokens[:, :cut_off] |
|
output_scores = output_scores[:, :cut_off] |
|
attn = None if attn is None else attn[:, :cut_off, :] |
|
return decoder_out._replace( |
|
output_tokens=output_tokens, |
|
output_scores=output_scores, |
|
attn=attn, |
|
history=None, |
|
) |
|
|
|
def forward_word_del( |
|
self, encoder_outs, output_tokens, output_scores, attn, can_del_word |
|
): |
|
word_del_score_avg = [] |
|
word_del_attn_avg = [] |
|
for model, encoder_out in zip(self.models, encoder_outs): |
|
word_del_out, word_del_attn = model.decoder.forward_word_del( |
|
_skip(output_tokens, can_del_word), |
|
_skip_encoder_out(model.encoder, encoder_out, can_del_word), |
|
) |
|
word_del_score = F.log_softmax(word_del_out, 2) |
|
word_del_score_avg.append(word_del_score) |
|
word_del_attn_avg.append(word_del_attn) |
|
word_del_score_avg = torch.logsumexp( |
|
torch.stack(word_del_score_avg, dim=0), dim=0 |
|
) - math.log(len(self.models)) |
|
word_del_pred = word_del_score_avg.max(-1)[1].bool() |
|
if word_del_attn_avg[0] is not None: |
|
word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models) |
|
else: |
|
word_del_attn_avg = None |
|
|
|
_tokens, _scores, _attn = _apply_del_words( |
|
output_tokens[can_del_word], |
|
output_scores[can_del_word], |
|
word_del_attn_avg, |
|
word_del_pred, |
|
self.pad, |
|
self.bos, |
|
self.eos, |
|
) |
|
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) |
|
output_scores = _fill(output_scores, can_del_word, _scores, 0) |
|
attn = _fill(attn, can_del_word, _attn, 0.0) |
|
return output_tokens, output_scores, attn |
|
|
|
def forward_mask_ins( |
|
self, |
|
encoder_outs, |
|
output_tokens, |
|
output_scores, |
|
can_ins_mask, |
|
eos_penalty, |
|
max_lens, |
|
): |
|
mask_ins_score_avg = [] |
|
for model, encoder_out in zip(self.models, encoder_outs): |
|
mask_ins_out, _ = model.decoder.forward_mask_ins( |
|
_skip(output_tokens, can_ins_mask), |
|
_skip_encoder_out(model.encoder, encoder_out, can_ins_mask), |
|
) |
|
mask_ins_score = F.log_softmax(mask_ins_out, 2) |
|
if eos_penalty > 0.0: |
|
mask_ins_score[:, :, 0] -= eos_penalty |
|
mask_ins_score_avg.append(mask_ins_score) |
|
mask_ins_score_avg = torch.logsumexp( |
|
torch.stack(mask_ins_score_avg, dim=0), dim=0 |
|
) - math.log(len(self.models)) |
|
mask_ins_pred = mask_ins_score_avg.max(-1)[1] |
|
mask_ins_pred = torch.min( |
|
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) |
|
) |
|
_tokens, _scores = _apply_ins_masks( |
|
output_tokens[can_ins_mask], |
|
output_scores[can_ins_mask], |
|
mask_ins_pred, |
|
self.pad, |
|
self.unk, |
|
self.eos, |
|
) |
|
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) |
|
output_scores = _fill(output_scores, can_ins_mask, _scores, 0) |
|
return output_tokens, output_scores |
|
|
|
def forward_word_ins( |
|
self, encoder_outs, output_tokens, output_scores, attn, can_ins_word |
|
): |
|
word_ins_score_avg = [] |
|
word_ins_attn_avg = [] |
|
for model, encoder_out in zip(self.models, encoder_outs): |
|
word_ins_out, word_ins_attn = model.decoder.forward_word_ins( |
|
_skip(output_tokens, can_ins_word), |
|
_skip_encoder_out(model.encoder, encoder_out, can_ins_word), |
|
) |
|
word_ins_score = F.log_softmax(word_ins_out, 2) |
|
word_ins_score_avg.append(word_ins_score) |
|
word_ins_attn_avg.append(word_ins_attn) |
|
word_ins_score_avg = torch.logsumexp( |
|
torch.stack(word_ins_score_avg, dim=0), dim=0 |
|
) - math.log(len(self.models)) |
|
if word_ins_attn_avg[0] is not None: |
|
word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models) |
|
else: |
|
word_ins_attn_avg = None |
|
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) |
|
|
|
_tokens, _scores = _apply_ins_words( |
|
output_tokens[can_ins_word], |
|
output_scores[can_ins_word], |
|
word_ins_pred, |
|
word_ins_score_max, |
|
self.unk, |
|
) |
|
|
|
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) |
|
output_scores = _fill(output_scores, can_ins_word, _scores, 0) |
|
attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) |
|
return output_tokens, output_scores, attn |
|
|
|
def initialize_output_tokens(self, encoder_outs, src_tokens): |
|
|
|
return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens) |
|
|