Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os, glob | |
import sys | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
from dataclasses import dataclass, field | |
from fairseq import metrics, search | |
from fairseq.data import Dictionary, encoders | |
from fairseq.dataclass.configs import FairseqDataclass | |
from fairseq.tasks import register_task | |
from fairseq.tasks.fairseq_task import FairseqTask | |
from omegaconf import MISSING, II | |
import numpy as np | |
from argparse import Namespace | |
DBG=True if len(sys.argv) == 1 else False | |
if DBG: | |
from hubert_dataset import AVHubertDataset | |
from sequence_generator import SequenceGenerator | |
else: | |
from .hubert_dataset import AVHubertDataset | |
from .sequence_generator import SequenceGenerator | |
logger = logging.getLogger(__name__) | |
class LabelEncoder(object): | |
def __init__(self, dictionary: Dictionary) -> None: | |
self.dictionary = dictionary | |
def __call__(self, label: str) -> List[str]: | |
return self.dictionary.encode_line( | |
label, append_eos=False, add_if_not_exist=False, | |
) | |
class LabelEncoderS2SToken(object): | |
def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None: | |
self.bpe_tokenizer = bpe_tokenizer | |
self.dictionary = dictionary | |
def __call__(self, label: str) -> List[str]: | |
label = self.bpe_tokenizer.encode(label.lower()) | |
return self.dictionary.encode_line( | |
label, append_eos=True, add_if_not_exist=False, | |
).long() | |
def decode(self, tok, symbols_ignore=None): | |
tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore) | |
if self.bpe_tokenizer: | |
tok = self.bpe_tokenizer.decode(tok) | |
return tok | |
class AVHubertPretrainingConfig(FairseqDataclass): | |
input_modality: str = II("task.input_modality") #?? | |
data: str = field( | |
default=MISSING, metadata={"help": "path to data directory"} | |
) | |
labels: List[str] = field( | |
default_factory=lambda: ["ltr"], | |
metadata={ | |
"help": ( | |
"extension of the label files to load, frame-level labels for" | |
" pre-training, and sequence-level label for fine-tuning" | |
) | |
}, | |
) | |
label_dir: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "if set, looks for labels in this directory instead", | |
}, | |
) | |
label_rate: int = field( | |
default=-1, | |
metadata={"help": "label frame rate. -1 for sequence label"}, | |
) | |
sample_rate: int = field( | |
default=16_000, | |
metadata={ | |
"help": "target sample rate. audio files will be up/down " | |
"sampled to this rate" | |
}, | |
) | |
normalize: bool = field( | |
default=False, | |
metadata={ | |
"help": "if set, normalizes input to have 0 mean and unit variance" | |
}, | |
) | |
enable_padding: bool = field( | |
default=False, | |
metadata={"help": "pad shorter samples instead of cropping"}, | |
) | |
max_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "max sample size to keep in training"}, | |
) | |
min_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "min sample size to keep in training"}, | |
) | |
max_trim_sample_size: Optional[int] = field( | |
default=II("task.max_sample_size"), | |
metadata={"help": "max sample size to trim to for batching"}, | |
) | |
single_target: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "if set, AddTargetDatasets outputs same keys " | |
"as AddTargetDataset" | |
}, | |
) | |
random_crop: Optional[bool] = field( | |
default=True, | |
metadata={"help": "always crop from the beginning if false"}, | |
) | |
pad_audio: Optional[bool] = field( | |
default=False, | |
metadata={"help": "pad audio to the longest one in the batch if true"}, | |
) | |
pdb: Optional[bool] = field( | |
default=False, | |
metadata={"help": "pdb"}, | |
) | |
stack_order_audio: int = field( | |
default=1, | |
metadata={"help": "concatenate n consecutive audio frames for one step"}, | |
) | |
skip_verify: Optional[bool] = field( | |
default=False, | |
metadata={"help": "skip verifying label-audio alignment"}, | |
) | |
image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'}) | |
image_crop_size: int = field( | |
default=88, metadata={"help": "image ROI size"}) | |
image_mean: float = field( | |
default=0.421, metadata={"help": "image mean"}) | |
image_std: float = field( | |
default=0.165, metadata={"help": "image std"}) | |
modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'}) | |
is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'}) | |
tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'}) | |
tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'}) | |
noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'}) | |
noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) | |
noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'}) | |
noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'}) | |
fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"}) | |
class AVHubertPretrainingTask(FairseqTask): | |
cfg: AVHubertPretrainingConfig | |
def __init__( | |
self, | |
cfg: AVHubertPretrainingConfig, | |
) -> None: | |
super().__init__(cfg) | |
logger.info(f"current directory is {os.getcwd()}") | |
logger.info(f"AVHubertPretrainingTask Config {cfg}") | |
self.fine_tuning = cfg.fine_tuning | |
if cfg.fine_tuning: | |
self.state.add_factory("target_dictionary", self.load_dictionaries) | |
if cfg.is_s2s: | |
self.state.add_factory("s2s_tokenizer", self.load_tokenizer) | |
else: | |
self.state.add_factory("dictionaries", self.load_dictionaries) | |
self.blank_symbol = "<s>" | |
def source_dictionary(self) -> Optional[Dictionary]: | |
return None # self._source_dictionary | |
def target_dictionary(self) -> Optional[Dictionary]: | |
return self.state.target_dictionary # self._target_dictionary | |
def dictionaries(self) -> List[Dictionary]: | |
return self.state.dictionaries | |
def load_dictionaries(self): | |
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
dictionaries = [ | |
Dictionary.load(f"{label_dir}/dict.{label}.txt") | |
for label in self.cfg.labels | |
] | |
return dictionaries[0] if self.cfg.fine_tuning else dictionaries | |
def load_tokenizer(self): | |
bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model}) | |
bpe_tokenizer = encoders.build_bpe(bpe_args) | |
return bpe_tokenizer | |
def s2s_tokenizer(self): | |
return self.state.s2s_tokenizer | |
def setup_task( | |
cls, cfg: AVHubertPretrainingConfig, **kwargs | |
) -> "AVHubertPretrainingTask": | |
if cfg.pdb: | |
import pdb | |
pdb.set_trace() | |
return cls(cfg) | |
def get_label_dir(self) -> str: | |
if self.cfg.label_dir is None: | |
return self.cfg.data | |
return self.cfg.label_dir | |
def load_dataset(self, split: str, **kwargs) -> None: | |
manifest = f"{self.cfg.data}/{split}.tsv" | |
dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries | |
pad_list = [dictionary.pad() for dictionary in dictionaries] | |
eos_list = [dictionary.eos() for dictionary in dictionaries] | |
if not self.cfg.is_s2s: | |
procs = [LabelEncoder(dictionary) for dictionary in dictionaries] | |
else: | |
logger.info(f"Using tokenizer") | |
bpe_tokenizer = self.s2s_tokenizer | |
procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries] | |
paths = [ | |
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels | |
] | |
image_aug = self.cfg.image_aug if split == 'train' else False | |
noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr) | |
noise_num = self.cfg.noise_num # | |
self.datasets[split] = AVHubertDataset( | |
manifest, | |
sample_rate=self.cfg.sample_rate, | |
label_paths=paths, | |
label_rates=self.cfg.label_rate, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=self.cfg.max_sample_size, | |
min_keep_sample_size=self.cfg.min_sample_size, | |
max_sample_size=self.cfg.max_trim_sample_size, | |
pad_audio=self.cfg.pad_audio, | |
normalize=self.cfg.normalize, | |
store_labels=False, | |
random_crop=self.cfg.random_crop, | |
single_target=self.cfg.single_target, | |
stack_order_audio=self.cfg.stack_order_audio, | |
skip_verify=self.cfg.skip_verify, | |
image_mean=self.cfg.image_mean, | |
image_std=self.cfg.image_std, | |
image_crop_size=self.cfg.image_crop_size, | |
image_aug=image_aug, | |
modalities=self.cfg.modalities, | |
is_s2s=self.cfg.is_s2s, | |
noise_fn=noise_fn, | |
noise_prob=self.cfg.noise_prob, | |
noise_snr=noise_snr, | |
noise_num=noise_num | |
) | |
def max_positions(self) -> Tuple[int, int]: | |
return (sys.maxsize, sys.maxsize) | |
def filter_indices_by_size( | |
self, indices: np.array, *args, **kwargs | |
) -> np.array: | |
return indices | |
def build_generator( | |
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, | |
): | |
""" | |
Build a :class:`~fairseq.SequenceGenerator` instance for this | |
task. | |
Args: | |
models (List[~fairseq.models.FairseqModel]): ensemble of models | |
args (fairseq.dataclass.configs.GenerationConfig): | |
configuration object (dataclass) for generation | |
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass | |
through to SequenceGenerator | |
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): | |
If provided, this function constrains the beam search to | |
allowed tokens only at each step. The provided function | |
should take 2 arguments: the batch ID (`batch_id: int`) | |
and a unidimensional tensor of token ids (`inputs_ids: | |
torch.Tensor`). It has to return a `List[int]` with the | |
allowed tokens for the next generation step conditioned | |
on the previously generated tokens (`inputs_ids`) and | |
the batch ID (`batch_id`). This argument is useful for | |
constrained generation conditioned on the prefix, as | |
described in "Autoregressive Entity Retrieval" | |
(https://arxiv.org/abs/2010.00904) and | |
https://github.com/facebookresearch/GENRE. | |
""" | |
if getattr(args, "score_reference", False): | |
from fairseq.sequence_scorer import SequenceScorer | |
return SequenceScorer( | |
self.target_dictionary, | |
compute_alignment=getattr(args, "print_alignment", False), | |
) | |
# Choose search strategy. Defaults to Beam Search. | |
sampling = getattr(args, "sampling", False) | |
sampling_topk = getattr(args, "sampling_topk", -1) | |
sampling_topp = getattr(args, "sampling_topp", -1.0) | |
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) | |
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) | |
match_source_len = getattr(args, "match_source_len", False) | |
diversity_rate = getattr(args, "diversity_rate", -1) | |
constrained = getattr(args, "constraints", False) | |
if prefix_allowed_tokens_fn is None: | |
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) | |
if ( | |
sum( | |
int(cond) | |
for cond in [ | |
sampling, | |
diverse_beam_groups > 0, | |
match_source_len, | |
diversity_rate > 0, | |
] | |
) | |
> 1 | |
): | |
raise ValueError("Provided Search parameters are mutually exclusive.") | |
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" | |
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" | |
if sampling: | |
search_strategy = search.Sampling( | |
self.target_dictionary, sampling_topk, sampling_topp | |
) | |
elif diverse_beam_groups > 0: | |
search_strategy = search.DiverseBeamSearch( | |
self.target_dictionary, diverse_beam_groups, diverse_beam_strength | |
) | |
elif match_source_len: | |
# this is useful for tagging applications where the output | |
# length should match the input length, so we hardcode the | |
# length constraints for simplicity | |
search_strategy = search.LengthConstrainedBeamSearch( | |
self.target_dictionary, | |
min_len_a=1, | |
min_len_b=0, | |
max_len_a=1, | |
max_len_b=0, | |
) | |
elif diversity_rate > -1: | |
search_strategy = search.DiverseSiblingsSearch( | |
self.target_dictionary, diversity_rate | |
) | |
elif constrained: | |
search_strategy = search.LexicallyConstrainedBeamSearch( | |
self.target_dictionary, args.constraints | |
) | |
elif prefix_allowed_tokens_fn: | |
search_strategy = search.PrefixConstrainedBeamSearch( | |
self.target_dictionary, prefix_allowed_tokens_fn | |
) | |
else: | |
search_strategy = search.BeamSearch(self.target_dictionary) | |
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
if seq_gen_cls is None: | |
if getattr(args, "print_alignment", False): | |
seq_gen_cls = SequenceGeneratorWithAlignment | |
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment | |
else: | |
seq_gen_cls = SequenceGenerator | |
return seq_gen_cls( | |
models, | |
self.target_dictionary, | |
beam_size=getattr(args, "beam", 5), | |
max_len_a=getattr(args, "max_len_a", 0), | |
max_len_b=getattr(args, "max_len_b", 200), | |
min_len=getattr(args, "min_len", 1), | |
normalize_scores=(not getattr(args, "unnormalized", False)), | |
len_penalty=getattr(args, "lenpen", 1), | |
unk_penalty=getattr(args, "unkpen", 0), | |
temperature=getattr(args, "temperature", 1.0), | |
match_source_len=getattr(args, "match_source_len", False), | |
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
search_strategy=search_strategy, | |
**extra_gen_cls_kwargs, | |
) | |