Spaces:
Runtime error
Runtime error
import fengshen.data.hubert.hubert_dataset as datasets | |
from fengshen.data.universal_datamodule import UniversalDataModule | |
from transformers import HubertConfig, HubertModel | |
# from transformers.models.hubert.modeling_hubert import _compute_mask_indices | |
import argparse | |
from fairseq.data import Dictionary | |
from pytorch_lightning import ( | |
LightningModule, | |
Trainer, | |
loggers, | |
) | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
import torch | |
import os | |
import torch.nn.functional as F | |
import torch.nn as nn | |
class LabelEncoder(object): | |
def __init__(self, dictionary: Dictionary): | |
self.dictionary = dictionary | |
def __call__(self, label: str): | |
return self.dictionary.encode_line( | |
label, | |
append_eos=False, | |
add_if_not_exist=False, | |
) | |
class HubertPretrainDataLoader(): | |
def __init__(self, args): | |
self.cfg = args | |
self.dictionaries = self.load_dictionaries() | |
self.load_datasets = {} | |
# TODO 改成HuggingFace Tokenizer | |
def load_dictionaries(self): | |
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
dictionaries = [ | |
Dictionary.load(f"{label_dir}/dict.{label}.txt") | |
for label in self.cfg.labels | |
] | |
return dictionaries | |
def get_label_dir(self): | |
if self.cfg.label_dir is None: | |
return self.cfg.data | |
return self.cfg.label_dir | |
def datasets(self): | |
return self.load_datasets | |
def load_dataset(self, split: str, **kwargs): | |
manifest = f"{self.cfg.data}/{split}.tsv" | |
dicts = self.dictionaries | |
pad_list = [dict.pad() for dict in dicts] | |
eos_list = [dict.eos() for dict in dicts] | |
procs = [LabelEncoder(dict) for dict in dicts] | |
paths = [f"{self.get_label_dir()}/{split}.{lb}" for lb in self.cfg.labels] | |
# hubert v1: pad_audio=True, random_crop=False; | |
self.load_datasets[split] = datasets.HubertDataset( | |
manifest, | |
sample_rate=self.cfg.sample_rate, | |
label_paths=paths, | |
label_rates=self.cfg.label_rate, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=self.cfg.max_keep_size, | |
min_keep_sample_size=self.cfg.min_sample_size, | |
max_sample_size=self.cfg.max_sample_size, | |
pad_audio=self.cfg.pad_audio, | |
normalize=self.cfg.normalize, | |
store_labels=False, | |
random_crop=self.cfg.random_crop, | |
single_target=self.cfg.single_target, | |
) | |
def perpare_data(args): | |
loader = HubertPretrainDataLoader(args) | |
loader.load_dataset('train') | |
loader.load_dataset('valid') | |
return loader | |
class HubertLightning(LightningModule): | |
def add_module_specific_args(parent_parser): | |
parser = parent_parser.add_argument_group('HuBert Lightning') | |
parser.add_argument('--pred_masked_weight', type=float, default=1.0) | |
parser.add_argument('--logit_temp', type=float, default=1.0) | |
parser.add_argument('--loss_weights', type=float, nargs='+') | |
# parser.add_argument('--mask_prob', type=float, default=0.65) | |
# parser.add_argument('--mask_length', type=int, default=10) | |
# parser.add_argument('--mask_selection', type=str, default='static', | |
# choice=["static", "uniform", "normal", "poisson"]) | |
# parser.add_argument('--mask_other', type=float, default=0) | |
# parser.add_argument('--no_mask_overlap', type=bool, default=False) | |
# parser.add_argument('--mask_min_space', type=int, default=1) | |
return parent_parser | |
def __init__(self, args, loader, ** kwargs) -> None: | |
super().__init__() | |
self.save_hyperparameters(args) | |
config = HubertConfig.from_pretrained(args.model_path) | |
self.config = config | |
self.model = HubertModel(config=config) | |
self.num_classes = [len(d) for d in loader.dictionaries] | |
self.label_embs_concat = nn.Parameter( | |
torch.FloatTensor(sum(self.num_classes), self.config.conv_dim[-1] // 2) | |
) | |
self.final_proj = nn.Linear( | |
self.config.hidden_size, self.config.conv_dim[-1] // 2 * len(loader.dictionaries) | |
) | |
nn.init.uniform_(self.label_embs_concat) | |
def setup(self, stage) -> None: | |
if stage == 'fit': | |
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
# Calculate total steps | |
if self.trainer.max_epochs > 0: | |
world_size = self.trainer.world_size | |
tb_size = self.hparams.train_batchsize * max(1, world_size) | |
ab_size = self.trainer.accumulate_grad_batches | |
self.total_steps = (len(train_loader.dataset) * | |
self.trainer.max_epochs // tb_size) // ab_size | |
else: | |
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches | |
print('Total steps: {}' .format(self.total_steps)) | |
def configure_optimizers(self): | |
from fengshen.models.model_utils import configure_optimizers | |
return configure_optimizers(self) | |
def compute_nce(self, x, pos, negs): | |
neg_is_pos = (pos == negs).all(-1) | |
pos = pos.unsqueeze(0) | |
targets = torch.cat([pos, negs], dim=0) | |
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) | |
logits /= self.hparams.logit_temp | |
if neg_is_pos.any(): | |
logits[1:][neg_is_pos] = float("-inf") | |
logits = logits.transpose(0, 1) # (num_x, num_cls+1) | |
return logits | |
def forward(self, **batch): | |
target_list = batch['target_list'] | |
padding_mask = batch['net_input']['padding_mask'] | |
input_values = batch['net_input']['source'] | |
output = self.model(input_values=input_values, | |
attention_mask=padding_mask, | |
target_list=target_list, | |
mask_time_indices=None, | |
return_dict=False) | |
def compute_pred(proj_x, target, label_embs): | |
# compute logits for the i-th label set | |
y = torch.index_select(label_embs, 0, target.long()) | |
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) | |
# proj_x: (S, D) | |
# y: (S, D) | |
# negs: (Neg, S, D) | |
return self.compute_nce(proj_x, y, negs) | |
label_embs_list = self.label_embs_concat.split(self.num_classes, 0) | |
x, extra_losses, target_list, mask_indices, padding_mask = output[ | |
0], output[-4], output[-3], output[-2], output[-1] | |
masked_indices = torch.logical_and(~padding_mask, mask_indices) | |
proj_x_m = self.final_proj(x[masked_indices]) | |
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) | |
logp_m_list = [ | |
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) | |
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) | |
] | |
targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list] | |
loss = 0.0 | |
loss_m_list = [] | |
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): | |
loss_m = F.cross_entropy(logp_m, targ_m) | |
loss_m_list.append(loss_m) | |
self.log(f"loss_m_{i}", loss_m.detach().item()) | |
loss += self.hparams.pred_masked_weight * sum(loss_m_list) | |
loss_weights = self.hparams.loss_weights | |
if loss_weights is not None: | |
if torch.is_tensor(extra_losses): | |
extra_losses = [extra_losses] | |
names = ['extra'] | |
if len(loss_weights) == 1 and len(extra_losses) != 1: | |
loss_weights = [loss_weights[0]] * len(extra_losses) | |
assert len(extra_losses) == len( | |
loss_weights | |
), f"{len(extra_losses)}, {len(loss_weights)}" | |
for p, n, coef in zip(extra_losses, names, loss_weights): | |
if coef != 0 and p is not None: | |
p = coef * p.float() | |
loss += p | |
self.log(f"loss_{n}", p.item()) | |
return {'loss': loss} | |
def training_step(self, batch, batch_idx): | |
output = self(**batch) | |
self.log('train_loss', output['loss']) | |
return output | |
def comput_metrix(self, logits, labels): | |
y_pred = torch.argmax(logits, dim=-1) | |
y_pred = y_pred.view(size=(-1,)) | |
y_true = labels.view(size=(-1,)).float() | |
corr = torch.eq(y_pred, y_true) | |
acc = torch.sum(corr.float()) / y_true.size()[0] | |
return acc | |
def validation_step(self, batch, batch_idx): | |
output = self(**batch) | |
# self.log('val_loss', output.loss, sync_dist=True) | |
# acc = self.comput_metrix(output.logits, batch['labels']) | |
# self.log('val_acc', acc, sync_dist=True) | |
return output | |
def on_save_checkpoint(self, checkpoint) -> None: | |
# Save the current loop info in the mid of epoch | |
# if you lightning <= 1.6.0 uncomment the line below | |
# checkpoint['loops'] = self.trainer.checkpoint_connector._get_loops_state_dict() | |
if self.trainer.global_rank == 0: | |
self.model.save_pretrained(os.path.join( | |
self.trainer.checkpoint_callback.dirpath, | |
'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step))) | |
def on_load_checkpoint(self, checkpoint) -> None: | |
global_step_offset = checkpoint["global_step"] | |
if 'global_samples' in checkpoint: | |
self.consumed_samples = checkpoint['global_samples'] | |
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset | |
if __name__ == '__main__': | |
args_parser = argparse.ArgumentParser() | |
from fengshen.utils import UniversalCheckpoint | |
from fengshen.models.model_utils import add_module_args | |
args_parser = add_module_args(args_parser) | |
args_parser = datasets.add_data_specific_args(args_parser) | |
args_parser = UniversalDataModule.add_data_specific_args(args_parser) | |
args_parser = Trainer.add_argparse_args(args_parser) | |
args_parser = HubertLightning.add_module_specific_args(args_parser) | |
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) | |
args_parser.add_argument('--ckpt_path', type=str, ) | |
args = args_parser.parse_args() | |
data_module = UniversalDataModule(args=args, tokenizer=None, collate_fn=None) | |
data_loader = perpare_data(args) | |
data_module.datasets = data_loader.datasets | |
module = HubertLightning(args, loader=data_loader) | |
lr_monitor = LearningRateMonitor(logging_interval='step') | |
logger = loggers.TensorBoardLogger(save_dir=os.path.join( | |
args.default_root_dir, 'logs/'), | |
name=os.path.basename(os.path.dirname(args.model_path))) | |
checkpoint_callback = UniversalCheckpoint(args).callbacks | |
if args.ckpt_path is not None and \ | |
not os.path.exists(args.ckpt_path): | |
print('--------warning no checkpoint found--------, remove args') | |
args.ckpt_path = None | |
trainer = Trainer.from_argparse_args(args, | |
logger=logger, | |
callbacks=[ | |
lr_monitor, | |
checkpoint_callback]) | |
trainer.fit(module, data_module, ckpt_path=args.ckpt_path) | |