# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb. # %% auto 0 __all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock', 'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer'] # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1 import io import time import math import random import dataclasses # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2 import torch import torch.nn as nn import torch.nn.functional as F from torch.profiler import profile, record_function, ProfilerActivity, schedule from fastcore.basics import store_attr from huggingface_hub import hf_hub_download # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3 from pathlib import Path import json from fastprogress import progress_bar, master_bar import webdataset as wds # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4 from .train import * from .modules import * from . import vq_stoks # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8 def rand(start, end): return random.random() * (end - start) + start # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9 def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750): atoks_per_second = atoks_len / 30 def _trunc(samples): for s in samples: if random.random() < random_trunc_p: seconds = rand(0.3, 30) s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)] s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)] yield s return _trunc def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096): def _pad(samples): for s in samples: s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token) s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100) yield s return _pad # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10 def speaker_id_extractor(speaker_map): def _extractor(samples): for s in samples: s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]]) yield s return _extractor # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14 def load_datasets( input:str, # webdataset folder samples:int, # samples per epoch subsample:float=1, # use a fraction of the files val_samples:int=512, random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds stoks_pad_token=4096, ): if isinstance(input, (Path, str)): path = Path(input) if path.is_dir(): glob = '*-s2a-*.tar.gz' else: glob = path.name path = path.parent input = Path(path).glob(glob) elif isinstance(input, list): pass else: raise ArgumentError("input should be either a list or a path with an optional glob specifier") shards = [str(x) for x in input] speakers = set() for shard in shards: with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines())) speakers = {id:i for i,id in enumerate(sorted(speakers))} def ds(shards, length): ds = wds.WebDataset(wds.ResampledShards(shards)).compose( wds.decode(), speaker_id_extractor(speakers), random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x, pad_samples(stoks_pad_token=stoks_pad_token), wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'), wds.batched(64), ) ds.speakers = speakers ds.total_samples = length return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64) return ( ds(shards[1:], samples), ds(shards[:1], val_samples), ) # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33 import pylab as plt import fastprogress import IPython import numpy as np class CMLMVisual: """Visualize training progress""" def __init__ (self, model, masterbar, total_steps): self.model = model self.masterbar = masterbar self.total_steps = total_steps self.epochs = total_steps // masterbar.main_bar.total gs = plt.GridSpec(3, 1, height_ratios=[2,2,1]) graph_fig = plt.figure(figsize=(10,6)) self.graph_fig = graph_fig self.loss_p = graph_fig.add_subplot(gs[0]) self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) self.acc_p.tick_params('x', labelbottom=False) self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p) self.lr_p.tick_params('x', labelbottom=False) self.graph_out = None self.its = [] self.train_losses = [] self.val_losses = [] self.lr_history = [] self.acc = np.nan self.acc_history = [] self.pacc_history = [] def show(self): self.start_t = time.time() self.masterbar.write(["samples", "train", "val", "time"], table=True) self.graph_out = display(self.graph_fig, display_id=True) self.acc_out = display(IPython.display.HTML(''), display_id=True) def hide(self): if self.graph_out is not None: self.graph_out.update(IPython.display.HTML('')) def plot(self): loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p loss_p.clear() loss_p.plot(self.its, self.train_losses) loss_p.plot(self.its, self.val_losses) loss_p.set_xlim(0, self.total_steps) loss_p.set_yscale('log') acc_p.clear() for k in self.acc_history[-1].keys(): acc_p.plot(self.its, [x[k] for x in self.acc_history], ':') # acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0]))) lr_p.clear() lrs = np.array(self.lr_history) lr_p.plot(self.its, lrs) self.graph_out.update(self.graph_fig) def add_data(self, it, lr, train_loss, val_los): self.its.append(it) self.train_losses.append(train_loss) self.val_losses.append(val_los) self.lr_history.append(lr) metrics = self.model.get_metrics() self.acc_history.append(metrics) # self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}") # self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy()) # if self.acc_history: html = "
Accuracies:
" html += ""+(''.join([f"" html += ""+(''.join([f"" html += "
{k}" for k,x in metrics.items()]))+"
{x*100:.1f}%" for k,x in metrics.items()]))+"
" self.acc_out.update(IPython.display.HTML(html)) self.plot() def add_table_row(self, it, avg_train_loss, val_loss): elapsed_t = time.time() - self.start_t self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) def on_iter(self, bar, it, avg_train_loss, val_loss): epoch = math.ceil(it / self.total_steps * self.epochs) bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34 # modified from https://blog.eleuther.ai/rotary-embeddings/ import torch class Rotary(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x, seq_dim=1): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos()[None, :, None, :] self.sin_cached = emb.sin()[None, :, None, :] return self.cos_cached, self.sin_cached # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat( (-x2, x1), dim=-1 ) #@torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin) # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35 from torch import Tensor, nn import torch.nn.functional as F from typing import Dict, Iterable, Optional class ResidualAttentionBlock(nn.Module): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False, qk_scale: float = 1, ffn_mult: int = 4): super().__init__() self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) self.attn_ln = LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None ) self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * ffn_mult self.mlp = nn.Sequential( nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) ) self.mlp_ln = LayerNorm(n_state) def forward( self, x: Tensor, xa: Optional[Tensor] = None, causal = False, kv_cache: Optional[dict] = None, ): x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0] if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.mlp(self.mlp_ln(x)) return x class MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False): super().__init__() self.n_head = n_head self.sqrt_qk_scale = math.sqrt(qk_scale) self.query = QueryHead(n_state, n_state) self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state) self.out = nn.Linear(n_state, n_state) self.rotary = None if rope: self.rotary = Rotary(n_state // n_head) def forward( self, x: Tensor, xa: Optional[Tensor] = None, causal = False, kv_cache: Optional[dict] = None, ): q = self.query(x) if kv_cache is None or xa is None or self.key not in kv_cache: # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; # otherwise, perform key/value projections for self- or cross-attention as usual. k = self.key(x if xa is None else xa) v = self.value(x if xa is None else xa) else: # for cross-attention, calculate keys and values once and reuse in subsequent calls. k = kv_cache[self.key] v = kv_cache[self.value] if self.sqrt_qk_scale != 1: q *= self.sqrt_qk_scale k *= self.sqrt_qk_scale wv, qk = self.qkv_attention_pth20(q, k, v, causal) # wv, qk = self.qkv_attention_xformers(q, k, v, causal) return self.out(wv), qk def qkv_attention_pth20( self, q: Tensor, k: Tensor, v: Tensor, causal = False ): n_batch, n_ctx, n_state = q.shape q = q.view(*q.shape[:2], self.n_head, -1) k = k.view(*k.shape[:2], self.n_head, -1) v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) #print('before rot:', q.shape, k.shape) if self.rotary: q, k = apply_rotary_pos_emb(q, k, *self.rotary(k)) #print(' after rot:', q.shape, k.shape) k = k.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3) # modified for better performance under PyTorch 2.0 wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal) # previously we've returned q@k which we don't have now # since it's not actually used anywhere else, let's just keep two return values for compatibility return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None def qkv_attention_xformers( self, q: Tensor, k: Tensor, v: Tensor, causal = False ): n_batch, n_ctx, n_state = q.shape q = q.view(*q.shape[:2], self.n_head, -1) k = k.view(*k.shape[:2], self.n_head, -1) v = v.view(*v.shape[:2], self.n_head, -1) if self.rotary: q, k = apply_rotary_pos_emb(q, k, *self.rotary(k)) bias = xops.LowerTriangularMask() if causal else None wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias) # previously we've returned q@k which we don't have now # since it's not actually used anywhere else, let's just keep two return values for compatibility return wv.flatten(start_dim=2), None # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36 class DelSumDecoder(nn.Module): def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None): super().__init__() self.length = length width = n_head * head_width self.width = width self.codes = codes self.quantizers = quantizers self.linear_heads = linear_heads self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)]) if pos_embs is not None: self.register_buffer("positional_embedding", pos_embs) self.layers = nn.ModuleList([ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth)) ]) self.ln_post = LayerNorm(width) if self.linear_heads: self.heads = LinearHead(width, (codes+1) * quantizers, bias=False) else: self.splitter = nn.Sequential( nn.Linear(width, width * quantizers), nn.GELU(), ) self.heads = nn.ModuleList([ LinearHead(width, codes+1, bias=True) for _ in range(quantizers) ]) def forward(self, toks, xenc): b,_,n = toks.shape newn = min(n+1, self.length) embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device) for i in range(self.quantizers): embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device)) if i < n: embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1]) x = embs.to(xenc.dtype) for l in self.layers: x = l(x, xenc, causal=True) x = self.ln_post(x) if self.linear_heads: logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3) else: split = self.splitter(x).view(b,newn,self.quantizers,self.width) logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1) return logits class EmbeddingProjector(nn.Linear): pass def rand(start, end): return random.random() * (end - start) + start @dataclasses.dataclass class Tunables: init_std :float = 9 embeddings_std :float = 0.2 embeddings_lr_scale: float = 10 output_mult :float = 5.6 # FIXME: try separate mults for self and cross attention query_mult :float = .3 encoder_depth_ratio :float = 0.25 linear_heads :bool = False rope :bool = True lr0 :float = 3e-3 clip_gradient_norm :float = 2 weight_decay :float = 1e-3 warmup_steps :float = 2000 random :bool = False def __post_init__(self): # randomize the hyperparams if requested if self.random: self.init_std = 2*10**rand(0,1) self.embeddings_std = 10**rand(-1.7,-0.22) self.embeddings_lr_scale = 2**rand(2,4) self.output_mult = 2**rand(1.5,3) self.query_mult = 2**rand(-3,-1.3) self.encoder_depth_ratio = random.choice([0.25,0.5]) self.linear_heads = False self.rope = True self.lr0 = 3e-3 self.clip_gradient_norm = 10**rand(-1,1) self.warmup_steps = 100*(10**rand(1.18,1.3)) @staticmethod def upgrade(args): args = {k:v for k,v in args.items()} def old_default(name, value): if name not in args: args[name] = value old_default('rope', False) old_default('linear_heads', True) return args class SADelARTransformer(nn.Module): def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4, quantizers=8, speaker_map={"1":0}, tunables=Tunables()): super().__init__() self.quantizers = quantizers width = n_head * head_width store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map") self.width = width self.base_width = 3 * head_width self.tunables = tunables if stoks_width is None: stoks_width = width if spk_width is None: spk_width = width self.emb_factor = width != stoks_width self.spk_factor = width != spk_width if tunables.rope: self.positional_embeddings = None else: self.register_buffer('positional_embeddings', sinusoids(ctx_n, width)) self.speaker_embedding = nn.Embedding(len(speaker_map), width) self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width) if self.emb_factor: self.emb_to_hidden = nn.Linear(stoks_width, width) if self.spk_factor: self.spk_to_hidden = EmbeddingProjector(spk_width, width) qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width) encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio) decoder_depth = depth * 2 - encoder_depth self.encoder = nn.Sequential(*[ ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth) ]) self.ln_post = LayerNorm(width) self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale, length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult, depth=decoder_depth, quantizers=quantizers, linear_heads=tunables.linear_heads, rope=tunables.rope) self.register_buffer('val_true', torch.zeros(self.quantizers).cuda()) self.register_buffer('val_total', torch.zeros(self.quantizers).cuda()) self.apply(self.init_transformer) def setup(self, device): pass def load_frozen_semantic_embeddings(self, vqmodel): with torch.no_grad(): self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0] self.semantic_embedding.lr_scale = 0 def init_transformer(self, m): if isinstance(m, LinearHead): m.no_weight_decay = True torch.nn.init.constant_(m.weight, 0) elif isinstance(m, QueryHead): m.lr_scale = 1/(m.weight.shape[1] / self.base_width) torch.nn.init.constant_(m.weight, 0) elif isinstance(m, nn.Embedding): m.no_weight_decay = True m.lr_scale = self.tunables.embeddings_lr_scale std = self.tunables.embeddings_std torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) elif isinstance(m, EmbeddingProjector): m.lr_scale = self.tunables.embeddings_lr_scale/2 std = self.tunables.init_std torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) elif isinstance(m, nn.Linear): m.lr_scale = 1/(m.weight.shape[1] / self.base_width) std = self.tunables.init_std / m.weight.shape[1] torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) if m.bias is not None: torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std) elif isinstance(m, nn.LayerNorm): m.no_weight_decay = True torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.weight, 1) def embed_stoks(self, Stoks): b,n = Stoks.shape if self.stoks_len == 1500: # converts 50 toks/s to 75 toks/s by adding padding between every two tokens x = Stoks.reshape(b,n//2,2) x = x.repeat_interleave(2, -1)[:,:,:3] x[:,:,1] = 1024 x = x.reshape(b,n//2*3) else: # it's a lot easier with 25 toks/s x = Stoks.repeat_interleave(3, -1) # embed semantic tokens Sembs = self.semantic_embedding(x.to(torch.long)) if self.emb_factor: Sembs = self.emb_to_hidden(Sembs) return Sembs def forward(self, Stoks, Atoks, speakers, noloss=False): Atoks = Atoks.to(torch.long) semb = self.embed_stoks(Stoks) with record_function("encoder"): if self.positional_embeddings is not None: semb = semb + self.positional_embeddings xenc = self.ln_post(self.encoder(semb)) # xenc = torch.zeros_like(xenc) with record_function("decoder"): Atoks_gt = Atoks.clone() Atoks_gt[Atoks == -100] = 1024 # we can randomize speaker ids during validation to measure # the importance of the speaker embedding vs. just the acoustic prompt/prefix # if not self.training: speakers = speakers[torch.randperm(speakers.nelement())] spk_embs = self.speaker_embedding(speakers) if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs) logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1)) logits *= self.tunables.output_mult / (self.width / self.base_width) if noloss: return logits with record_function("loss"): N = Atoks.shape[-1] loss = 0 for i in range(self.quantizers): loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1)) loss /= self.quantizers if not self.training: for i in range(self.quantizers): Atoks_i = Atoks[:,i,:N-i] valid_Atoks = Atoks_i != -100 self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum() self.val_total[i] += valid_Atoks.float().sum() return logits, loss def get_metrics(self): metrics = { f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total) } self.val_true[:] = 0 self.val_total[:] = 0 return metrics # # inference # @classmethod def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None): if not local_filename: local_filename = hf_hub_download(repo_id=repo_id, filename=filename) spec = torch.load(local_filename) if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] } model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables']))) model.load_state_dict(spec['state_dict']) model.eval() return model def get_extra_state(self): return { 'speaker_map': self.speaker_map } def set_extra_state(self, st): self.speaker_map = st['speaker_map'] def load_checkpoint(self, local_filename): spec = torch.load(local_filename, map_location='cpu') assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint' state_dict = {k.replace('model.', ''):v for k,v in spec['state_dict'].items()} self.load_state_dict(state_dict) return self def save_model(self, fname): torch.save(dict(config = self.__stored_args__, tunables = dataclasses.asdict(self.tunables), state_dict = self.state_dict()), fname) @property def device(self): return next(self.parameters()).device @torch.no_grad() def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True): dev = self.device if self.stoks_len == 1500: N = N or len(stoks) * 3 // 2 else: N = N or len(stoks) * 3 stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0) speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev) toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev) it = range(0,N) if show_progress_bar: it = progress_bar(it) for i in it: p = self(stoks, toks[:,:,:i], speakers, noloss=True) last_p = p[0,:,-1] if top_k: last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)): toks[0,j,max(0,i-j)] = tok if toks[0,0,i] == 1024: return toks[0,:,:i] return toks[0] # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37 def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs): assert(dataset is not None) kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs) if size == 'micro': return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs) if size == 'tiny-narrow': return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs) if size == 'tiny': return SADelARTransformer(depth=4, n_head=6, **kwargs) if size == 'base': return SADelARTransformer(depth=6, n_head=8, **kwargs) if size == 'base-deep': return SADelARTransformer(depth=9, n_head=8, **kwargs) if size == 'base-wide': return SADelARTransformer(depth=6, n_head=12, **kwargs) if size == 'small/2': return SADelARTransformer(depth=9, n_head=12, **kwargs) if size == 'small': return SADelARTransformer(depth=12, n_head=12, **kwargs) if size == 'medium': return SADelARTransformer(depth=24, n_head=16, **kwargs) def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None): if frozen_embeddings_model: vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1]) model.load_frozen_semantic_embeddings(vqmodel) else: model = _make_model(size, quantizers, tunables, dataset) return model