import ast import logging import os import sys from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple from transformers import ( HfArgumentParser, Wav2Vec2Config, Wav2Vec2FeatureExtractor ) logger = logging.getLogger(__name__) @dataclass class ConfigArguments: """ Arguments to which config we are going to set up. """ output_dir: str = field( default=".", metadata={"help": "The output directory where the config will be written."}, ) name_or_path: Optional[str] = field( default=None, metadata={ "help": "The model checkpoint for weights initialization." "Don't set if you want to train a model from scratch." }, ) config_params: Optional[str] = field( default=None, metadata={"help": "Custom configuration for the specific `name_or_path`"} ) feature_extractor_params: Optional[str] = field( default=None, metadata={"help": "Custom feature extractor configuration for the specific `name_or_path`"} ) def __post_init__(self): if self.config_params: try: self.config_params = ast.literal_eval(self.config_params) except Exception as e: print(f"Your custom `config` parameters do not acceptable due to {e}") if self.feature_extractor_params: try: self.feature_extractor_params = ast.literal_eval(self.feature_extractor_params) except Exception as e: print(f"Your custom `feature_extractor` parameters do not acceptable due to {e}") def main(): parser = HfArgumentParser([ConfigArguments]) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] else: config_args = parser.parse_args_into_dataclasses()[0] # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) logger.setLevel(logging.INFO) logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.config_params}") if config_args.config_params and isinstance(config_args.config_params, dict): config = Wav2Vec2Config.from_pretrained( config_args.name_or_path, **config_args.config_params ) else: config = Wav2Vec2Config.from_pretrained( config_args.name_or_path, mask_time_length=10, mask_time_prob=0.05, diversity_loss_weight=0.1, num_negatives=100, do_stable_layer_norm=True, feat_extract_norm="layer", vocab_size=40 ) logger.info(f"Setting up feature_extractor {config_args.name_or_path} with extra params " f"{config_args.feature_extractor_params}") if config_args.feature_extractor_params and isinstance(config_args.feature_extractor_params, dict): feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( config_args.name_or_path, **config_args.feature_extractor_params ) else: feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( config_args.name_or_path, return_attention_mask=True ) logger.info(f"Your `config` saved here {config_args.output_dir}/config.json") config.save_pretrained(config_args.output_dir) logger.info(f"Your `feature_extractor` saved here {config_args.output_dir}/preprocessor_config.json") feature_extractor.save_pretrained(config_args.output_dir) if __name__ == '__main__': main()