HuBERT / fairseq /dataclass /configs.py
aliabd
full working demo
d5175d3
raw
history blame
35.4 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, List, Optional
import torch
from fairseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES,
)
from omegaconf import II, MISSING
@dataclass
class FairseqDataclass:
"""fairseq base dataclass that supported fetching attributes and metas"""
_name: Optional[str] = None
@staticmethod
def name():
return None
def _get_all_attributes(self) -> List[str]:
return [k for k in self.__dataclass_fields__.keys()]
def _get_meta(
self, attribute_name: str, meta: str, default: Optional[Any] = None
) -> Any:
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
def _get_name(self, attribute_name: str) -> str:
return self.__dataclass_fields__[attribute_name].name
def _get_default(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default
):
return getattr(self, attribute_name)
f = self.__dataclass_fields__[attribute_name]
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
def _get_help(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "help")
def _get_argparse_const(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_const")
def _get_argparse_alias(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_alias")
def _get_choices(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "choices")
@dataclass
class CommonConfig(FairseqDataclass):
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
no_progress_bar: bool = field(
default=False, metadata={"help": "disable progress bar"}
)
log_interval: int = field(
default=100,
metadata={
"help": "log progress every N batches (when progress bar is disabled)"
},
)
log_format: Optional[LOG_FORMAT_CHOICES] = field(
default=None, metadata={"help": "log format to use"}
)
log_file: Optional[str] = field(
default=None, metadata={"help": "log file to copy metrics to."}
)
tensorboard_logdir: Optional[str] = field(
default=None,
metadata={
"help": "path to save logs for tensorboard, should match --logdir "
"of running tensorboard (default: no tensorboard logging)"
},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": "Weights and Biases project name to use for logging"},
)
azureml_logging: Optional[bool] = field(
default=False, metadata={"help": "Log scalars to AzureML context"},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
)
cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"})
tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"})
bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"})
memory_efficient_bf16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of BF16 training; implies --bf16"
},
)
fp16: bool = field(default=False, metadata={"help": "use FP16"})
memory_efficient_fp16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of FP16 training; implies --fp16"
},
)
fp16_no_flatten_grads: bool = field(
default=False, metadata={"help": "don't flatten FP16 grads tensor"}
)
fp16_init_scale: int = field(
default=2 ** 7, metadata={"help": "default FP16 loss scale"}
)
fp16_scale_window: Optional[int] = field(
default=None,
metadata={"help": "number of updates before increasing loss scale"},
)
fp16_scale_tolerance: float = field(
default=0.0,
metadata={
"help": "pct of updates that can overflow before decreasing the loss scale"
},
)
on_cpu_convert_precision: bool = field(
default=False,
metadata={
"help": "if set, the floating point conversion to fp16/bf16 runs on CPU. "
"This reduces bus transfer time and GPU memory usage."
}
)
min_loss_scale: float = field(
default=1e-4,
metadata={"help": "minimum FP16/AMP loss scale, after which training is stopped"},
)
threshold_loss_scale: Optional[float] = field(
default=None, metadata={"help": "threshold FP16 loss scale from below"}
)
amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"})
amp_batch_retries: int = field(
default=2,
metadata={"help": "number of retries of same batch after reducing loss scale with AMP"},
)
amp_init_scale: int = field(
default=2 ** 7, metadata={"help": "default AMP loss scale"}
)
amp_scale_window: Optional[int] = field(
default=None,
metadata={"help": "number of updates before increasing AMP loss scale"},
)
user_dir: Optional[str] = field(
default=None,
metadata={
"help": "path to a python module containing custom extensions (tasks and/or architectures)"
},
)
empty_cache_freq: int = field(
default=0,
metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"},
)
all_gather_list_size: int = field(
default=16384,
metadata={"help": "number of bytes reserved for gathering stats from workers"},
)
model_parallel_size: int = field(
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
)
quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"}
)
profile: bool = field(
default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
)
reset_logging: bool = field(
default=False,
metadata={
"help": "when using Hydra, reset the logging at the beginning of training"
},
)
suppress_crashes: bool = field(
default=False,
metadata={
"help": "suppress crashes when training with the hydra_train entry point so that the "
"main method can return a value (useful for sweeps)"
},
)
use_plasma_view: bool = field(
default=False, metadata={"help": "Store indices and sizes in shared memory"}
)
plasma_path: Optional[str] = field(
default="/tmp/plasma",
metadata={
"help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail."
},
)
@dataclass
class DistributedTrainingConfig(FairseqDataclass):
distributed_world_size: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "total number of GPUs across all nodes (default: all visible GPUs)"
},
)
distributed_num_procs: Optional[int] = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "total number of processes to fork (default: all visible GPUs)"
},
)
distributed_rank: Optional[int] = field(
default=0, metadata={"help": "rank of the current worker"}
)
distributed_backend: str = field(
default="nccl", metadata={"help": "distributed backend"}
)
distributed_init_method: Optional[str] = field(
default=None,
metadata={
"help": "typically tcp://hostname:port that will be used to "
"establish initial connetion"
},
)
distributed_port: int = field(
default=-1,
metadata={
"help": "port number (not required if using --distributed-init-method)"
},
)
device_id: int = field(
default=0,
metadata={
"help": "which GPU to use (usually configured automatically)",
"argparse_alias": "--local_rank",
},
)
distributed_no_spawn: bool = field(
default=False,
metadata={
"help": "do not spawn multiple processes even if multiple GPUs are visible"
},
)
ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field(
default="none", metadata={"help": "communication hook"}
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
fix_batches_to_gpus: bool = field(
default=False,
metadata={
"help": "don't shuffle batches between GPUs; this reduces overall "
"randomness and may affect precision but avoids the cost of re-reading the data"
},
)
find_unused_parameters: bool = field(
default=False,
metadata={
"help": "disable unused parameter detection (not applicable to "
"--ddp-backend=legacy_ddp)"
},
)
fast_stat_sync: bool = field(
default=False,
metadata={"help": "[deprecated] this is now defined per Criterion"},
)
heartbeat_timeout: int = field(
default=-1,
metadata={
"help": "kill the job if no progress is made in N seconds; "
"set to -1 to disable"
},
)
broadcast_buffers: bool = field(
default=False,
metadata={
"help": "Copy non-trainable parameters between GPUs, such as "
"batchnorm population statistics"
},
)
slowmo_momentum: Optional[float] = field(
default=None,
metadata={
"help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, "
"0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
},
)
slowmo_algorithm: str = field(
default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"}
)
localsgd_frequency: int = field(
default=3, metadata={"help": "Local SGD allreduce frequency"}
)
nprocs_per_node: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes"
},
)
pipeline_model_parallel: bool = field(
default=False,
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
)
pipeline_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the model into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_balance) "
"should equal the total number of layers in the model"
},
)
pipeline_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-balance argument"
},
)
pipeline_chunks: Optional[int] = field(
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
)
pipeline_encoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_encoder_balance) "
"should equal the total number of encoder layers in the model"
},
)
pipeline_encoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-encoder-balance argument"
},
)
pipeline_decoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_decoder_balance) "
"should equal the total number of decoder layers in the model"
},
)
pipeline_decoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-decoder-balance argument"
},
)
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field(
default="never",
metadata={"help": "checkpointing mode for pipeline model parallelism"},
)
zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"}
)
fp16: bool = II("common.fp16")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field(
default=False, metadata={"help": "don't reshard parameters after forward pass"},
)
fp32_reduce_scatter: bool = field(
default=False, metadata={"help": "reduce-scatter grads in FP32"},
)
cpu_offload: bool = field(
default=False, metadata={"help": "offload FP32 params to CPU"}
)
use_sharded_state: bool = field(
default=False, metadata={"help": "use sharded checkpoint files"},
)
@dataclass
class DatasetConfig(FairseqDataclass):
num_workers: int = field(
default=1, metadata={"help": "how many subprocesses to use for data loading"}
)
skip_invalid_size_inputs_valid_test: bool = field(
default=False,
metadata={"help": "ignore too long or too short lines in valid and test set"},
)
max_tokens: Optional[int] = field(
default=None, metadata={"help": "maximum number of tokens in a batch"}
)
batch_size: Optional[int] = field(
default=None,
metadata={
"help": "number of examples in a batch",
"argparse_alias": "--max-sentences",
},
)
required_batch_size_multiple: int = field(
default=8, metadata={"help": "batch size will be a multiplier of this value"}
)
required_seq_len_multiple: int = field(
default=1,
metadata={
"help": "maximum sequence length in batch will be a multiplier of this value"
},
)
dataset_impl: Optional[DATASET_IMPL_CHOICES] = field(
default=None, metadata={"help": "output dataset implementation"}
)
data_buffer_size: int = field(
default=10, metadata={"help": "Number of batches to preload"}
)
train_subset: str = field(
default="train",
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
)
valid_subset: str = field(
default="valid",
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)"
},
)
combine_valid_subsets: Optional[bool] = field(
default=None,
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)",
"argparse_alias": "--combine-val",
},
)
ignore_unused_valid_subsets: Optional[bool] = field(
default=False,
metadata={"help": "do not raise error if valid subsets are ignored"},
)
validate_interval: int = field(
default=1, metadata={"help": "validate every N epochs"}
)
validate_interval_updates: int = field(
default=0, metadata={"help": "validate every N updates"}
)
validate_after_updates: int = field(
default=0, metadata={"help": "dont validate until reaching this many updates"}
)
fixed_validation_seed: Optional[int] = field(
default=None, metadata={"help": "specified random seed for validation"}
)
disable_validation: bool = field(
default=False, metadata={"help": "disable validation"}
)
max_tokens_valid: Optional[int] = field(
default=II("dataset.max_tokens"),
metadata={
"help": "maximum number of tokens in a validation batch"
" (defaults to --max-tokens)"
},
)
batch_size_valid: Optional[int] = field(
default=II("dataset.batch_size"),
metadata={
"help": "batch size of the validation batch (defaults to --batch-size)",
"argparse_alias": "--max-sentences-valid",
},
)
max_valid_steps: Optional[int] = field(default=None, metadata={'help': 'How many batches to evaluate',
"argparse_alias": "--nval"})
curriculum: int = field(
default=0, metadata={"help": "don't shuffle batches for first N epochs"}
)
gen_subset: str = field(
default="test",
metadata={"help": "data subset to generate (train, valid, test)"},
)
num_shards: int = field(
default=1, metadata={"help": "shard generation over N shards"}
)
shard_id: int = field(
default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
)
@dataclass
class OptimizationConfig(FairseqDataclass):
max_epoch: int = field(
default=0, metadata={"help": "force stop training at specified epoch"}
)
max_update: int = field(
default=0, metadata={"help": "force stop training at specified update"}
)
stop_time_hours: float = field(
default=0,
metadata={
"help": "force stop training after specified cumulative time (if >0)"
},
)
clip_norm: float = field(
default=0.0, metadata={"help": "clip threshold of gradients"}
)
sentence_avg: bool = field(
default=False,
metadata={
"help": "normalize gradients by the number of sentences in a batch"
" (default is to normalize by number of tokens)"
},
)
update_freq: List[int] = field(
default_factory=lambda: [1],
metadata={"help": "update parameters every N_i batches, when in epoch i"},
)
lr: List[float] = field(
default_factory=lambda: [0.25],
metadata={
"help": "learning rate for the first N epochs; all epochs >N using LR_N"
" (note: this may be interpreted differently depending on --lr-scheduler)"
},
)
stop_min_lr: float = field(
default=-1.0,
metadata={"help": "stop training when the learning rate reaches this minimum"},
)
use_bmuf: bool = field(
default=False,
metadata={
"help": "specify global optimizer for syncing models on different GPUs/shards"
},
)
@dataclass
class CheckpointConfig(FairseqDataclass):
save_dir: str = field(
default="checkpoints", metadata={"help": "path to save checkpoints"}
)
restore_file: str = field(
default="checkpoint_last.pt",
metadata={
"help": "filename from which to load checkpoint "
"(default: <save-dir>/checkpoint_last.pt"
},
)
finetune_from_model: Optional[str] = field(
default=None,
metadata={
"help": "finetune from a pretrained model; note that meters and lr scheduler will be reset"
},
)
reset_dataloader: bool = field(
default=False,
metadata={
"help": "if set, does not reload dataloader state from the checkpoint"
},
)
reset_lr_scheduler: bool = field(
default=False,
metadata={
"help": "if set, does not load lr scheduler state from the checkpoint"
},
)
reset_meters: bool = field(
default=False,
metadata={"help": "if set, does not load meters from the checkpoint"},
)
reset_optimizer: bool = field(
default=False,
metadata={"help": "if set, does not load optimizer state from the checkpoint"},
)
optimizer_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override optimizer args when loading a checkpoint"
},
)
save_interval: int = field(
default=1, metadata={"help": "save a checkpoint every N epochs"}
)
save_interval_updates: int = field(
default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
)
keep_interval_updates: int = field(
default=-1,
metadata={
"help": "keep the last N checkpoints saved with --save-interval-updates"
},
)
keep_interval_updates_pattern: int = field(
default=-1,
metadata={
"help": "when used with --keep-interval-updates, skips deleting "
"any checkpoints with update X where "
"X %% keep_interval_updates_pattern == 0"
},
)
keep_last_epochs: int = field(
default=-1, metadata={"help": "keep last N epoch checkpoints"}
)
keep_best_checkpoints: int = field(
default=-1, metadata={"help": "keep best N checkpoints based on scores"}
)
no_save: bool = field(
default=False, metadata={"help": "don't save models or checkpoints"}
)
no_epoch_checkpoints: bool = field(
default=False, metadata={"help": "only store last and best checkpoints"}
)
no_last_checkpoints: bool = field(
default=False, metadata={"help": "don't store last checkpoints"}
)
no_save_optimizer_state: bool = field(
default=False,
metadata={"help": "don't save optimizer-state as part of checkpoint"},
)
best_checkpoint_metric: str = field(
default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
)
maximize_best_checkpoint_metric: bool = field(
default=False,
metadata={
"help": 'select the largest metric value for saving "best" checkpoints'
},
)
patience: int = field(
default=-1,
metadata={
"help": (
"early stop training if valid performance doesn't "
"improve for N consecutive validation runs; note "
"that this is influenced by --validate-interval"
)
},
)
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
load_checkpoint_on_all_dp_ranks: bool = field(
default=False,
metadata={
"help": "load checkpoints on all data parallel devices "
"(default: only load on rank 0 and broadcast to other devices)"
},
)
write_checkpoints_asynchronously: bool = field(
default=False,
metadata={
"help": (
"Write checkpoints asynchronously in a separate "
"thread. NOTE: This feature is currently being tested."
),
"argparse_alias": "--save-async",
},
)
model_parallel_size: int = II("common.model_parallel_size")
@dataclass
class FairseqBMUFConfig(FairseqDataclass):
block_lr: float = field(
default=1, metadata={"help": "block learning rate for bmuf"}
)
block_momentum: float = field(
default=0.875, metadata={"help": "block momentum for bmuf"}
)
global_sync_iter: int = field(
default=50, metadata={"help": "Iteration for syncing global model"}
)
warmup_iterations: int = field(
default=500, metadata={"help": "warmup iterations for model to broadcast"}
)
use_nbm: bool = field(
default=False,
metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"},
)
average_sync: bool = field(
default=False,
metadata={
"help": "Specify whether you want to average the local momentum after each sync"
},
)
distributed_world_size: int = II("distributed_training.distributed_world_size")
@dataclass
class GenerationConfig(FairseqDataclass):
beam: int = field(
default=5, metadata={"help": "beam size"},
)
nbest: int = field(
default=1, metadata={"help": "number of hypotheses to output"},
)
max_len_a: float = field(
default=0,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
max_len_b: int = field(
default=200,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
min_len: int = field(
default=1, metadata={"help": "minimum generation length"},
)
match_source_len: bool = field(
default=False, metadata={"help": "generations should match the source length"},
)
unnormalized: bool = field(
default=False, metadata={"help": "compare unnormalized hypothesis scores"},
)
no_early_stop: bool = field(
default=False, metadata={"help": "deprecated"},
)
no_beamable_mm: bool = field(
default=False, metadata={"help": "don't use BeamableMM in attention layers"},
)
lenpen: float = field(
default=1,
metadata={
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
unkpen: float = field(
default=0,
metadata={
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
},
)
replace_unk: Optional[str] = field(
default=None,
metadata={
"help": "perform unknown replacement (optionally with alignment dictionary)",
"argparse_const": "@@ ",
},
)
sacrebleu: bool = field(
default=False, metadata={"help": "score with sacrebleu"},
)
score_reference: bool = field(
default=False, metadata={"help": "just score the reference translation"},
)
prefix_size: int = field(
default=0,
metadata={"help": "initialize generation by target prefix of given length"},
)
no_repeat_ngram_size: int = field(
default=0,
metadata={
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
},
)
sampling: bool = field(
default=False,
metadata={"help": "sample hypotheses instead of using beam search"},
)
sampling_topk: int = field(
default=-1,
metadata={"help": "sample from top K likely next words instead of all words"},
)
sampling_topp: float = field(
default=-1.0,
metadata={
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
},
)
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
default=None,
metadata={
"help": "enables lexically constrained decoding",
"argparse_const": "ordered",
},
)
temperature: float = field(
default=1.0, metadata={"help": "temperature for generation"},
)
diverse_beam_groups: int = field(
default=-1, metadata={"help": "number of groups for Diverse Beam Search"},
)
diverse_beam_strength: float = field(
default=0.5,
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
)
diversity_rate: float = field(
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field(
default=None,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens "
"(valid options are: hard, soft, otherwise treated as hard alignment)",
"argparse_const": "hard",
},
)
print_step: bool = field(
default=False, metadata={"help": "print steps"},
)
lm_path: Optional[str] = field(
default=None, metadata={"help": "path to lm checkpoint for lm fusion"},
)
lm_weight: float = field(
default=0.0, metadata={"help": "weight for lm probs for lm fusion"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
default=0.0,
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
)
iter_decode_max_iter: int = field(
default=10, metadata={"help": "maximum iterations for iterative refinement."},
)
iter_decode_force_max_iter: bool = field(
default=False,
metadata={
"help": "if set, run exact the maximum number of iterations without early stop"
},
)
iter_decode_with_beam: int = field(
default=1,
metadata={
"help": "if > 1, model will generate translations varying by the lengths."
},
)
iter_decode_with_external_reranker: bool = field(
default=False,
metadata={
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
},
)
retain_iter_history: bool = field(
default=False,
metadata={
"help": "if set, decoding returns the whole history of iterative refinement"
},
)
retain_dropout: bool = field(
default=False, metadata={"help": "Use dropout at inference time"},
)
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
# retain_dropout_modules: Optional[List[str]] = field(
retain_dropout_modules: Any = field(
default=None,
metadata={
"help": "if set, only retain dropout for the specified modules; "
"if not set, then dropout will be retained for all modules"
},
)
# special decoding format for advanced decoding.
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
default=None,
metadata={"help": "special decoding format for advanced decoding."},
)
no_seed_provided: bool = field(
default=False,
metadata={"help": "if set, dont use seed for initializing random generators"},
)
@dataclass
class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"},
)
post_process: Optional[str] = field(
default=None,
metadata={
"help": (
"post-process text by removing BPE, letter segmentation, etc. "
"Valid options can be found in fairseq.data.utils.post_process."
),
"argparse_const": "subword_nmt",
"argparse_alias": "--remove-bpe",
},
)
quiet: bool = field(default=False, metadata={"help": "only print final scores"})
model_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override model args at generation that were used during model training"
},
)
results_path: Optional[str] = field(
default=None, metadata={"help": "path to save eval results (optional)"}
)
@dataclass
class EvalLMConfig(FairseqDataclass):
output_word_probs: bool = field(
default=False,
metadata={
"help": "if set, outputs words and their predicted log probabilities to standard output"
},
)
output_word_stats: bool = field(
default=False,
metadata={
"help": "if set, outputs word statistics such as word count, average probability, etc"
},
)
context_window: int = field(
default=0,
metadata={
"help": "ensures that every evaluated token has access to a context of at least this size, if possible"
},
)
softmax_batch: int = field(
default=sys.maxsize,
metadata={
"help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory"
},
)
@dataclass
class InteractiveConfig(FairseqDataclass):
buffer_size: int = field(
default=0,
metadata={
"help": "read this many sentences into a buffer before processing them"
},
)
input: str = field(
default="-", metadata={"help": "file to read from; use - for stdin"},
)
@dataclass
class FairseqConfig(FairseqDataclass):
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
optimization: OptimizationConfig = OptimizationConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
generation: GenerationConfig = GenerationConfig()
eval_lm: EvalLMConfig = EvalLMConfig()
interactive: InteractiveConfig = InteractiveConfig()
model: Any = MISSING
task: Any = None
criterion: Any = None
optimizer: Any = None
lr_scheduler: Any = None
scoring: Any = None
bpe: Any = None
tokenizer: Any = None