|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
# Adapting the model |
|
|
|
python train_asr_adapter.py \ |
|
--config-path="../conf/asr_adapters" \ |
|
--config-name="asr_adaptation.yaml" \ |
|
model.pretrained_model=null \ |
|
model.nemo_model=null \ |
|
model.adapter.adapter_name=<Unique adapter name> \ |
|
model.adapter.adapter_type="<linear, tiny_attn, or others from config sub-sections of `adapter`>" \ |
|
model.adapter.adapter_module_name=<null, or str module. Type: encoder, decoder, joint, or multiple with + between them> \ |
|
model.adapter.linear.in_features=<dimension of the layer outputs of the model> \ |
|
model.adapter.linear.dim=32 \ |
|
model.adapter.linear.dropout=0.0 \ |
|
model.train_ds.manifest_filepath=<Path to manifest> \ |
|
model.train_ds.batch_size=16 \ |
|
model.validation_ds.manifest_filepath=<Path to manifest> \ |
|
model.validation_ds.batch_size=16 \ |
|
model.optim.lr=0.001 \ |
|
model.optim.weight_decay=0.0 \ |
|
model.optim.sched.warmup_steps=100 \ |
|
trainer.max_steps=300 \ |
|
trainer.devices=1 \ |
|
trainer.precision=32 \ |
|
exp_manager.exp_dir=<Some directory for experiment manager> |
|
|
|
# Hyper Parmaeter Search |
|
|
|
python train_asr_adapter.py \ |
|
--config-path="../conf/asr_adapters" \ |
|
--config-name="asr_adaptation_hp.yaml" \ |
|
-m \ |
|
model.pretrained_model=null \ |
|
model.nemo_model=null \ |
|
model.adapter.adapter_name=<Unique adapter name> \ |
|
model.adapter.adapter_type="<linear, tiny_attn, or others from config sub-sections of `adapter`>" \ |
|
model.adapter.adapter_module_name=<null, or str module. Type: encoder, decoder, joint, or multiple with + between them> \ |
|
model.adapter.linear.in_features=<dimension of the layer outputs of the model> \ |
|
model.train_ds.manifest_filepath=<Path to manifest> \ |
|
model.train_ds.batch_size=16 \ |
|
model.validation_ds.manifest_filepath=<Path to manifest> \ |
|
model.validation_ds.batch_size=16 \ |
|
exp_manager.exp_dir="<some directory>" \ |
|
exp_manager.create_wandb_logger=true \ |
|
exp_manager.wandb_logger_kwargs.project="<Project Name>" \ |
|
++delete_ckpt_after_train=True |
|
|
|
# Fine-tune a model |
|
|
|
While adaptation is very efficient for low-resource datasets, it imposes several restrictions - |
|
|
|
- The vocabulary of the new dataset must be supported by the pre-existing vocabulary or tokenizer. |
|
If tokens exist outside this scope, the adapter will have to learn UNK tokens (or fail entirely |
|
for character based models). |
|
|
|
- As a consequence of the above, the language of the new dataset must be the same as the original model. |
|
There is ongoing research to enable more sophisticated adapters for other languages. |
|
|
|
When adapters cannot be readily used due to the above limitations, fine-tuning may be a better alternative. |
|
|
|
For documentation on fine-tuning a model, please visit - |
|
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations |
|
|
|
# Pretrained Models |
|
|
|
For documentation on existing pretrained models, please visit - |
|
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html |
|
|
|
""" |
|
import os |
|
from dataclasses import is_dataclass |
|
|
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
|
|
from nemo.collections.asr.models import ASRModel |
|
from nemo.core import adapter_mixins |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import clean_exp_ckpt, exp_manager |
|
|
|
|
|
def update_model_config_to_support_adapter(model_cfg, current_cfg): |
|
with open_dict(model_cfg): |
|
|
|
model_cfg.log_prediction = current_cfg.model.get('log_prediction', False) |
|
|
|
|
|
adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_) |
|
if adapter_metadata is not None: |
|
model_cfg.encoder._target_ = adapter_metadata.adapter_class_path |
|
|
|
|
|
def update_model_cfg(original_cfg, new_cfg): |
|
with open_dict(original_cfg), open_dict(new_cfg): |
|
|
|
whitelist_keys = ['num_workers', 'pin_memory'] |
|
for wkey in whitelist_keys: |
|
if wkey in new_cfg: |
|
original_cfg[wkey] = new_cfg[wkey] |
|
print(f"Injecting white listed key `{wkey}` into config") |
|
|
|
|
|
new_keys = list(new_cfg.keys()) |
|
for key in new_keys: |
|
if key not in original_cfg: |
|
new_cfg.pop(key) |
|
print("Removing unavailable key from config :", key) |
|
|
|
new_cfg = OmegaConf.merge(original_cfg, new_cfg) |
|
return new_cfg |
|
|
|
|
|
def add_global_adapter_cfg(model, global_adapter_cfg): |
|
|
|
if is_dataclass(global_adapter_cfg): |
|
global_adapter_cfg = OmegaConf.structured(global_adapter_cfg) |
|
|
|
if not isinstance(global_adapter_cfg, DictConfig): |
|
global_adapter_cfg = DictConfig(global_adapter_cfg) |
|
|
|
|
|
with open_dict(global_adapter_cfg), open_dict(model.cfg): |
|
if 'adapters' not in model.cfg: |
|
model.cfg.adapters = OmegaConf.create({}) |
|
|
|
|
|
model.cfg.adapters[model.adapter_global_cfg_key] = global_adapter_cfg |
|
|
|
|
|
model.update_adapter_cfg(model.cfg.adapters) |
|
|
|
|
|
@hydra_runner(config_path="../conf/asr_adapters", config_name="asr_adaptation.yaml") |
|
def main(cfg): |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
if cfg.model.pretrained_model is None and cfg.model.nemo_model is None: |
|
raise ValueError("Either set `cfg.model.nemo_model` or `cfg.model.pretrained_model`") |
|
if cfg.model.pretrained_model is not None and cfg.model.nemo_model is not None: |
|
raise ValueError("Cannot set both `cfg.model.nemo_model` and `cfg.model.pretrained_model`. Select one only.") |
|
|
|
trainer = pl.Trainer(**cfg.trainer) |
|
exp_log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
if cfg.model.pretrained_model is not None: |
|
model_cfg = ASRModel.from_pretrained(cfg.model.pretrained_model, return_config=True) |
|
update_model_config_to_support_adapter(model_cfg, cfg) |
|
model = ASRModel.from_pretrained(cfg.model.pretrained_model, override_config_path=model_cfg, trainer=trainer) |
|
|
|
else: |
|
model_cfg = ASRModel.restore_from(cfg.model.nemo_model, return_config=True) |
|
update_model_config_to_support_adapter(model_cfg, cfg) |
|
model = ASRModel.restore_from(cfg.model.nemo_model, override_config_path=model_cfg, trainer=trainer) |
|
|
|
|
|
cfg.model.train_ds = update_model_cfg(model.cfg.train_ds, cfg.model.train_ds) |
|
model.setup_training_data(cfg.model.train_ds) |
|
|
|
if 'validation_ds' in cfg.model: |
|
cfg.model.validation_ds = update_model_cfg(model.cfg.validation_ds, cfg.model.validation_ds) |
|
model.setup_multiple_validation_data(cfg.model.validation_ds) |
|
|
|
|
|
model.setup_optimization(cfg.model.optim) |
|
|
|
|
|
if 'spec_augment' in cfg.model: |
|
model.spec_augmentation = model.from_config_dict(cfg.model.spec_augment) |
|
else: |
|
model.spec_augmentation = None |
|
del model.cfg.spec_augment |
|
|
|
|
|
with open_dict(cfg.model.adapter): |
|
|
|
adapter_name = cfg.model.adapter.pop("adapter_name") |
|
adapter_type = cfg.model.adapter.pop("adapter_type") |
|
adapter_module_name = cfg.model.adapter.pop("adapter_module_name", None) |
|
adapter_state_dict_name = cfg.model.adapter.pop("adapter_state_dict_name", None) |
|
|
|
|
|
if adapter_type not in cfg.model.adapter.keys(): |
|
raise ValueError( |
|
f"Adapter type ({adapter_type}) config could not be found. Adapter setup config - \n" |
|
f"{OmegaConf.to_yaml(cfg.model.adapter)}" |
|
) |
|
|
|
adapter_type_cfg = cfg.model.adapter[adapter_type] |
|
print(f"Found `{adapter_type}` config :\n" f"{OmegaConf.to_yaml(adapter_type_cfg)}") |
|
|
|
|
|
if adapter_module_name is not None and ':' not in adapter_name: |
|
adapter_name = f'{adapter_module_name}:{adapter_name}' |
|
|
|
|
|
adapter_global_cfg = cfg.model.adapter.pop(model.adapter_global_cfg_key, None) |
|
if adapter_global_cfg is not None: |
|
add_global_adapter_cfg(model, adapter_global_cfg) |
|
|
|
model.add_adapter(adapter_name, cfg=adapter_type_cfg) |
|
assert model.is_adapter_available() |
|
|
|
|
|
model.set_enabled_adapters(enabled=False) |
|
model.set_enabled_adapters(adapter_name, enabled=True) |
|
|
|
|
|
model.freeze() |
|
|
|
model = model.train() |
|
|
|
model.unfreeze_enabled_adapters() |
|
|
|
|
|
model.cfg = model.cfg |
|
|
|
|
|
trainer.fit(model) |
|
|
|
|
|
if adapter_state_dict_name is not None: |
|
state_path = exp_log_dir if exp_log_dir is not None else os.getcwd() |
|
ckpt_path = os.path.join(state_path, "checkpoints") |
|
if os.path.exists(ckpt_path): |
|
state_path = ckpt_path |
|
state_path = os.path.join(state_path, adapter_state_dict_name) |
|
|
|
|
|
model.save_adapters(str(state_path)) |
|
|
|
if 'delete_ckpt_after_train' in cfg: |
|
delete_ckpt_after_train = cfg.delete_ckpt_after_train |
|
if delete_ckpt_after_train: |
|
|
|
clean_exp_ckpt(exp_log_dir, remove_ckpt=True, remove_nemo=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|