File size: 3,946 Bytes
c239b93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers import (
HfArgumentParser,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor
)
logger = logging.getLogger(__name__)
@dataclass
class ConfigArguments:
"""
Arguments to which config we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
config_params: Optional[str] = field(
default=None,
metadata={"help": "Custom configuration for the specific `name_or_path`"}
)
feature_extractor_params: Optional[str] = field(
default=None,
metadata={"help": "Custom feature extractor configuration for the specific `name_or_path`"}
)
def __post_init__(self):
if self.config_params:
try:
self.config_params = ast.literal_eval(self.config_params)
except Exception as e:
print(f"Your custom `config` parameters do not acceptable due to {e}")
if self.feature_extractor_params:
try:
self.feature_extractor_params = ast.literal_eval(self.feature_extractor_params)
except Exception as e:
print(f"Your custom `feature_extractor` parameters do not acceptable due to {e}")
def main():
parser = HfArgumentParser([ConfigArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
config_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.config_params}")
if config_args.config_params and isinstance(config_args.config_params, dict):
config = Wav2Vec2Config.from_pretrained(
config_args.name_or_path,
**config_args.config_params
)
else:
config = Wav2Vec2Config.from_pretrained(
config_args.name_or_path,
mask_time_length=10,
mask_time_prob=0.05,
diversity_loss_weight=0.1,
num_negatives=100,
do_stable_layer_norm=True,
feat_extract_norm="layer",
vocab_size=40
)
logger.info(f"Setting up feature_extractor {config_args.name_or_path} with extra params "
f"{config_args.feature_extractor_params}")
if config_args.feature_extractor_params and isinstance(config_args.feature_extractor_params, dict):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
config_args.name_or_path,
**config_args.feature_extractor_params
)
else:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
config_args.name_or_path,
return_attention_mask=True
)
logger.info(f"Your `config` saved here {config_args.output_dir}/config.json")
config.save_pretrained(config_args.output_dir)
logger.info(f"Your `feature_extractor` saved here {config_args.output_dir}/preprocessor_config.json")
feature_extractor.save_pretrained(config_args.output_dir)
if __name__ == '__main__':
main()
|