|
import ast |
|
import logging |
|
import os |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional, Tuple |
|
from transformers import ( |
|
HfArgumentParser, |
|
Wav2Vec2Config, |
|
Wav2Vec2FeatureExtractor |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ConfigArguments: |
|
""" |
|
Arguments to which config we are going to set up. |
|
""" |
|
output_dir: str = field( |
|
default=".", |
|
metadata={"help": "The output directory where the config will be written."}, |
|
) |
|
name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model checkpoint for weights initialization." |
|
"Don't set if you want to train a model from scratch." |
|
}, |
|
) |
|
config_params: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Custom configuration for the specific `name_or_path`"} |
|
) |
|
feature_extractor_params: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Custom feature extractor configuration for the specific `name_or_path`"} |
|
) |
|
|
|
def __post_init__(self): |
|
if self.config_params: |
|
try: |
|
self.config_params = ast.literal_eval(self.config_params) |
|
except Exception as e: |
|
print(f"Your custom `config` parameters do not acceptable due to {e}") |
|
|
|
if self.feature_extractor_params: |
|
try: |
|
self.feature_extractor_params = ast.literal_eval(self.feature_extractor_params) |
|
except Exception as e: |
|
print(f"Your custom `feature_extractor` parameters do not acceptable due to {e}") |
|
|
|
|
|
def main(): |
|
parser = HfArgumentParser([ConfigArguments]) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] |
|
else: |
|
config_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
logger.setLevel(logging.INFO) |
|
logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.config_params}") |
|
if config_args.config_params and isinstance(config_args.config_params, dict): |
|
config = Wav2Vec2Config.from_pretrained( |
|
config_args.name_or_path, |
|
**config_args.config_params |
|
) |
|
else: |
|
config = Wav2Vec2Config.from_pretrained( |
|
config_args.name_or_path, |
|
mask_time_length=10, |
|
mask_time_prob=0.05, |
|
diversity_loss_weight=0.1, |
|
num_negatives=100, |
|
do_stable_layer_norm=True, |
|
feat_extract_norm="layer", |
|
vocab_size=40 |
|
) |
|
|
|
logger.info(f"Setting up feature_extractor {config_args.name_or_path} with extra params " |
|
f"{config_args.feature_extractor_params}") |
|
if config_args.feature_extractor_params and isinstance(config_args.feature_extractor_params, dict): |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
config_args.name_or_path, |
|
**config_args.feature_extractor_params |
|
) |
|
else: |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
config_args.name_or_path, |
|
return_attention_mask=True |
|
) |
|
logger.info(f"Your `config` saved here {config_args.output_dir}/config.json") |
|
config.save_pretrained(config_args.output_dir) |
|
|
|
logger.info(f"Your `feature_extractor` saved here {config_args.output_dir}/preprocessor_config.json") |
|
feature_extractor.save_pretrained(config_args.output_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|