""" |
# Adapting the model |
python train_asr_adapter.py \ |
--config-path="../conf/asr_adapters" \ |
--config-name="asr_adaptation.yaml" \ |
model.pretrained_model=null \ |
model.nemo_model=null \ |
model.adapter.adapter_name=<Unique adapter name> \ |
model.adapter.adapter_type="<linear, tiny_attn, or others from config sub-sections of `adapter`>" \ |
model.adapter.adapter_module_name=<null, or str module. Type: encoder, decoder, joint, or multiple with + between them> \ |
model.adapter.linear.in_features=<dimension of the layer outputs of the model> \ |
model.adapter.linear.dim=32 \ |
model.adapter.linear.dropout=0.0 \ |
model.train_ds.manifest_filepath=<Path to manifest> \ |
model.train_ds.batch_size=16 \ |
model.validation_ds.manifest_filepath=<Path to manifest> \ |
model.validation_ds.batch_size=16 \ |
model.optim.lr=0.001 \ |
model.optim.weight_decay=0.0 \ |
model.optim.sched.warmup_steps=100 \ |
trainer.max_steps=300 \ |
trainer.devices=1 \ |
trainer.precision=32 \ |
exp_manager.exp_dir=<Some directory for experiment manager> |
# Hyper Parmaeter Search |
python train_asr_adapter.py \ |
--config-path="../conf/asr_adapters" \ |
--config-name="asr_adaptation_hp.yaml" \ |
-m \ |
model.pretrained_model=null \ |
model.nemo_model=null \ |
model.adapter.adapter_name=<Unique adapter name> \ |
model.adapter.adapter_type="<linear, tiny_attn, or others from config sub-sections of `adapter`>" \ |
model.adapter.adapter_module_name=<null, or str module. Type: encoder, decoder, joint, or multiple with + between them> \ |
model.adapter.linear.in_features=<dimension of the layer outputs of the model> \ |
model.train_ds.manifest_filepath=<Path to manifest> \ |
model.train_ds.batch_size=16 \ |
model.validation_ds.manifest_filepath=<Path to manifest> \ |
model.validation_ds.batch_size=16 \ |
exp_manager.exp_dir="<some directory>" \ |
exp_manager.create_wandb_logger=true \ |
exp_manager.wandb_logger_kwargs.project="<Project Name>" \ |
++delete_ckpt_after_train=True |
# Fine-tune a model |
While adaptation is very efficient for low-resource datasets, it imposes several restrictions - |
- The vocabulary of the new dataset must be supported by the pre-existing vocabulary or tokenizer. |
If tokens exist outside this scope, the adapter will have to learn UNK tokens (or fail entirely |
for character based models). |
- As a consequence of the above, the language of the new dataset must be the same as the original model. |
There is ongoing research to enable more sophisticated adapters for other languages. |
When adapters cannot be readily used due to the above limitations, fine-tuning may be a better alternative. |
For documentation on fine-tuning a model, please visit - |
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations |
# Pretrained Models |
For documentation on existing pretrained models, please visit - |
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html |
""" |
import os |
from dataclasses import is_dataclass |
import pytorch_lightning as pl |
from omegaconf import DictConfig, OmegaConf, open_dict |
from nemo.collections.asr.models import ASRModel |
from nemo.core import adapter_mixins |
from nemo.core.config import hydra_runner |
from nemo.utils import logging |
from nemo.utils.exp_manager import clean_exp_ckpt, exp_manager |
def update_model_config_to_support_adapter(model_cfg, current_cfg): |
with open_dict(model_cfg): |
model_cfg.log_prediction = current_cfg.model.get('log_prediction', False) |
adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_) |
if adapter_metadata is not None: |
model_cfg.encoder._target_ = adapter_metadata.adapter_class_path |
def update_model_cfg(original_cfg, new_cfg): |
with open_dict(original_cfg), open_dict(new_cfg): |
whitelist_keys = ['num_workers', 'pin_memory'] |
for wkey in whitelist_keys: |
if wkey in new_cfg: |
original_cfg[wkey] = new_cfg[wkey] |
print(f"Injecting white listed key `{wkey}` into config") |
new_keys = list(new_cfg.keys()) |
for key in new_keys: |
if key not in original_cfg: |
new_cfg.pop(key) |
print("Removing unavailable key from config :", key) |
new_cfg = OmegaConf.merge(original_cfg, new_cfg) |
return new_cfg |
def add_global_adapter_cfg(model, global_adapter_cfg): |
if is_dataclass(global_adapter_cfg): |
global_adapter_cfg = OmegaConf.structured(global_adapter_cfg) |
if not isinstance(global_adapter_cfg, DictConfig): |
global_adapter_cfg = DictConfig(global_adapter_cfg) |
with open_dict(global_adapter_cfg), open_dict(model.cfg): |
if 'adapters' not in model.cfg: |
model.cfg.adapters = OmegaConf.create({}) |
model.cfg.adapters[model.adapter_global_cfg_key] = global_adapter_cfg |
model.update_adapter_cfg(model.cfg.adapters) |
@hydra_runner(config_path="../conf/asr_adapters", config_name="asr_adaptation.yaml") |
def main(cfg): |
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
if cfg.model.pretrained_model is None and cfg.model.nemo_model is None: |
raise ValueError("Either set `cfg.model.nemo_model` or `cfg.model.pretrained_model`") |
if cfg.model.pretrained_model is not None and cfg.model.nemo_model is not None: |
raise ValueError("Cannot set both `cfg.model.nemo_model` and `cfg.model.pretrained_model`. Select one only.") |
trainer = pl.Trainer(**cfg.trainer) |
exp_log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) |
if cfg.model.pretrained_model is not None: |
model_cfg = ASRModel.from_pretrained(cfg.model.pretrained_model, return_config=True) |
update_model_config_to_support_adapter(model_cfg, cfg) |
model = ASRModel.from_pretrained(cfg.model.pretrained_model, override_config_path=model_cfg, trainer=trainer) |
else: |
model_cfg = ASRModel.restore_from(cfg.model.nemo_model, return_config=True) |
update_model_config_to_support_adapter(model_cfg, cfg) |
model = ASRModel.restore_from(cfg.model.nemo_model, override_config_path=model_cfg, trainer=trainer) |
cfg.model.train_ds = update_model_cfg(model.cfg.train_ds, cfg.model.train_ds) |
model.setup_training_data(cfg.model.train_ds) |
if 'validation_ds' in cfg.model: |
cfg.model.validation_ds = update_model_cfg(model.cfg.validation_ds, cfg.model.validation_ds) |
model.setup_multiple_validation_data(cfg.model.validation_ds) |
model.setup_optimization(cfg.model.optim) |
if 'spec_augment' in cfg.model: |
model.spec_augmentation = model.from_config_dict(cfg.model.spec_augment) |
else: |
model.spec_augmentation = None |
del model.cfg.spec_augment |
with open_dict(cfg.model.adapter): |
adapter_name = cfg.model.adapter.pop("adapter_name") |
adapter_type = cfg.model.adapter.pop("adapter_type") |
adapter_module_name = cfg.model.adapter.pop("adapter_module_name", None) |
adapter_state_dict_name = cfg.model.adapter.pop("adapter_state_dict_name", None) |
if adapter_type not in cfg.model.adapter.keys(): |
raise ValueError( |
f"Adapter type ({adapter_type}) config could not be found. Adapter setup config - \n" |
f"{OmegaConf.to_yaml(cfg.model.adapter)}" |
) |
adapter_type_cfg = cfg.model.adapter[adapter_type] |
print(f"Found `{adapter_type}` config :\n" f"{OmegaConf.to_yaml(adapter_type_cfg)}") |
if adapter_module_name is not None and ':' not in adapter_name: |
adapter_name = f'{adapter_module_name}:{adapter_name}' |
adapter_global_cfg = cfg.model.adapter.pop(model.adapter_global_cfg_key, None) |
if adapter_global_cfg is not None: |
add_global_adapter_cfg(model, adapter_global_cfg) |
model.add_adapter(adapter_name, cfg=adapter_type_cfg) |
assert model.is_adapter_available() |
model.set_enabled_adapters(enabled=False) |
model.set_enabled_adapters(adapter_name, enabled=True) |
model.freeze() |
model = model.train() |
model.unfreeze_enabled_adapters() |
model.cfg = model.cfg |
trainer.fit(model) |
if adapter_state_dict_name is not None: |
state_path = exp_log_dir if exp_log_dir is not None else os.getcwd() |
ckpt_path = os.path.join(state_path, "checkpoints") |
if os.path.exists(ckpt_path): |
state_path = ckpt_path |
state_path = os.path.join(state_path, adapter_state_dict_name) |
model.save_adapters(str(state_path)) |
if 'delete_ckpt_after_train' in cfg: |
delete_ckpt_after_train = cfg.delete_ckpt_after_train |
if delete_ckpt_after_train: |
clean_exp_ckpt(exp_log_dir, remove_ckpt=True, remove_nemo=False) |
if __name__ == '__main__': |
main() |