|
|
|
|
|
|
|
|
|
|
|
""" |
|
Translate pre-processed data with a trained model. |
|
""" |
|
|
|
import ast |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
from argparse import Namespace |
|
from itertools import chain |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq import checkpoint_utils, options, scoring, tasks, utils |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.logging import progress_bar |
|
from fairseq.logging.meters import StopwatchMeter, TimeMeter |
|
from omegaconf import DictConfig |
|
|
|
|
|
def main(cfg: DictConfig): |
|
|
|
if isinstance(cfg, Namespace): |
|
cfg = convert_namespace_to_omegaconf(cfg) |
|
|
|
assert cfg.common_eval.path is not None, "--path required for generation!" |
|
assert ( |
|
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam |
|
), "--sampling requires --nbest to be equal to --beam" |
|
assert ( |
|
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" |
|
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" |
|
|
|
if cfg.common_eval.results_path is not None: |
|
os.makedirs(cfg.common_eval.results_path, exist_ok=True) |
|
output_path = os.path.join( |
|
cfg.common_eval.results_path, |
|
"generate-{}.txt".format(cfg.dataset.gen_subset), |
|
) |
|
with open(output_path, "w", buffering=1, encoding="utf-8") as h: |
|
return _main(cfg, h) |
|
else: |
|
return _main(cfg, sys.stdout) |
|
|
|
|
|
def get_symbols_to_strip_from_output(generator): |
|
if hasattr(generator, "symbols_to_strip_from_output"): |
|
return generator.symbols_to_strip_from_output |
|
else: |
|
return {generator.eos} |
|
|
|
|
|
def _main(cfg: DictConfig, output_file): |
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
stream=output_file, |
|
) |
|
logger = logging.getLogger("fairseq_cli.generate") |
|
|
|
utils.import_user_module(cfg.common) |
|
|
|
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: |
|
cfg.dataset.max_tokens = 12000 |
|
logger.info(cfg) |
|
|
|
|
|
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: |
|
np.random.seed(cfg.common.seed) |
|
utils.set_torch_seed(cfg.common.seed) |
|
|
|
use_cuda = torch.cuda.is_available() and not cfg.common.cpu |
|
|
|
|
|
task = tasks.setup_task(cfg.task) |
|
|
|
|
|
|
|
try: |
|
src_dict = getattr(task, "source_dictionary", None) |
|
except NotImplementedError: |
|
src_dict = None |
|
tgt_dict = task.target_dictionary |
|
|
|
overrides = ast.literal_eval(cfg.common_eval.model_overrides) |
|
|
|
|
|
logger.info("loading model(s) from {}".format(cfg.common_eval.path)) |
|
models, saved_cfg = checkpoint_utils.load_model_ensemble( |
|
utils.split_paths(cfg.common_eval.path), |
|
arg_overrides=overrides, |
|
task=task, |
|
suffix=cfg.checkpoint.checkpoint_suffix, |
|
strict=(cfg.checkpoint.checkpoint_shard_count == 1), |
|
num_shards=cfg.checkpoint.checkpoint_shard_count, |
|
) |
|
|
|
|
|
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) |
|
|
|
if cfg.generation.lm_path is not None: |
|
overrides["data"] = cfg.task.data |
|
|
|
try: |
|
lms, _ = checkpoint_utils.load_model_ensemble( |
|
[cfg.generation.lm_path], arg_overrides=overrides, task=None |
|
) |
|
except: |
|
logger.warning( |
|
f"Failed to load language model! Please make sure that the language model dict is the same " |
|
f"as target dict and is located in the data dir ({cfg.task.data})" |
|
) |
|
raise |
|
|
|
assert len(lms) == 1 |
|
else: |
|
lms = [None] |
|
|
|
|
|
for model in chain(models, lms): |
|
if model is None: |
|
continue |
|
if cfg.common.fp16: |
|
model.half() |
|
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
|
model.cuda() |
|
model.prepare_for_inference_(cfg) |
|
|
|
|
|
|
|
align_dict = utils.load_align_dict(cfg.generation.replace_unk) |
|
|
|
|
|
itr = task.get_batch_iterator( |
|
dataset=task.dataset(cfg.dataset.gen_subset), |
|
max_tokens=cfg.dataset.max_tokens, |
|
max_sentences=cfg.dataset.batch_size, |
|
max_positions=utils.resolve_max_positions( |
|
task.max_positions(), *[m.max_positions() for m in models] |
|
), |
|
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, |
|
seed=cfg.common.seed, |
|
num_shards=cfg.distributed_training.distributed_world_size, |
|
shard_id=cfg.distributed_training.distributed_rank, |
|
num_workers=cfg.dataset.num_workers, |
|
data_buffer_size=cfg.dataset.data_buffer_size, |
|
).next_epoch_itr(shuffle=False) |
|
progress = progress_bar.progress_bar( |
|
itr, |
|
log_format=cfg.common.log_format, |
|
log_interval=cfg.common.log_interval, |
|
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), |
|
) |
|
|
|
|
|
gen_timer = StopwatchMeter() |
|
|
|
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} |
|
generator = task.build_generator( |
|
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs |
|
) |
|
|
|
|
|
tokenizer = task.build_tokenizer(cfg.tokenizer) |
|
bpe = task.build_bpe(cfg.bpe) |
|
|
|
def decode_fn(x): |
|
if bpe is not None: |
|
x = bpe.decode(x) |
|
if tokenizer is not None: |
|
x = tokenizer.decode(x) |
|
return x |
|
|
|
scorer = scoring.build_scorer(cfg.scoring, tgt_dict) |
|
|
|
num_sentences = 0 |
|
has_target = True |
|
wps_meter = TimeMeter() |
|
for sample in progress: |
|
sample = utils.move_to_cuda(sample) if use_cuda else sample |
|
if "net_input" not in sample: |
|
continue |
|
|
|
prefix_tokens = None |
|
if cfg.generation.prefix_size > 0: |
|
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] |
|
|
|
constraints = None |
|
if "constraints" in sample: |
|
constraints = sample["constraints"] |
|
|
|
gen_timer.start() |
|
hypos = task.inference_step( |
|
generator, |
|
models, |
|
sample, |
|
prefix_tokens=prefix_tokens, |
|
constraints=constraints, |
|
) |
|
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) |
|
gen_timer.stop(num_generated_tokens) |
|
|
|
for i, sample_id in enumerate(sample["id"].tolist()): |
|
has_target = sample["target"] is not None |
|
|
|
|
|
if "src_tokens" in sample["net_input"]: |
|
src_tokens = utils.strip_pad( |
|
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() |
|
) |
|
else: |
|
src_tokens = None |
|
|
|
target_tokens = None |
|
if has_target: |
|
target_tokens = ( |
|
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() |
|
) |
|
|
|
|
|
if align_dict is not None: |
|
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( |
|
sample_id |
|
) |
|
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( |
|
sample_id |
|
) |
|
else: |
|
if src_dict is not None: |
|
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) |
|
else: |
|
src_str = "" |
|
if has_target: |
|
target_str = tgt_dict.string( |
|
target_tokens, |
|
cfg.common_eval.post_process, |
|
escape_unk=True, |
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output( |
|
generator |
|
), |
|
) |
|
|
|
src_str = decode_fn(src_str) |
|
if has_target: |
|
target_str = decode_fn(target_str) |
|
|
|
if not cfg.common_eval.quiet: |
|
if src_dict is not None: |
|
print("S-{}\t{}".format(sample_id, src_str), file=output_file) |
|
if has_target: |
|
print("T-{}\t{}".format(sample_id, target_str), file=output_file) |
|
|
|
|
|
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): |
|
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( |
|
hypo_tokens=hypo["tokens"].int().cpu(), |
|
src_str=src_str, |
|
alignment=hypo["alignment"], |
|
align_dict=align_dict, |
|
tgt_dict=tgt_dict, |
|
remove_bpe=cfg.common_eval.post_process, |
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), |
|
) |
|
detok_hypo_str = decode_fn(hypo_str) |
|
if not cfg.common_eval.quiet: |
|
score = hypo["score"] / math.log(2) |
|
|
|
print( |
|
"H-{}\t{}\t{}".format(sample_id, score, hypo_str), |
|
file=output_file, |
|
) |
|
|
|
print( |
|
"D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), |
|
file=output_file, |
|
) |
|
print( |
|
"P-{}\t{}".format( |
|
sample_id, |
|
" ".join( |
|
map( |
|
lambda x: "{:.4f}".format(x), |
|
|
|
hypo["positional_scores"] |
|
.div_(math.log(2)) |
|
.tolist(), |
|
) |
|
), |
|
), |
|
file=output_file, |
|
) |
|
|
|
if cfg.generation.print_alignment == "hard": |
|
print( |
|
"A-{}\t{}".format( |
|
sample_id, |
|
" ".join( |
|
[ |
|
"{}-{}".format(src_idx, tgt_idx) |
|
for src_idx, tgt_idx in alignment |
|
] |
|
), |
|
), |
|
file=output_file, |
|
) |
|
if cfg.generation.print_alignment == "soft": |
|
print( |
|
"A-{}\t{}".format( |
|
sample_id, |
|
" ".join( |
|
[ |
|
",".join(src_probs) |
|
for src_probs in alignment |
|
] |
|
), |
|
), |
|
file=output_file, |
|
) |
|
|
|
if cfg.generation.print_step: |
|
print( |
|
"I-{}\t{}".format(sample_id, hypo["steps"]), |
|
file=output_file, |
|
) |
|
|
|
if cfg.generation.retain_iter_history: |
|
for step, h in enumerate(hypo["history"]): |
|
_, h_str, _ = utils.post_process_prediction( |
|
hypo_tokens=h["tokens"].int().cpu(), |
|
src_str=src_str, |
|
alignment=None, |
|
align_dict=None, |
|
tgt_dict=tgt_dict, |
|
remove_bpe=None, |
|
) |
|
print( |
|
"E-{}_{}\t{}".format(sample_id, step, h_str), |
|
file=output_file, |
|
) |
|
|
|
|
|
if has_target and j == 0: |
|
if align_dict is not None or cfg.common_eval.post_process is not None: |
|
|
|
target_tokens = tgt_dict.encode_line( |
|
target_str, add_if_not_exist=True |
|
) |
|
hypo_tokens = tgt_dict.encode_line( |
|
detok_hypo_str, add_if_not_exist=True |
|
) |
|
if hasattr(scorer, "add_string"): |
|
scorer.add_string(target_str, detok_hypo_str) |
|
else: |
|
scorer.add(target_tokens, hypo_tokens) |
|
|
|
wps_meter.update(num_generated_tokens) |
|
progress.log({"wps": round(wps_meter.avg)}) |
|
num_sentences += ( |
|
sample["nsentences"] if "nsentences" in sample else sample["id"].numel() |
|
) |
|
|
|
logger.info("NOTE: hypothesis and token scores are output in base 2") |
|
logger.info( |
|
"Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( |
|
num_sentences, |
|
gen_timer.n, |
|
gen_timer.sum, |
|
num_sentences / gen_timer.sum, |
|
1.0 / gen_timer.avg, |
|
) |
|
) |
|
if has_target: |
|
if cfg.bpe and not cfg.generation.sacrebleu: |
|
if cfg.common_eval.post_process: |
|
logger.warning( |
|
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" |
|
) |
|
else: |
|
logger.warning( |
|
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" |
|
) |
|
|
|
print( |
|
"Generate {} with beam={}: {}".format( |
|
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() |
|
), |
|
file=output_file, |
|
) |
|
|
|
return scorer |
|
|
|
|
|
def cli_main(): |
|
parser = options.get_generation_parser() |
|
args = options.parse_args_and_arch(parser) |
|
main(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|