{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "authorship_tag": "ABX9TyNiDU9ykIeYxO86Lmuid+ph", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "### Install packages and download models" ], "metadata": { "id": "yLqBa4uYPrqE" } }, { "cell_type": "code", "source": [ "%%shell\n", "git clone https://github.com/yl4579/StyleTTS2.git\n", "cd StyleTTS2\n", "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n", "sudo apt-get install espeak-ng\n", "git-lfs clone https://huggingface.co./yl4579/StyleTTS2-LibriTTS\n", "mv StyleTTS2-LibriTTS/Models ." ], "metadata": { "id": "H72WF06ZPrTF" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Download dataset (LJSpeech, 200 samples, ~15 minutes of data)\n", "\n", "You can definitely do it with fewer samples. This is just a proof of concept with 200 smaples." ], "metadata": { "id": "G398sL8wPzTB" } }, { "cell_type": "code", "source": [ "%cd StyleTTS2\n", "!rm -rf Data" ], "metadata": { "id": "kJuQUBrEPy5C" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!gdown --id 1vqz26D3yn7OXS2vbfYxfSnpLS6m6tOFP\n", "!unzip Data.zip" ], "metadata": { "id": "mDXW8ZZePuSb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Change the finetuning config\n", "\n", "Depending on the GPU you got, you may want to change the bacth size, max audio length, epiochs and so on." ], "metadata": { "id": "_AlBQREWU8ud" } }, { "cell_type": "code", "source": [ "config_path = \"Configs/config_ft.yml\"\n", "\n", "import yaml\n", "config = yaml.safe_load(open(config_path))" ], "metadata": { "id": "7uEITi0hU4I2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "config['data_params']['root_path'] = \"Data/wavs\"\n", "\n", "config['batch_size'] = 2 # not enough RAM\n", "config['max_len'] = 100 # not enough RAM\n", "config['loss_params']['joint_epoch'] = 110 # we do not do SLM adversarial training due to not enough RAM\n", "\n", "with open(config_path, 'w') as outfile:\n", " yaml.dump(config, outfile, default_flow_style=True)" ], "metadata": { "id": "TPTRgOKSVT4K" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Start finetuning\n" ], "metadata": { "id": "uUuB_19NWj2Y" } }, { "cell_type": "code", "source": [ "!python train_finetune.py --config_path ./Configs/config_ft.yml" ], "metadata": { "id": "HZVAD5GKWm-O" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Test the model quality\n", "\n", "Note that this mainly serves as a proof of concept due to RAM limitation of free Colab instances. A lot of settings are suboptimal. In the future when DDP works for train_second.py, we will also add mixed precision finetuning to save time and RAM. You can also add SLM adversarial training run if you have paid Colab services (such as A100 with 40G of RAM)." ], "metadata": { "id": "I0_7wsGkXGfc" } }, { "cell_type": "code", "source": [ "import nltk\n", "nltk.download('punkt')" ], "metadata": { "id": "OPLphjbncE7p" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "torch.manual_seed(0)\n", "torch.backends.cudnn.benchmark = False\n", "torch.backends.cudnn.deterministic = True\n", "\n", "import random\n", "random.seed(0)\n", "\n", "import numpy as np\n", "np.random.seed(0)\n", "\n", "# load packages\n", "import time\n", "import random\n", "import yaml\n", "from munch import Munch\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import torchaudio\n", "import librosa\n", "from nltk.tokenize import word_tokenize\n", "\n", "from models import *\n", "from utils import *\n", "from text_utils import TextCleaner\n", "textclenaer = TextCleaner()\n", "\n", "%matplotlib inline\n", "\n", "to_mel = torchaudio.transforms.MelSpectrogram(\n", " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", "mean, std = -4, 4\n", "\n", "def length_to_mask(lengths):\n", " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", " return mask\n", "\n", "def preprocess(wave):\n", " wave_tensor = torch.from_numpy(wave).float()\n", " mel_tensor = to_mel(wave_tensor)\n", " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", " return mel_tensor\n", "\n", "def compute_style(path):\n", " wave, sr = librosa.load(path, sr=24000)\n", " audio, index = librosa.effects.trim(wave, top_db=30)\n", " if sr != 24000:\n", " audio = librosa.resample(audio, sr, 24000)\n", " mel_tensor = preprocess(audio).to(device)\n", "\n", " with torch.no_grad():\n", " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n", " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n", "\n", " return torch.cat([ref_s, ref_p], dim=1)\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "# load phonemizer\n", "import phonemizer\n", "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", "\n", "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n", "\n", "# load pretrained ASR model\n", "ASR_config = config.get('ASR_config', False)\n", "ASR_path = config.get('ASR_path', False)\n", "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", "\n", "# load pretrained F0 model\n", "F0_path = config.get('F0_path', False)\n", "pitch_extractor = load_F0_models(F0_path)\n", "\n", "# load BERT model\n", "from Utils.PLBERT.util import load_plbert\n", "BERT_path = config.get('PLBERT_dir', False)\n", "plbert = load_plbert(BERT_path)\n", "\n", "model_params = recursive_munch(config['model_params'])\n", "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n", "_ = [model[key].eval() for key in model]\n", "_ = [model[key].to(device) for key in model]" ], "metadata": { "id": "jIIAoDACXJL0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n", "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))" ], "metadata": { "id": "eKXRAyyzcMpQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n", "params = params_whole['net']" ], "metadata": { "id": "ULuU9-VDb9Pk" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "for key in model:\n", " if key in params:\n", " print('%s loaded' % key)\n", " try:\n", " model[key].load_state_dict(params[key])\n", " except:\n", " from collections import OrderedDict\n", " state_dict = params[key]\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = k[7:] # remove `module.`\n", " new_state_dict[name] = v\n", " # load params\n", " model[key].load_state_dict(new_state_dict, strict=False)\n", "# except:\n", "# _load(params[key], model[key])\n", "_ = [model[key].eval() for key in model]" ], "metadata": { "id": "J-U29yIYc2ea" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule" ], "metadata": { "id": "jrPQ_Yrwc3n6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "sampler = DiffusionSampler(\n", " model.diffusion.diffusion,\n", " sampler=ADPM2Sampler(),\n", " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", " clamp=False\n", ")" ], "metadata": { "id": "n2CWYNoqc455" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n", " text = text.strip()\n", " ps = global_phonemizer.phonemize([text])\n", " ps = word_tokenize(ps[0])\n", " ps = ' '.join(ps)\n", " tokens = textclenaer(ps)\n", " tokens.insert(0, 0)\n", " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", "\n", " with torch.no_grad():\n", " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", " text_mask = length_to_mask(input_lengths).to(device)\n", "\n", " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n", "\n", " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n", " embedding=bert_dur,\n", " embedding_scale=embedding_scale,\n", " features=ref_s, # reference from the same speaker as the embedding\n", " num_steps=diffusion_steps).squeeze(1)\n", "\n", "\n", " s = s_pred[:, 128:]\n", " ref = s_pred[:, :128]\n", "\n", " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n", " s = beta * s + (1 - beta) * ref_s[:, 128:]\n", "\n", " d = model.predictor.text_encoder(d_en,\n", " s, input_lengths, text_mask)\n", "\n", " x, _ = model.predictor.lstm(d)\n", " duration = model.predictor.duration_proj(x)\n", "\n", " duration = torch.sigmoid(duration).sum(axis=-1)\n", " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", "\n", " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", " c_frame = 0\n", " for i in range(pred_aln_trg.size(0)):\n", " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", " c_frame += int(pred_dur[i].data)\n", "\n", " # encode prosody\n", " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", " if model_params.decoder.type == \"hifigan\":\n", " asr_new = torch.zeros_like(en)\n", " asr_new[:, :, 0] = en[:, :, 0]\n", " asr_new[:, :, 1:] = en[:, :, 0:-1]\n", " en = asr_new\n", "\n", " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", "\n", " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n", " if model_params.decoder.type == \"hifigan\":\n", " asr_new = torch.zeros_like(asr)\n", " asr_new[:, :, 0] = asr[:, :, 0]\n", " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n", " asr = asr_new\n", "\n", " out = model.decoder(asr,\n", " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", "\n", "\n", " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later" ], "metadata": { "id": "2x5kVb3nc_eY" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Synthesize speech" ], "metadata": { "id": "O159JnwCc6CC" } }, { "cell_type": "code", "source": [ "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.\n", "'''" ], "metadata": { "id": "ThciXQ6rc9Eq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# get a random reference in the training set, note that it doesn't matter which one you use\n", "path = \"Data/wavs/LJ001-0110.wav\"\n", "# this style vector ref_s can be saved as a parameter together with the model weights\n", "ref_s = compute_style(path)" ], "metadata": { "id": "jldPkJyCc83a" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "start = time.time()\n", "wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n", "rtf = (time.time() - start) / (len(wav) / 24000)\n", "print(f\"RTF = {rtf:5f}\")\n", "import IPython.display as ipd\n", "display(ipd.Audio(wav, rate=24000, normalize=False))" ], "metadata": { "id": "_mIU0jqDdQ-c" }, "execution_count": null, "outputs": [] } ] }