diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3e7e46f8cc5ff83646d60b5797be9ae3cca4a50f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -21,6 +21,8 @@ *.pkl filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text +*.t7 filter=lfs diff=lfs merge=lfs -text +OOD_texts.txt filter=lfs diff=lfs merge=lfs -text *.rar filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text saved_model/**/* filter=lfs diff=lfs merge=lfs -text @@ -32,4 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d43535ffc3eacc55701f9b02b6c888797c6ab1e1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,55 @@ +# Use specific version of nvidia cuda image +FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 + +# Remove any third-party apt sources to avoid issues with expiring keys. +RUN rm -f /etc/apt/sources.list.d/*.list + +# Set shell and noninteractive environment variables +SHELL ["/bin/bash", "-c"] +ENV DEBIAN_FRONTEND=noninteractive +ENV SHELL=/bin/bash + +# Set working directory +WORKDIR / + +# Update and upgrade the system packages (Worker Template) +RUN apt-get update -y && \ + apt-get upgrade -y && \ + apt-get install --yes --no-install-recommends sudo ca-certificates git wget curl bash libgl1 libx11-6 software-properties-common ffmpeg build-essential -y &&\ + apt-get autoremove -y && \ + apt-get clean -y && \ + rm -rf /var/lib/apt/lists/* + +# Add the deadsnakes PPA and install Python 3.10 +RUN add-apt-repository ppa:deadsnakes/ppa -y && \ + apt-get install python3.10-dev python3.10-venv python3-pip -y --no-install-recommends && \ + ln -s /usr/bin/python3.10 /usr/bin/python && \ + rm /usr/bin/python3 && \ + ln -s /usr/bin/python3.10 /usr/bin/python3 && \ + apt-get autoremove -y && \ + apt-get clean -y && \ + rm -rf /var/lib/apt/lists/* + +# Download and install pip +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \ + python get-pip.py && \ + rm get-pip.py + +# Install Python dependencies (Worker Template) +COPY builder/requirements.txt /requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade pip && \ + pip install -r /requirements.txt --no-cache-dir && \ + rm /requirements.txt +# Copy source code into image +COPY src . + +# Copy and run script to fetch models +COPY builder/fetch_models.py /fetch_models.py +RUN python /fetch_models.py && \ + rm /fetch_models.py + + + +# Set default command +CMD python -u /rp_handler.py \ No newline at end of file diff --git a/builder/fetch_models.py b/builder/fetch_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e35d38306264e8f5a6d9ad78bcd6294fb09deee7 --- /dev/null +++ b/builder/fetch_models.py @@ -0,0 +1,13 @@ +import se_extractor as se + +_ = se.generate_voice_segments('openai_source_output.mp3',vad=False) +_ = se.generate_voice_segments('openai_source_output.mp3',vad=True) + +from resemble_enhance.enhancer.inference import denoise, enhance +import torchaudio + + +dwav, sr = torchaudio.load('openai_source_output.mp3') +dwav = dwav.mean(dim=0) + +wav1, new_sr = enhance(dwav, sr, 'cuda:0', nfe=32, solver='midpoint', lambd=0.9, tau=0.5) \ No newline at end of file diff --git a/builder/requirements.txt b/builder/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..09920c29bbd4441eba828a115a0b872c81b5f227 --- /dev/null +++ b/builder/requirements.txt @@ -0,0 +1,280 @@ +accelerate==0.25.0 +aiofiles==23.2.1 +altair==5.2.0 +amqp==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.2.0 +asttokens==2.0.5 +astunparse==1.6.3 +async-timeout==4.0.3 +attrs==23.1.0 +audioread==3.0.1 +av==10.0.0 +Babel==2.14.0 +backcall==0.2.0 +beartype==0.16.4 +beautifulsoup4==4.12.2 +bibtexparser==2.0.0b4 +billiard==4.2.0 +boltons==23.0.0 +boto3==1.34.11 +botocore==1.34.11 +brotlipy==0.7.0 +cached-path==1.5.1 +cachetools==5.3.2 +celery==5.3.6 +celluloid==0.2.0 +certifi==2023.7.22 +cffi==1.15.1 +chardet==4.0.0 +charset-normalizer==2.0.4 +click==8.1.7 +click-didyoumean==0.3.0 +click-plugins==1.1.1 +click-repl==0.3.0 +clldutils==3.22.1 +cloudpickle==3.0.0 +cn2an==0.5.22 +colorama==0.4.6 +coloredlogs==15.0.1 +colorlog==6.8.0 +conda==23.9.0 +conda-build==3.27.0 +conda-content-trust==0.2.0 +conda_index==0.3.0 +conda-libmamba-solver==23.7.0 +conda-package-handling==2.2.0 +conda_package_streaming==0.9.0 +contourpy==1.2.0 +cryptography==41.0.3 +csvw==3.2.1 +ctranslate2==3.23.0 +cycler==0.12.1 +Cython==3.0.7 +dateparser==1.1.8 +decorator==5.1.1 +deepspeed==0.12.4 +distro==1.9.0 +dlinfo==1.2.1 +dnspython==2.4.2 +docopt==0.6.2 +dtw-python==1.3.1 +einops==0.7.0 +einops-exts==0.0.4 +email-validator==2.1.0.post1 +eng-to-ipa==0.0.2 +eventlet==0.34.2 +exceptiongroup==1.0.4 +executing==0.8.3 +expecttest==0.1.6 +fastapi==0.108.0 +faster-whisper==0.10.0 +ffmpy==0.3.1 +filelock==3.9.0 +flatbuffers==23.5.26 +fonttools==4.47.0 +fsspec==2023.9.2 +gmpy2==2.1.2 +google-api-core==2.15.0 +google-auth==2.25.2 +google-cloud-core==2.4.1 +google-cloud-storage==2.14.0 +google-crc32c==1.5.0 +google-resumable-media==2.7.0 +googleapis-common-protos==1.62.0 +gradio==4.8.0 +gradio_client==0.7.1 +greenlet==3.0.3 +gruut==2.3.4 +gruut-ipa==0.13.0 +gruut-lang-en==2.0.0 +h11==0.14.0 +hjson==3.1.0 +httpcore==1.0.2 +httptools==0.6.1 +httpx==0.26.0 +huggingface-hub==0.19.4 +humanfriendly==10.0 +hypothesis==6.87.1 +icontract==2.6.6 +idna==3.4 +importlib-resources==6.1.1 +inflect==7.0.0 +interegular==0.3.2 +ipython==8.15.0 +isodate==0.6.1 +itsdangerous==2.1.2 +jedi==0.18.1 +jieba==0.42.1 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.3.2 +jsonlines==1.2.0 +jsonpatch==1.32 +jsonpointer==2.1 +jsonschema==4.20.0 +jsonschema-specifications==2023.12.1 +kiwisolver==1.4.5 +kombu==5.3.4 +language-tags==1.2.0 +lark==1.1.8 +lazy_loader==0.3 +libarchive-c==2.9 +libmambapy==1.4.1 +librosa==0.10.1 +llvmlite==0.41.1 +lxml==5.0.0 +Markdown==3.5.1 +markdown-it-py==3.0.0 +MarkupSafe==2.1.1 +matplotlib==3.8.1 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mkl-fft==1.3.8 +mkl-random==1.2.4 +mkl-service==2.4.0 +monotonic_align==1.2 +more-itertools==8.12.0 +mpmath==1.3.0 +msgpack==1.0.7 +munch==4.0.0 +nest-asyncio==1.5.8 +networkx==2.8.8 +ninja==1.11.1.1 +nltk==3.8.1 +num2words==0.5.13 +numba==0.58.1 +numpy==1.26.2 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +onnxruntime==1.16.3 +openai==1.6.1 +openai-whisper==20231117 +orjson==3.9.10 +outlines==0.0.21 +packaging==23.1 +pandas==2.1.3 +parso==0.8.3 +perscache==0.6.1 +pexpect==4.8.0 +phonemizer==3.2.1 +pickleshare==0.7.5 +Pillow==9.4.0 +pip==23.2.1 +pkginfo==1.9.6 +platformdirs==4.1.0 +pluggy==1.0.0 +pooch==1.8.0 +proces==0.1.7 +progressbar==2.5 +prompt-toolkit==3.0.36 +protobuf==4.25.1 +psutil==5.9.0 +ptflops==0.7.1.2 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycosat==0.6.4 +pycparser==2.21 +pydantic==2.5.3 +pydantic_core==2.14.6 +pydantic-extra-types==2.3.0 +pydantic-settings==2.1.0 +pydub==0.25.1 +Pygments==2.15.1 +pylatexenc==2.10 +pynvml==11.5.0 +pyOpenSSL==23.2.0 +pyparsing==3.1.1 +pypinyin==0.50.0 +PySocks==1.7.1 +python-crfsuite==0.9.10 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +python-etcd==0.4.5 +python-multipart==0.0.6 +pytz==2023.3.post1 +PyYAML==6.0 +rdflib==7.0.0 +redis==5.0.1 +referencing==0.32.0 +regex==2023.12.25 +requests==2.31.0 +resampy==0.4.2 +resemble-enhance==0.0.1 +rfc3986==1.5.0 +rich==13.7.0 +rotary-embedding-torch==0.5.3 +rpds-py==0.16.2 +rsa==4.9 +ruamel.yaml==0.17.21 +ruamel.yaml.clib==0.2.6 +s3transfer==0.10.0 +safetensors==0.4.1 +scikit-learn==1.3.2 +scipy==1.11.4 +segments==2.2.1 +semantic-version==2.10.0 +setuptools==68.0.0 +shellingham==1.5.4 +six==1.16.0 +sniffio==1.3.0 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soupsieve==2.5 +sox==1.4.1 +soxr==0.3.7 +stack-data==0.2.0 +starlette==0.32.0.post1 +sympy==1.11.1 +tabulate==0.8.10 +threadpoolctl==3.2.0 +tiktoken==0.5.2 +tokenizers==0.13.3 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.0 +torch==2.1.1 +torchaudio==2.1.1 +torchelastic==0.2.2 +torchvision==0.16.1 +tortoise-tts==3.0.0 +tqdm==4.66.1 +traitlets==5.7.1 +transformers==4.31.0 +triton==2.1.0 +truststore==0.8.0 +typer==0.9.0 +types-dataclasses==0.6.6 +typing==3.7.4.3 +typing_extensions==4.8.0 +tzdata==2023.4 +tzlocal==5.2 +ujson==5.9.0 +Unidecode==1.3.7 +uritemplate==4.1.1 +urllib3==1.26.16 +uuid==1.30 +uvicorn==0.25.0 +uvloop==0.19.0 +vine==5.1.0 +watchfiles==0.21.0 +wcwidth==0.2.5 +websockets==11.0.3 +wheel==0.41.2 +whisper-timestamped==1.14.2 +zstandard==0.19.0 diff --git a/src/.gitattributes b/src/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..3e88d4a57b579a66963f7e5df240b276e7a75bd6 --- /dev/null +++ b/src/.gitattributes @@ -0,0 +1,2 @@ +*.txt filter=lfs diff=lfs merge=lfs -text +*.t7 filter=lfs diff=lfs merge=lfs -text diff --git a/src/Configs/config.yml b/src/Configs/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..b74b8ee8a72f28f33edfaa3ea992342467e801cb --- /dev/null +++ b/src/Configs/config.yml @@ -0,0 +1,116 @@ +log_dir: "Models/LJSpeech" +first_stage_path: "first_stage.pth" +save_freq: 2 +log_interval: 10 +device: "cuda" +epochs_1st: 200 # number of epochs for first stage training (pre-training) +epochs_2nd: 100 # number of peochs for second stage training (joint training) +batch_size: 16 +max_len: 400 # maximum number of frames +pretrained_model: "" +second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage +load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters + +F0_path: "Utils/JDC/bst.t7" +ASR_config: "Utils/ASR/config.yml" +ASR_path: "Utils/ASR/epoch_00080.pth" +PLBERT_dir: 'Utils/PLBERT/' + +data_params: + train_data: "Data/train_list.txt" + val_data: "Data/val_list.txt" + root_path: "/local/LJSpeech-1.1/wavs" + OOD_data: "Data/OOD_texts.txt" + min_length: 50 # sample until texts with this size are obtained for OOD texts + +preprocess_params: + sr: 24000 + spect_params: + n_fft: 2048 + win_length: 1200 + hop_length: 300 + +model_params: + multispeaker: false + + dim_in: 64 + hidden_dim: 512 + max_conv_dim: 512 + n_layer: 3 + n_mels: 80 + + n_token: 178 # number of phoneme tokens + max_dur: 50 # maximum duration of a single phoneme + style_dim: 128 # style vector size + + dropout: 0.2 + + # config for decoder + decoder: + type: 'istftnet' # either hifigan or istftnet + resblock_kernel_sizes: [3,7,11] + upsample_rates : [10, 6] + upsample_initial_channel: 512 + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] + upsample_kernel_sizes: [20, 12] + gen_istft_n_fft: 20 + gen_istft_hop_size: 5 + + # speech language model config + slm: + model: 'microsoft/wavlm-base-plus' + sr: 16000 # sampling rate of SLM + hidden: 768 # hidden size of SLM + nlayers: 13 # number of layers of SLM + initial_channel: 64 # initial channels of SLM discriminator head + + # style diffusion model config + diffusion: + embedding_mask_proba: 0.1 + # transformer config + transformer: + num_layers: 3 + num_heads: 8 + head_features: 64 + multiplier: 2 + + # diffusion distribution config + dist: + sigma_data: 0.2 # placeholder for estimate_sigma_data set to false + estimate_sigma_data: true # estimate sigma_data from the current batch if set to true + mean: -3.0 + std: 1.0 + +loss_params: + lambda_mel: 5. # mel reconstruction loss + lambda_gen: 1. # generator loss + lambda_slm: 1. # slm feature matching loss + + lambda_mono: 1. # monotonic alignment loss (1st stage, TMA) + lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA) + TMA_epoch: 50 # TMA starting epoch (1st stage) + + lambda_F0: 1. # F0 reconstruction loss (2nd stage) + lambda_norm: 1. # norm reconstruction loss (2nd stage) + lambda_dur: 1. # duration loss (2nd stage) + lambda_ce: 20. # duration predictor probability output CE loss (2nd stage) + lambda_sty: 1. # style reconstruction loss (2nd stage) + lambda_diff: 1. # score matching loss (2nd stage) + + diff_epoch: 20 # style diffusion starting epoch (2nd stage) + joint_epoch: 50 # joint training starting epoch (2nd stage) + +optimizer_params: + lr: 0.0001 # general learning rate + bert_lr: 0.00001 # learning rate for PLBERT + ft_lr: 0.00001 # learning rate for acoustic modules + +slmadv_params: + min_len: 400 # minimum length of samples + max_len: 500 # maximum length of samples + batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size + iter: 10 # update the discriminator every this iterations of generator update + thresh: 5 # gradient norm above which the gradient is scaled + scale: 0.01 # gradient scaling factor for predictors from SLM discriminators + sig: 1.5 # sigma for differentiable duration modeling + \ No newline at end of file diff --git a/src/Configs/config_ft.yml b/src/Configs/config_ft.yml new file mode 100644 index 0000000000000000000000000000000000000000..00ae95fc734934a4639681ec9981222fc462c270 --- /dev/null +++ b/src/Configs/config_ft.yml @@ -0,0 +1,111 @@ +log_dir: "Models/LJSpeech" +save_freq: 5 +log_interval: 10 +device: "cuda" +epochs: 50 # number of finetuning epoch (1 hour of data) +batch_size: 8 +max_len: 400 # maximum number of frames +pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth" +second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage +load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters + +F0_path: "Utils/JDC/bst.t7" +ASR_config: "Utils/ASR/config.yml" +ASR_path: "Utils/ASR/epoch_00080.pth" +PLBERT_dir: 'Utils/PLBERT/' + +data_params: + train_data: "Data/train_list.txt" + val_data: "Data/val_list.txt" + root_path: "/local/LJSpeech-1.1/wavs" + OOD_data: "Data/OOD_texts.txt" + min_length: 50 # sample until texts with this size are obtained for OOD texts + +preprocess_params: + sr: 24000 + spect_params: + n_fft: 2048 + win_length: 1200 + hop_length: 300 + +model_params: + multispeaker: true + + dim_in: 64 + hidden_dim: 512 + max_conv_dim: 512 + n_layer: 3 + n_mels: 80 + + n_token: 178 # number of phoneme tokens + max_dur: 50 # maximum duration of a single phoneme + style_dim: 128 # style vector size + + dropout: 0.2 + + # config for decoder + decoder: + type: 'hifigan' # either hifigan or istftnet + resblock_kernel_sizes: [3,7,11] + upsample_rates : [10,5,3,2] + upsample_initial_channel: 512 + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] + upsample_kernel_sizes: [20,10,6,4] + + # speech language model config + slm: + model: 'microsoft/wavlm-base-plus' + sr: 16000 # sampling rate of SLM + hidden: 768 # hidden size of SLM + nlayers: 13 # number of layers of SLM + initial_channel: 64 # initial channels of SLM discriminator head + + # style diffusion model config + diffusion: + embedding_mask_proba: 0.1 + # transformer config + transformer: + num_layers: 3 + num_heads: 8 + head_features: 64 + multiplier: 2 + + # diffusion distribution config + dist: + sigma_data: 0.2 # placeholder for estimate_sigma_data set to false + estimate_sigma_data: true # estimate sigma_data from the current batch if set to true + mean: -3.0 + std: 1.0 + +loss_params: + lambda_mel: 5. # mel reconstruction loss + lambda_gen: 1. # generator loss + lambda_slm: 1. # slm feature matching loss + + lambda_mono: 1. # monotonic alignment loss (TMA) + lambda_s2s: 1. # sequence-to-sequence loss (TMA) + + lambda_F0: 1. # F0 reconstruction loss + lambda_norm: 1. # norm reconstruction loss + lambda_dur: 1. # duration loss + lambda_ce: 20. # duration predictor probability output CE loss + lambda_sty: 1. # style reconstruction loss + lambda_diff: 1. # score matching loss + + diff_epoch: 10 # style diffusion starting epoch + joint_epoch: 30 # joint training starting epoch + +optimizer_params: + lr: 0.0001 # general learning rate + bert_lr: 0.00001 # learning rate for PLBERT + ft_lr: 0.0001 # learning rate for acoustic modules + +slmadv_params: + min_len: 400 # minimum length of samples + max_len: 500 # maximum length of samples + batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size + iter: 10 # update the discriminator every this iterations of generator update + thresh: 5 # gradient norm above which the gradient is scaled + scale: 0.01 # gradient scaling factor for predictors from SLM discriminators + sig: 1.5 # sigma for differentiable duration modeling + diff --git a/src/Configs/config_libritts.yml b/src/Configs/config_libritts.yml new file mode 100644 index 0000000000000000000000000000000000000000..135d87260aa53cfb18b665333d44744ce5b4152a --- /dev/null +++ b/src/Configs/config_libritts.yml @@ -0,0 +1,113 @@ +log_dir: "Models/LibriTTS" +first_stage_path: "first_stage.pth" +save_freq: 1 +log_interval: 10 +device: "cuda" +epochs_1st: 50 # number of epochs for first stage training (pre-training) +epochs_2nd: 30 # number of peochs for second stage training (joint training) +batch_size: 16 +max_len: 300 # maximum number of frames +pretrained_model: "" +second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage +load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters + +F0_path: "Utils/JDC/bst.t7" +ASR_config: "Utils/ASR/config.yml" +ASR_path: "Utils/ASR/epoch_00080.pth" +PLBERT_dir: 'Utils/PLBERT/' + +data_params: + train_data: "Data/train_list.txt" + val_data: "Data/val_list.txt" + root_path: "" + OOD_data: "Data/OOD_texts.txt" + min_length: 50 # sample until texts with this size are obtained for OOD texts + +preprocess_params: + sr: 24000 + spect_params: + n_fft: 2048 + win_length: 1200 + hop_length: 300 + +model_params: + multispeaker: true + + dim_in: 64 + hidden_dim: 512 + max_conv_dim: 512 + n_layer: 3 + n_mels: 80 + + n_token: 178 # number of phoneme tokens + max_dur: 50 # maximum duration of a single phoneme + style_dim: 128 # style vector size + + dropout: 0.2 + + # config for decoder + decoder: + type: 'hifigan' # either hifigan or istftnet + resblock_kernel_sizes: [3,7,11] + upsample_rates : [10,5,3,2] + upsample_initial_channel: 512 + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] + upsample_kernel_sizes: [20,10,6,4] + + # speech language model config + slm: + model: 'microsoft/wavlm-base-plus' + sr: 16000 # sampling rate of SLM + hidden: 768 # hidden size of SLM + nlayers: 13 # number of layers of SLM + initial_channel: 64 # initial channels of SLM discriminator head + + # style diffusion model config + diffusion: + embedding_mask_proba: 0.1 + # transformer config + transformer: + num_layers: 3 + num_heads: 8 + head_features: 64 + multiplier: 2 + + # diffusion distribution config + dist: + sigma_data: 0.2 # placeholder for estimate_sigma_data set to false + estimate_sigma_data: true # estimate sigma_data from the current batch if set to true + mean: -3.0 + std: 1.0 + +loss_params: + lambda_mel: 5. # mel reconstruction loss + lambda_gen: 1. # generator loss + lambda_slm: 1. # slm feature matching loss + + lambda_mono: 1. # monotonic alignment loss (1st stage, TMA) + lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA) + TMA_epoch: 5 # TMA starting epoch (1st stage) + + lambda_F0: 1. # F0 reconstruction loss (2nd stage) + lambda_norm: 1. # norm reconstruction loss (2nd stage) + lambda_dur: 1. # duration loss (2nd stage) + lambda_ce: 20. # duration predictor probability output CE loss (2nd stage) + lambda_sty: 1. # style reconstruction loss (2nd stage) + lambda_diff: 1. # score matching loss (2nd stage) + + diff_epoch: 10 # style diffusion starting epoch (2nd stage) + joint_epoch: 15 # joint training starting epoch (2nd stage) + +optimizer_params: + lr: 0.0001 # general learning rate + bert_lr: 0.00001 # learning rate for PLBERT + ft_lr: 0.00001 # learning rate for acoustic modules + +slmadv_params: + min_len: 400 # minimum length of samples + max_len: 500 # maximum length of samples + batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size + iter: 20 # update the discriminator every this iterations of generator update + thresh: 5 # gradient norm above which the gradient is scaled + scale: 0.01 # gradient scaling factor for predictors from SLM discriminators + sig: 1.5 # sigma for differentiable duration modeling diff --git a/src/Configs/hg.yml b/src/Configs/hg.yml new file mode 100644 index 0000000000000000000000000000000000000000..8463e93b297bade9ce60032f3006faa1376f0fec --- /dev/null +++ b/src/Configs/hg.yml @@ -0,0 +1,21 @@ +{ASR_config: Utils/ASR/config.yml, ASR_path: Utils/ASR/epoch_00080.pth, F0_path: Utils/JDC/bst.t7, + PLBERT_dir: Utils/PLBERT/, batch_size: 8, data_params: {OOD_data: Data/OOD_texts.txt, + min_length: 50, root_path: '', train_data: Data/train_list.txt, val_data: Data/val_list.txt}, + device: cuda, epochs_1st: 40, epochs_2nd: 25, first_stage_path: first_stage.pth, + load_only_params: false, log_dir: Models/LibriTTS, log_interval: 10, loss_params: { + TMA_epoch: 4, diff_epoch: 0, joint_epoch: 0, lambda_F0: 1.0, lambda_ce: 20.0, + lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0, lambda_mono: 1.0, + lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 300, + model_params: {decoder: {resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, + 5]], resblock_kernel_sizes: [3, 7, 11], type: hifigan, upsample_initial_channel: 512, + upsample_kernel_sizes: [20, 10, 6, 4], upsample_rates: [10, 5, 3, 2]}, diffusion: { + dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.19926648961191362, + std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2, + num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512, + max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178, + slm: {hidden: 768, initial_channel: 64, model: microsoft/wavlm-base-plus, nlayers: 13, + sr: 16000}, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05, + lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 300, n_fft: 2048, + win_length: 1200}, sr: 24000}, pretrained_model: Models/LibriTTS/epoch_2nd_00002.pth, + save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5, + iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}} \ No newline at end of file diff --git a/src/Data/OOD_texts.txt b/src/Data/OOD_texts.txt new file mode 100644 index 0000000000000000000000000000000000000000..00d56ffe1e3a1140d0a93965ba2364056d68a205 --- /dev/null +++ b/src/Data/OOD_texts.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0989ef6a9873b711befefcbe60660ced7a65532359277f766f4db504c558a72 +size 31758898 diff --git a/src/Data/train_list.txt b/src/Data/train_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ea7b8da250cad73352d08675a2a31ad40044b03 --- /dev/null +++ b/src/Data/train_list.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a02392d09b88cb0dd5794d5aef056068b9741cde680c37fb34c607de83d77da0 +size 2195448 diff --git a/src/Data/val_list.txt b/src/Data/val_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bd97dbdff13ac01444d731edd95b2d53e1d1d23 --- /dev/null +++ b/src/Data/val_list.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e2a6f76b7698ce50ba199dfba60c784b758b3ef4981c05dffd0768db2934208 +size 17203 diff --git a/src/Models/epochs_2nd_00020.pth b/src/Models/epochs_2nd_00020.pth new file mode 100644 index 0000000000000000000000000000000000000000..b1508d92da82e15837d59ff8c2c9d63b591dd61e --- /dev/null +++ b/src/Models/epochs_2nd_00020.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1164ffe19a17449d2c722234cecaf2836b35a698fb8ffd42562d2663657dca0a +size 771390526 diff --git a/src/Modules/__init__.py b/src/Modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/Modules/__init__.py @@ -0,0 +1 @@ + diff --git a/src/Modules/__pycache__/__init__.cpython-310.pyc b/src/Modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6ed27c919cb31a5cf942b48efa68561a6b6744 Binary files /dev/null and b/src/Modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/Modules/__pycache__/discriminators.cpython-310.pyc b/src/Modules/__pycache__/discriminators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a7e60ed39913dfdf1c01560ff4960cce63226c Binary files /dev/null and b/src/Modules/__pycache__/discriminators.cpython-310.pyc differ diff --git a/src/Modules/__pycache__/hifigan.cpython-310.pyc b/src/Modules/__pycache__/hifigan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7e39a3849b12e55e972c134e938e3ed1f51a23 Binary files /dev/null and b/src/Modules/__pycache__/hifigan.cpython-310.pyc differ diff --git a/src/Modules/__pycache__/utils.cpython-310.pyc b/src/Modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b915a59a21b7f467c0044bfa7214d5b0a823be91 Binary files /dev/null and b/src/Modules/__pycache__/utils.cpython-310.pyc differ diff --git a/src/Modules/diffusion/__init__.py b/src/Modules/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/Modules/diffusion/__init__.py @@ -0,0 +1 @@ + diff --git a/src/Modules/diffusion/__pycache__/__init__.cpython-310.pyc b/src/Modules/diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f938c601bbeabc858d27ee14b256453d07042065 Binary files /dev/null and b/src/Modules/diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc b/src/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab0a59f6c05417d72b966293c95cf3b2296a0544 Binary files /dev/null and b/src/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc differ diff --git a/src/Modules/diffusion/__pycache__/modules.cpython-310.pyc b/src/Modules/diffusion/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef58b8f8077359081c1a604b80b0c21b561e9e0e Binary files /dev/null and b/src/Modules/diffusion/__pycache__/modules.cpython-310.pyc differ diff --git a/src/Modules/diffusion/__pycache__/sampler.cpython-310.pyc b/src/Modules/diffusion/__pycache__/sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bc649815cc06e6bc35e22c9b44d16633cc30da9 Binary files /dev/null and b/src/Modules/diffusion/__pycache__/sampler.cpython-310.pyc differ diff --git a/src/Modules/diffusion/__pycache__/utils.cpython-310.pyc b/src/Modules/diffusion/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcdddd7e3e511e28cb9db88efa4526821649caec Binary files /dev/null and b/src/Modules/diffusion/__pycache__/utils.cpython-310.pyc differ diff --git a/src/Modules/diffusion/diffusion.py b/src/Modules/diffusion/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d930ea37f463e661ddd0c0e2d3bb02cb8f55f67e --- /dev/null +++ b/src/Modules/diffusion/diffusion.py @@ -0,0 +1,92 @@ +from math import pi +from random import randint +from typing import Any, Optional, Sequence, Tuple, Union + +import torch +from einops import rearrange +from torch import Tensor, nn +from tqdm import tqdm + +from .utils import * +from .sampler import * + +""" +Diffusion Classes (generic for 1d data) +""" + + +class Model1d(nn.Module): + def __init__(self, unet_type: str = "base", **kwargs): + super().__init__() + diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) + self.unet = None + self.diffusion = None + + def forward(self, x: Tensor, **kwargs) -> Tensor: + return self.diffusion(x, **kwargs) + + def sample(self, *args, **kwargs) -> Tensor: + return self.diffusion.sample(*args, **kwargs) + + +""" +Audio Diffusion Classes (specific for 1d audio data) +""" + + +def get_default_model_kwargs(): + return dict( + channels=128, + patch_size=16, + multipliers=[1, 2, 4, 4, 4, 4, 4], + factors=[4, 4, 4, 2, 2, 2], + num_blocks=[2, 2, 2, 2, 2, 2], + attentions=[0, 0, 0, 1, 1, 1, 1], + attention_heads=8, + attention_features=64, + attention_multiplier=2, + attention_use_rel_pos=False, + diffusion_type="v", + diffusion_sigma_distribution=UniformDistribution(), + ) + + +def get_default_sampling_kwargs(): + return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True) + + +class AudioDiffusionModel(Model1d): + def __init__(self, **kwargs): + super().__init__(**{**get_default_model_kwargs(), **kwargs}) + + def sample(self, *args, **kwargs): + return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) + + +class AudioDiffusionConditional(Model1d): + def __init__( + self, + embedding_features: int, + embedding_max_length: int, + embedding_mask_proba: float = 0.1, + **kwargs, + ): + self.embedding_mask_proba = embedding_mask_proba + default_kwargs = dict( + **get_default_model_kwargs(), + unet_type="cfg", + context_embedding_features=embedding_features, + context_embedding_max_length=embedding_max_length, + ) + super().__init__(**{**default_kwargs, **kwargs}) + + def forward(self, *args, **kwargs): + default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba) + return super().forward(*args, **{**default_kwargs, **kwargs}) + + def sample(self, *args, **kwargs): + default_kwargs = dict( + **get_default_sampling_kwargs(), + embedding_scale=5.0, + ) + return super().sample(*args, **{**default_kwargs, **kwargs}) diff --git a/src/Modules/diffusion/modules.py b/src/Modules/diffusion/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bca6001a1efd58a500c07f345c44d3a8270c950a --- /dev/null +++ b/src/Modules/diffusion/modules.py @@ -0,0 +1,700 @@ +from math import floor, log, pi +from typing import Any, List, Optional, Sequence, Tuple, Union + +from .utils import * + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum + + +""" +Utils +""" + + +class AdaLayerNorm(nn.Module): + def __init__(self, style_dim, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.fc = nn.Linear(style_dim, channels * 2) + + def forward(self, x, s): + x = x.transpose(-1, -2) + x = x.transpose(1, -1) + + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) + + x = F.layer_norm(x, (self.channels,), eps=self.eps) + x = (1 + gamma) * x + beta + return x.transpose(1, -1).transpose(-1, -2) + + +class StyleTransformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + use_context_time: bool = True, + use_rel_pos: bool = False, + context_features_multiplier: int = 1, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + context_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + embedding_max_length: int = 512, + ): + super().__init__() + + self.blocks = nn.ModuleList( + [ + StyleTransformerBlock( + features=channels + context_embedding_features, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + style_dim=context_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + nn.Conv1d( + in_channels=channels + context_embedding_features, + out_channels=channels, + kernel_size=1, + ), + ) + + use_context_features = exists(context_features) + self.use_context_features = use_context_features + self.use_context_time = use_context_time + + if use_context_time or use_context_features: + context_mapping_features = channels + context_embedding_features + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + self.fixed_embedding = FixedEmbedding( + max_length=embedding_max_length, features=context_embedding_features + ) + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + + return mapping + + def run(self, x, time, embedding, features): + mapping = self.get_mapping(time, features) + x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1) + mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1) + + for block in self.blocks: + x = x + mapping + x = block(x, features) + + x = x.mean(axis=1).unsqueeze(1) + x = self.to_out(x) + x = x.transpose(-1, -2) + + return x + + def forward( + self, + x: Tensor, + time: Tensor, + embedding_mask_proba: float = 0.0, + embedding: Optional[Tensor] = None, + features: Optional[Tensor] = None, + embedding_scale: float = 1.0, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + fixed_embedding = self.fixed_embedding(embedding) + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + # Compute both normal and fixed embedding outputs + out = self.run(x, time, embedding=embedding, features=features) + out_masked = self.run(x, time, embedding=fixed_embedding, features=features) + # Scale conditional output using classifier-free guidance + return out_masked + (out - out_masked) * embedding_scale + else: + return self.run(x, time, embedding=embedding, features=features) + + return x + + +class StyleTransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + style_dim: int, + multiplier: int, + use_rel_pos: bool, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = StyleAttention( + features=features, + style_dim=style_dim, + num_heads=num_heads, + head_features=head_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + if self.use_cross_attention: + self.cross_attention = StyleAttention( + features=features, + style_dim=style_dim, + num_heads=num_heads, + head_features=head_features, + context_features=context_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward( + self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None + ) -> Tensor: + x = self.attention(x, s) + x + if self.use_cross_attention: + x = self.cross_attention(x, s, context=context) + x + x = self.feed_forward(x) + x + return x + + +class StyleAttention(nn.Module): + def __init__( + self, + features: int, + *, + style_dim: int, + head_features: int, + num_heads: int, + context_features: Optional[int] = None, + use_rel_pos: bool, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + ): + super().__init__() + self.context_features = context_features + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = AdaLayerNorm(style_dim, features) + self.norm_context = AdaLayerNorm(style_dim, context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + def forward( + self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x, s), self.norm_context(context, s) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + # Compute and return attention + return self.attention(q, k, v) + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + use_context_time: bool = True, + use_rel_pos: bool = False, + context_features_multiplier: int = 1, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + context_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + embedding_max_length: int = 512, + ): + super().__init__() + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels + context_embedding_features, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + nn.Conv1d( + in_channels=channels + context_embedding_features, + out_channels=channels, + kernel_size=1, + ), + ) + + use_context_features = exists(context_features) + self.use_context_features = use_context_features + self.use_context_time = use_context_time + + if use_context_time or use_context_features: + context_mapping_features = channels + context_embedding_features + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + self.fixed_embedding = FixedEmbedding( + max_length=embedding_max_length, features=context_embedding_features + ) + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + + return mapping + + def run(self, x, time, embedding, features): + mapping = self.get_mapping(time, features) + x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1) + mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1) + + for block in self.blocks: + x = x + mapping + x = block(x) + + x = x.mean(axis=1).unsqueeze(1) + x = self.to_out(x) + x = x.transpose(-1, -2) + + return x + + def forward( + self, + x: Tensor, + time: Tensor, + embedding_mask_proba: float = 0.0, + embedding: Optional[Tensor] = None, + features: Optional[Tensor] = None, + embedding_scale: float = 1.0, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + fixed_embedding = self.fixed_embedding(embedding) + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + # Compute both normal and fixed embedding outputs + out = self.run(x, time, embedding=embedding, features=features) + out_masked = self.run(x, time, embedding=fixed_embedding, features=features) + # Scale conditional output using classifier-free guidance + return out_masked + (out - out_masked) * embedding_scale + else: + return self.run(x, time, embedding=embedding, features=features) + + return x + + +""" +Attention Components +""" + + +class RelativePositionBias(nn.Module): + def __init__(self, num_buckets: int, max_distance: int, num_heads: int): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.num_heads = num_heads + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + @staticmethod + def _relative_position_bucket( + relative_position: Tensor, num_buckets: int, max_distance: int + ): + num_buckets //= 2 + ret = (relative_position >= 0).to(torch.long) * num_buckets + n = torch.abs(relative_position) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = ( + max_exact + + ( + torch.log(n.float() / max_exact) + / log(max_distance / max_exact) + * (num_buckets - max_exact) + ).long() + ) + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1) + ) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, num_queries: int, num_keys: int) -> Tensor: + i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device + q_pos = torch.arange(j - i, j, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1") + + relative_position_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance + ) + + bias = self.relative_attention_bias(relative_position_bucket) + bias = rearrange(bias, "m n h -> 1 h m n") + return bias + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + use_rel_pos: bool, + out_features: Optional[int] = None, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + self.use_rel_pos = use_rel_pos + mid_features = head_features * num_heads + + if use_rel_pos: + assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance) + self.rel_pos = RelativePositionBias( + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + num_heads=num_heads, + ) + if out_features is None: + out_features = features + + self.to_out = nn.Linear(in_features=mid_features, out_features=out_features) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + # Compute similarity matrix + sim = einsum("... n d, ... m d -> ... n m", q, k) + sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim + sim = sim * self.scale + # Get attention matrix with softmax + attn = sim.softmax(dim=-1) + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + use_rel_pos: bool, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + ): + super().__init__() + self.context_features = context_features + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + + self.attention = AttentionBase( + features, + out_features=out_features, + num_heads=num_heads, + head_features=head_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + # Compute and return attention + return self.attention(q, k, v) + + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + use_rel_pos: bool, + rel_pos_num_buckets: Optional[int] = None, + rel_pos_max_distance: Optional[int] = None, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features, + use_rel_pos=use_rel_pos, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_distance=rel_pos_max_distance, + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: + x = self.attention(x) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context) + x + x = self.feed_forward(x) + x + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding diff --git a/src/Modules/diffusion/sampler.py b/src/Modules/diffusion/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f16d0ffef33c690ed4927cc71c063871c474a93 --- /dev/null +++ b/src/Modules/diffusion/sampler.py @@ -0,0 +1,685 @@ +from math import atan, cos, pi, sin, sqrt +from typing import Any, Callable, List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce +from torch import Tensor + +from .utils import * + +""" +Diffusion Training +""" + +""" Distributions """ + + +class Distribution: + def __call__(self, num_samples: int, device: torch.device): + raise NotImplementedError() + + +class LogNormalDistribution(Distribution): + def __init__(self, mean: float, std: float): + self.mean = mean + self.std = std + + def __call__( + self, num_samples: int, device: torch.device = torch.device("cpu") + ) -> Tensor: + normal = self.mean + self.std * torch.randn((num_samples,), device=device) + return normal.exp() + + +class UniformDistribution(Distribution): + def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): + return torch.rand(num_samples, device=device) + + +class VKDistribution(Distribution): + def __init__( + self, + min_value: float = 0.0, + max_value: float = float("inf"), + sigma_data: float = 1.0, + ): + self.min_value = min_value + self.max_value = max_value + self.sigma_data = sigma_data + + def __call__( + self, num_samples: int, device: torch.device = torch.device("cpu") + ) -> Tensor: + sigma_data = self.sigma_data + min_cdf = atan(self.min_value / sigma_data) * 2 / pi + max_cdf = atan(self.max_value / sigma_data) * 2 / pi + u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf + return torch.tan(u * pi / 2) * sigma_data + + +""" Diffusion Classes """ + + +def pad_dims(x: Tensor, ndim: int) -> Tensor: + # Pads additional ndims to the right of the tensor + return x.view(*x.shape, *((1,) * ndim)) + + +def clip(x: Tensor, dynamic_threshold: float = 0.0): + if dynamic_threshold == 0.0: + return x.clamp(-1.0, 1.0) + else: + # Dynamic thresholding + # Find dynamic threshold quantile for each batch + x_flat = rearrange(x, "b ... -> b (...)") + scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) + # Clamp to a min of 1.0 + scale.clamp_(min=1.0) + # Clamp all values and scale + scale = pad_dims(scale, ndim=x.ndim - scale.ndim) + x = x.clamp(-scale, scale) / scale + return x + + +def to_batch( + batch_size: int, + device: torch.device, + x: Optional[float] = None, + xs: Optional[Tensor] = None, +) -> Tensor: + assert exists(x) ^ exists(xs), "Either x or xs must be provided" + # If x provided use the same for all batch items + if exists(x): + xs = torch.full(size=(batch_size,), fill_value=x).to(device) + assert exists(xs) + return xs + + +class Diffusion(nn.Module): + alias: str = "" + + """Base diffusion class""" + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + raise NotImplementedError("Diffusion class missing denoise_fn") + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + raise NotImplementedError("Diffusion class missing forward function") + + +class VDiffusion(Diffusion): + alias = "v" + + def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): + super().__init__() + self.net = net + self.sigma_distribution = sigma_distribution + + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: + angle = sigmas * pi / 2 + alpha = torch.cos(angle) + beta = torch.sin(angle) + return alpha, beta + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + batch_size, device = x_noisy.shape[0], x_noisy.device + sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) + return self.net(x_noisy, sigmas, **kwargs) + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + batch_size, device = x.shape[0], x.device + + # Sample amount of noise to add for each batch element + sigmas = self.sigma_distribution(num_samples=batch_size, device=device) + sigmas_padded = rearrange(sigmas, "b -> b 1 1") + + # Get noise + noise = default(noise, lambda: torch.randn_like(x)) + + # Combine input and noise weighted by half-circle + alpha, beta = self.get_alpha_beta(sigmas_padded) + x_noisy = x * alpha + noise * beta + x_target = noise * alpha - x * beta + + # Denoise and return loss + x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs) + return F.mse_loss(x_denoised, x_target) + + +class KDiffusion(Diffusion): + """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364""" + + alias = "k" + + def __init__( + self, + net: nn.Module, + *, + sigma_distribution: Distribution, + sigma_data: float, # data distribution standard deviation + dynamic_threshold: float = 0.0, + ): + super().__init__() + self.net = net + self.sigma_data = sigma_data + self.sigma_distribution = sigma_distribution + self.dynamic_threshold = dynamic_threshold + + def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: + sigma_data = self.sigma_data + c_noise = torch.log(sigmas) * 0.25 + sigmas = rearrange(sigmas, "b -> b 1 1") + c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2) + c_out = sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5 + c_in = (sigmas**2 + sigma_data**2) ** -0.5 + return c_skip, c_out, c_in, c_noise + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + batch_size, device = x_noisy.shape[0], x_noisy.device + sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) + + # Predict network output and add skip connection + c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) + x_denoised = c_skip * x_noisy + c_out * x_pred + + return x_denoised + + def loss_weight(self, sigmas: Tensor) -> Tensor: + # Computes weight depending on data distribution + return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2 + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + batch_size, device = x.shape[0], x.device + from einops import rearrange, reduce + + # Sample amount of noise to add for each batch element + sigmas = self.sigma_distribution(num_samples=batch_size, device=device) + sigmas_padded = rearrange(sigmas, "b -> b 1 1") + + # Add noise to input + noise = default(noise, lambda: torch.randn_like(x)) + x_noisy = x + sigmas_padded * noise + + # Compute denoised values + x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) + + # Compute weighted loss + losses = F.mse_loss(x_denoised, x, reduction="none") + losses = reduce(losses, "b ... -> b", "mean") + losses = losses * self.loss_weight(sigmas) + loss = losses.mean() + return loss + + +class VKDiffusion(Diffusion): + alias = "vk" + + def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): + super().__init__() + self.net = net + self.sigma_distribution = sigma_distribution + + def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: + sigma_data = 1.0 + sigmas = rearrange(sigmas, "b -> b 1 1") + c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2) + c_out = -sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5 + c_in = (sigmas**2 + sigma_data**2) ** -0.5 + return c_skip, c_out, c_in + + def sigma_to_t(self, sigmas: Tensor) -> Tensor: + return sigmas.atan() / pi * 2 + + def t_to_sigma(self, t: Tensor) -> Tensor: + return (t * pi / 2).tan() + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + batch_size, device = x_noisy.shape[0], x_noisy.device + sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) + + # Predict network output and add skip connection + c_skip, c_out, c_in = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) + x_denoised = c_skip * x_noisy + c_out * x_pred + return x_denoised + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + batch_size, device = x.shape[0], x.device + + # Sample amount of noise to add for each batch element + sigmas = self.sigma_distribution(num_samples=batch_size, device=device) + sigmas_padded = rearrange(sigmas, "b -> b 1 1") + + # Add noise to input + noise = default(noise, lambda: torch.randn_like(x)) + x_noisy = x + sigmas_padded * noise + + # Compute model output + c_skip, c_out, c_in = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) + + # Compute v-objective target + v_target = (x - c_skip * x_noisy) / (c_out + 1e-7) + + # Compute loss + loss = F.mse_loss(x_pred, v_target) + return loss + + +""" +Diffusion Sampling +""" + +""" Schedules """ + + +class Schedule(nn.Module): + """Interface used by different sampling schedules""" + + def forward(self, num_steps: int, device: torch.device) -> Tensor: + raise NotImplementedError() + + +class LinearSchedule(Schedule): + def forward(self, num_steps: int, device: Any) -> Tensor: + sigmas = torch.linspace(1, 0, num_steps + 1)[:-1] + return sigmas + + +class KarrasSchedule(Schedule): + """https://arxiv.org/abs/2206.00364 equation 5""" + + def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): + super().__init__() + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def forward(self, num_steps: int, device: Any) -> Tensor: + rho_inv = 1.0 / self.rho + steps = torch.arange(num_steps, device=device, dtype=torch.float32) + sigmas = ( + self.sigma_max**rho_inv + + (steps / (num_steps - 1)) + * (self.sigma_min**rho_inv - self.sigma_max**rho_inv) + ) ** self.rho + sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) + return sigmas + + +""" Samplers """ + + +class Sampler(nn.Module): + diffusion_types: List[Type[Diffusion]] = [] + + def forward( + self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int + ) -> Tensor: + raise NotImplementedError() + + def inpaint( + self, + source: Tensor, + mask: Tensor, + fn: Callable, + sigmas: Tensor, + num_steps: int, + num_resamples: int, + ) -> Tensor: + raise NotImplementedError("Inpainting not available with current sampler") + + +class VSampler(Sampler): + diffusion_types = [VDiffusion] + + def get_alpha_beta(self, sigma: float) -> Tuple[float, float]: + angle = sigma * pi / 2 + alpha = cos(angle) + beta = sin(angle) + return alpha, beta + + def forward( + self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int + ) -> Tensor: + x = sigmas[0] * noise + alpha, beta = self.get_alpha_beta(sigmas[0].item()) + + for i in range(num_steps - 1): + is_last = i == num_steps - 1 + + x_denoised = fn(x, sigma=sigmas[i]) + x_pred = x * alpha - x_denoised * beta + x_eps = x * beta + x_denoised * alpha + + if not is_last: + alpha, beta = self.get_alpha_beta(sigmas[i + 1].item()) + x = x_pred * alpha + x_eps * beta + + return x_pred + + +class KarrasSampler(Sampler): + """https://arxiv.org/abs/2206.00364 algorithm 1""" + + diffusion_types = [KDiffusion, VKDiffusion] + + def __init__( + self, + s_tmin: float = 0, + s_tmax: float = float("inf"), + s_churn: float = 0.0, + s_noise: float = 1.0, + ): + super().__init__() + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + self.s_churn = s_churn + + def step( + self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float + ) -> Tensor: + """Algorithm 2 (step)""" + # Select temporarily increased noise level + sigma_hat = sigma + gamma * sigma + # Add noise to move from sigma to sigma_hat + epsilon = self.s_noise * torch.randn_like(x) + x_hat = x + sqrt(sigma_hat**2 - sigma**2) * epsilon + # Evaluate ∂x/∂sigma at sigma_hat + d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat + # Take euler step from sigma_hat to sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * d + # Second order correction + if sigma_next != 0: + model_out_next = fn(x_next, sigma=sigma_next) + d_prime = (x_next - model_out_next) / sigma_next + x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) + return x_next + + def forward( + self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int + ) -> Tensor: + x = sigmas[0] * noise + # Compute gammas + gammas = torch.where( + (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax), + min(self.s_churn / num_steps, sqrt(2) - 1), + 0.0, + ) + # Denoise to sample + for i in range(num_steps - 1): + x = self.step( + x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa + ) + + return x + + +class AEulerSampler(Sampler): + diffusion_types = [KDiffusion, VKDiffusion] + + def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]: + sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2) + sigma_down = sqrt(sigma_next**2 - sigma_up**2) + return sigma_up, sigma_down + + def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: + # Sigma steps + sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next) + # Derivative at sigma (∂x/∂sigma) + d = (x - fn(x, sigma=sigma)) / sigma + # Euler method + x_next = x + d * (sigma_down - sigma) + # Add randomness + x_next = x_next + torch.randn_like(x) * sigma_up + return x_next + + def forward( + self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int + ) -> Tensor: + x = sigmas[0] * noise + # Denoise to sample + for i in range(num_steps - 1): + x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa + return x + + +class ADPM2Sampler(Sampler): + """https://www.desmos.com/calculator/jbxjlqd9mb""" + + diffusion_types = [KDiffusion, VKDiffusion] + + def __init__(self, rho: float = 1.0): + super().__init__() + self.rho = rho + + def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: + r = self.rho + sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2) + sigma_down = sqrt(sigma_next**2 - sigma_up**2) + sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r + return sigma_up, sigma_down, sigma_mid + + def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: + # Sigma steps + sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) + # Derivative at sigma (∂x/∂sigma) + d = (x - fn(x, sigma=sigma)) / sigma + # Denoise to midpoint + x_mid = x + d * (sigma_mid - sigma) + # Derivative at sigma_mid (∂x_mid/∂sigma_mid) + d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid + # Denoise to next + x = x + d_mid * (sigma_down - sigma) + # Add randomness + x_next = x + torch.randn_like(x) * sigma_up + return x_next + + def forward( + self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int + ) -> Tensor: + x = sigmas[0] * noise + # Denoise to sample + for i in range(num_steps - 1): + x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa + return x + + def inpaint( + self, + source: Tensor, + mask: Tensor, + fn: Callable, + sigmas: Tensor, + num_steps: int, + num_resamples: int, + ) -> Tensor: + x = sigmas[0] * torch.randn_like(source) + + for i in range(num_steps - 1): + # Noise source to current noise level + source_noisy = source + sigmas[i] * torch.randn_like(source) + for r in range(num_resamples): + # Merge noisy source and current then denoise + x = source_noisy * mask + x * ~mask + x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa + # Renoise if not last resample step + if r < num_resamples - 1: + sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2) + x = x + sigma * torch.randn_like(x) + + return source * mask + x * ~mask + + +""" Main Classes """ + + +class DiffusionSampler(nn.Module): + def __init__( + self, + diffusion: Diffusion, + *, + sampler: Sampler, + sigma_schedule: Schedule, + num_steps: Optional[int] = None, + clamp: bool = True, + ): + super().__init__() + self.denoise_fn = diffusion.denoise_fn + self.sampler = sampler + self.sigma_schedule = sigma_schedule + self.num_steps = num_steps + self.clamp = clamp + + # Check sampler is compatible with diffusion type + sampler_class = sampler.__class__.__name__ + diffusion_class = diffusion.__class__.__name__ + message = f"{sampler_class} incompatible with {diffusion_class}" + assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message + + def forward( + self, noise: Tensor, num_steps: Optional[int] = None, **kwargs + ) -> Tensor: + device = noise.device + num_steps = default(num_steps, self.num_steps) # type: ignore + assert exists(num_steps), "Parameter `num_steps` must be provided" + # Compute sigmas using schedule + sigmas = self.sigma_schedule(num_steps, device) + # Append additional kwargs to denoise function (used e.g. for conditional unet) + fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa + # Sample using sampler + x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) + x = x.clamp(-1.0, 1.0) if self.clamp else x + return x + + +class DiffusionInpainter(nn.Module): + def __init__( + self, + diffusion: Diffusion, + *, + num_steps: int, + num_resamples: int, + sampler: Sampler, + sigma_schedule: Schedule, + ): + super().__init__() + self.denoise_fn = diffusion.denoise_fn + self.num_steps = num_steps + self.num_resamples = num_resamples + self.inpaint_fn = sampler.inpaint + self.sigma_schedule = sigma_schedule + + @torch.no_grad() + def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor: + x = self.inpaint_fn( + source=inpaint, + mask=inpaint_mask, + fn=self.denoise_fn, + sigmas=self.sigma_schedule(self.num_steps, inpaint.device), + num_steps=self.num_steps, + num_resamples=self.num_resamples, + ) + return x + + +def sequential_mask(like: Tensor, start: int) -> Tensor: + length, device = like.shape[2], like.device + mask = torch.ones_like(like, dtype=torch.bool) + mask[:, :, start:] = torch.zeros((length - start,), device=device) + return mask + + +class SpanBySpanComposer(nn.Module): + def __init__( + self, + inpainter: DiffusionInpainter, + *, + num_spans: int, + ): + super().__init__() + self.inpainter = inpainter + self.num_spans = num_spans + + def forward(self, start: Tensor, keep_start: bool = False) -> Tensor: + half_length = start.shape[2] // 2 + + spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else [] + # Inpaint second half from first half + inpaint = torch.zeros_like(start) + inpaint[:, :, :half_length] = start[:, :, half_length:] + inpaint_mask = sequential_mask(like=start, start=half_length) + + for i in range(self.num_spans): + # Inpaint second half + span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask) + # Replace first half with generated second half + second_half = span[:, :, half_length:] + inpaint[:, :, :half_length] = second_half + # Save generated span + spans.append(second_half) + + return torch.cat(spans, dim=2) + + +class XDiffusion(nn.Module): + def __init__(self, type: str, net: nn.Module, **kwargs): + super().__init__() + + diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion] + aliases = [t.alias for t in diffusion_classes] # type: ignore + message = f"type='{type}' must be one of {*aliases,}" + assert type in aliases, message + self.net = net + + for XDiffusion in diffusion_classes: + if XDiffusion.alias == type: # type: ignore + self.diffusion = XDiffusion(net=net, **kwargs) + + def forward(self, *args, **kwargs) -> Tensor: + return self.diffusion(*args, **kwargs) + + def sample( + self, + noise: Tensor, + num_steps: int, + sigma_schedule: Schedule, + sampler: Sampler, + clamp: bool, + **kwargs, + ) -> Tensor: + diffusion_sampler = DiffusionSampler( + diffusion=self.diffusion, + sampler=sampler, + sigma_schedule=sigma_schedule, + num_steps=num_steps, + clamp=clamp, + ) + return diffusion_sampler(noise, **kwargs) diff --git a/src/Modules/diffusion/utils.py b/src/Modules/diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a25cb57e89f97c78bd7295963169b3f36671bede --- /dev/null +++ b/src/Modules/diffusion/utils.py @@ -0,0 +1,83 @@ +from functools import reduce +from inspect import isfunction +from math import ceil, floor, log2, pi +from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Generator, Tensor +from typing_extensions import TypeGuard + +T = TypeVar("T") + + +def exists(val: Optional[T]) -> TypeGuard[T]: + return val is not None + + +def iff(condition: bool, value: T) -> Optional[T]: + return value if condition else None + + +def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: + return isinstance(obj, list) or isinstance(obj, tuple) + + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + + +def to_list(val: Union[T, Sequence[T]]) -> List[T]: + if isinstance(val, tuple): + return list(val) + if isinstance(val, list): + return val + return [val] # type: ignore + + +def prod(vals: Sequence[int]) -> int: + return reduce(lambda x, y: x * y, vals) + + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2**z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + + +def rand_bool(shape, proba, device=None): + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +""" +Kwargs Utils +""" + + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + + +def prefix_dict(prefix: str, d: Dict) -> Dict: + return {prefix + str(k): v for k, v in d.items()} diff --git a/src/Modules/discriminators.py b/src/Modules/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d428117dfa8f02f59cfaecabf1d525c8ecb49a --- /dev/null +++ b/src/Modules/discriminators.py @@ -0,0 +1,267 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, spectral_norm + +from .utils import get_padding + +LRELU_SLOPE = 0.1 + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + return torch.abs(x_stft).transpose(2, 1) + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__( + self, + fft_size=1024, + shift_size=120, + win_length=600, + window="hann_window", + use_spectral_norm=False, + ): + super(SpecDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.discriminators = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4)) + ), + norm_f( + nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4)) + ), + norm_f( + nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4)) + ), + norm_f( + nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + ), + ] + ) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + fmap = [] + y = y.squeeze(1) + y = stft( + y, + self.fft_size, + self.shift_size, + self.win_length, + self.window.to(y.get_device()), + ) + y = y.unsqueeze(1) + for i, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap + + +class MultiResSpecDiscriminator(torch.nn.Module): + def __init__( + self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window", + ): + super(MultiResSpecDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class WavLMDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__( + self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False + ): + super(WavLMDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.pre = norm_f( + Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0) + ) + + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv1d( + initial_channel, initial_channel * 2, kernel_size=5, padding=2 + ) + ), + norm_f( + nn.Conv1d( + initial_channel * 2, + initial_channel * 4, + kernel_size=5, + padding=2, + ) + ), + norm_f( + nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2) + ), + ] + ) + + self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) + + def forward(self, x): + x = self.pre(x) + + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x diff --git a/src/Modules/hifigan.py b/src/Modules/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..9162e3bf5e6723b77978ce901565e29f07f75245 --- /dev/null +++ b/src/Modules/hifigan.py @@ -0,0 +1,643 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from .utils import init_weights, get_padding + +import math +import random +import numpy as np + +LRELU_SLOPE = 0.1 + + +class AdaIN1d(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm1d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdaINResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): + super(AdaINResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.adain1 = nn.ModuleList( + [ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ] + ) + + self.adain2 = nn.ModuleList( + [ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ] + ) + + self.alpha1 = nn.ParameterList( + [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))] + ) + self.alpha2 = nn.ParameterList( + [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))] + ) + + def forward(self, x, s): + for c1, c2, n1, n2, a1, a2 in zip( + self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2 + ): + xt = n1(x, s) + xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D + xt = c1(xt) + xt = n2(xt, s) + xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + upsample_scale, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand( + f0_values.shape[0], f0_values.shape[2], device=f0_values.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # # for normal case + + # # To prevent torch.cumsum numerical overflow, + # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # # Buffer tmp_over_one_idx indicates the time step to add -1. + # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + # tmp_over_one = torch.cumsum(rad_values, 1) % 1 + # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + # cumsum_shift = torch.zeros_like(rad_values) + # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + rad_values = torch.nn.functional.interpolate( + rad_values.transpose(1, 2), + scale_factor=1 / self.upsample_scale, + mode="linear", + ).transpose(1, 2) + + # tmp_over_one = torch.cumsum(rad_values, 1) % 1 + # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + # cumsum_shift = torch.zeros_like(rad_values) + # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate( + phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, + mode="linear", + ).transpose(1, 2) + sines = torch.sin(phase) + + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + fn = torch.multiply( + f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device) + ) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__( + self, + sampling_rate, + upsample_scale, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + ): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, + upsample_scale, + harmonic_num, + sine_amp, + add_noise_std, + voiced_threshod, + ) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +def padDiff(x): + return F.pad( + F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0 + ) + + +class Generator(torch.nn.Module): + def __init__( + self, + style_dim, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + resblock = AdaINResBlock1 + + self.m_source = SourceModuleHnNSF( + sampling_rate=24000, + upsample_scale=np.prod(upsample_rates), + harmonic_num=8, + voiced_threshod=10, + ) + + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates)) + self.noise_convs = nn.ModuleList() + self.ups = nn.ModuleList() + self.noise_res = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(u // 2 + u % 2), + output_padding=u % 2, + ) + ) + ) + + if i + 1 < len(upsample_rates): # + stride_f0 = np.prod(upsample_rates[i + 1 :]) + self.noise_convs.append( + Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=(stride_f0 + 1) // 2, + ) + ) + self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim)) + + self.resblocks = nn.ModuleList() + + self.alphas = nn.ParameterList() + self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1))) + + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.alphas.append(nn.Parameter(torch.ones(1, ch, 1))) + + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d, style_dim)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, s, f0): + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2) + + for i in range(self.num_upsamples): + x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2) + x_source = self.noise_convs[i](har_source) + x_source = self.noise_res[i](x_source, s) + + x = self.ups[i](x) + x = x + x_source + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x, s) + else: + xs += self.resblocks[i * self.num_kernels + j](x, s) + x = xs / self.num_kernels + x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class AdainResBlk1d(nn.Module): + def __init__( + self, + dim_in, + dim_out, + style_dim=64, + actv=nn.LeakyReLU(0.2), + upsample="none", + dropout_p=0.0, + ): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + + if upsample == "none": + self.pool = nn.Identity() + else: + self.pool = weight_norm( + nn.ConvTranspose1d( + dim_in, + dim_in, + kernel_size=3, + stride=2, + groups=dim_in, + padding=1, + output_padding=1, + ) + ) + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == "none": + return x + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class Decoder(nn.Module): + def __init__( + self, + dim_in=512, + F0_channel=512, + style_dim=64, + dim_out=80, + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[10, 5, 3, 2], + upsample_initial_channel=512, + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_kernel_sizes=[20, 10, 6, 4], + ): + super().__init__() + + self.decode = nn.ModuleList() + + self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) + + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) + + self.F0_conv = weight_norm( + nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1) + ) + + self.N_conv = weight_norm( + nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1) + ) + + self.asr_res = nn.Sequential( + weight_norm(nn.Conv1d(512, 64, kernel_size=1)), + ) + + self.generator = Generator( + style_dim, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + ) + + def forward(self, asr, F0_curve, N, s): + if self.training: + downlist = [0, 3, 7] + F0_down = downlist[random.randint(0, 2)] + downlist = [0, 3, 7, 15] + N_down = downlist[random.randint(0, 3)] + if F0_down: + F0_curve = ( + nn.functional.conv1d( + F0_curve.unsqueeze(1), + torch.ones(1, 1, F0_down).to("cuda"), + padding=F0_down // 2, + ).squeeze(1) + / F0_down + ) + if N_down: + N = ( + nn.functional.conv1d( + N.unsqueeze(1), + torch.ones(1, 1, N_down).to("cuda"), + padding=N_down // 2, + ).squeeze(1) + / N_down + ) + + F0 = self.F0_conv(F0_curve.unsqueeze(1)) + N = self.N_conv(N.unsqueeze(1)) + + x = torch.cat([asr, F0, N], axis=1) + x = self.encode(x, s) + + asr_res = self.asr_res(asr) + + res = True + for block in self.decode: + if res: + x = torch.cat([x, asr_res, F0, N], axis=1) + x = block(x, s) + if block.upsample_type != "none": + res = False + + x = self.generator(x, s, F0_curve) + return x diff --git a/src/Modules/istftnet.py b/src/Modules/istftnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c91550d7917dfd5e420e36d2e50389b67154e12d --- /dev/null +++ b/src/Modules/istftnet.py @@ -0,0 +1,720 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from .utils import init_weights, get_padding + +import math +import random +import numpy as np +from scipy.signal import get_window + +LRELU_SLOPE = 0.1 + + +class AdaIN1d(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm1d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdaINResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): + super(AdaINResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.adain1 = nn.ModuleList( + [ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ] + ) + + self.adain2 = nn.ModuleList( + [ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ] + ) + + self.alpha1 = nn.ParameterList( + [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))] + ) + self.alpha2 = nn.ParameterList( + [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))] + ) + + def forward(self, x, s): + for c1, c2, n1, n2, a1, a2 in zip( + self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2 + ): + xt = n1(x, s) + xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D + xt = c1(xt) + xt = n2(xt, s) + xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class TorchSTFT(torch.nn.Module): + def __init__( + self, filter_length=800, hop_length=200, win_length=800, window="hann" + ): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = torch.from_numpy( + get_window(window, win_length, fftbins=True).astype(np.float32) + ) + + def transform(self, input_data): + forward_transform = torch.stft( + input_data, + self.filter_length, + self.hop_length, + self.win_length, + window=self.window.to(input_data.device), + return_complex=True, + ) + + return torch.abs(forward_transform), torch.angle(forward_transform) + + def inverse(self, magnitude, phase): + inverse_transform = torch.istft( + magnitude * torch.exp(phase * 1j), + self.filter_length, + self.hop_length, + self.win_length, + window=self.window.to(magnitude.device), + ) + + return inverse_transform.unsqueeze( + -2 + ) # unsqueeze to stay consistent with conv_transpose1d implementation + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + upsample_scale, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand( + f0_values.shape[0], f0_values.shape[2], device=f0_values.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # # for normal case + + # # To prevent torch.cumsum numerical overflow, + # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # # Buffer tmp_over_one_idx indicates the time step to add -1. + # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + # tmp_over_one = torch.cumsum(rad_values, 1) % 1 + # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + # cumsum_shift = torch.zeros_like(rad_values) + # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + rad_values = torch.nn.functional.interpolate( + rad_values.transpose(1, 2), + scale_factor=1 / self.upsample_scale, + mode="linear", + ).transpose(1, 2) + + # tmp_over_one = torch.cumsum(rad_values, 1) % 1 + # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + # cumsum_shift = torch.zeros_like(rad_values) + # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate( + phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, + mode="linear", + ).transpose(1, 2) + sines = torch.sin(phase) + + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + fn = torch.multiply( + f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device) + ) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__( + self, + sampling_rate, + upsample_scale, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + ): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, + upsample_scale, + harmonic_num, + sine_amp, + add_noise_std, + voiced_threshod, + ) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +def padDiff(x): + return F.pad( + F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0 + ) + + +class Generator(torch.nn.Module): + def __init__( + self, + style_dim, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + gen_istft_n_fft, + gen_istft_hop_size, + ): + super(Generator, self).__init__() + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + resblock = AdaINResBlock1 + + self.m_source = SourceModuleHnNSF( + sampling_rate=24000, + upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size, + harmonic_num=8, + voiced_threshod=10, + ) + self.f0_upsamp = torch.nn.Upsample( + scale_factor=np.prod(upsample_rates) * gen_istft_hop_size + ) + self.noise_convs = nn.ModuleList() + self.noise_res = nn.ModuleList() + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d, style_dim)) + + c_cur = upsample_initial_channel // (2 ** (i + 1)) + + if i + 1 < len(upsample_rates): # + stride_f0 = np.prod(upsample_rates[i + 1 :]) + self.noise_convs.append( + Conv1d( + gen_istft_n_fft + 2, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=(stride_f0 + 1) // 2, + ) + ) + self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim)) + else: + self.noise_convs.append( + Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1) + ) + self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim)) + + self.post_n_fft = gen_istft_n_fft + self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) + self.stft = TorchSTFT( + filter_length=gen_istft_n_fft, + hop_length=gen_istft_hop_size, + win_length=gen_istft_n_fft, + ) + + def forward(self, x, s, f0): + with torch.no_grad(): + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2).squeeze(1) + har_spec, har_phase = self.stft.transform(har_source) + har = torch.cat([har_spec, har_phase], dim=1) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x_source = self.noise_convs[i](har) + x_source = self.noise_res[i](x_source, s) + + x = self.ups[i](x) + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x, s) + else: + xs += self.resblocks[i * self.num_kernels + j](x, s) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :]) + return self.stft.inverse(spec, phase) + + def fw_phase(self, x, s): + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x, s) + else: + xs += self.resblocks[i * self.num_kernels + j](x, s) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.reflection_pad(x) + x = self.conv_post(x) + spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :]) + return spec, phase + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class AdainResBlk1d(nn.Module): + def __init__( + self, + dim_in, + dim_out, + style_dim=64, + actv=nn.LeakyReLU(0.2), + upsample="none", + dropout_p=0.0, + ): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + + if upsample == "none": + self.pool = nn.Identity() + else: + self.pool = weight_norm( + nn.ConvTranspose1d( + dim_in, + dim_in, + kernel_size=3, + stride=2, + groups=dim_in, + padding=1, + output_padding=1, + ) + ) + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == "none": + return x + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class Decoder(nn.Module): + def __init__( + self, + dim_in=512, + F0_channel=512, + style_dim=64, + dim_out=80, + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[10, 6], + upsample_initial_channel=512, + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_kernel_sizes=[20, 12], + gen_istft_n_fft=20, + gen_istft_hop_size=5, + ): + super().__init__() + + self.decode = nn.ModuleList() + + self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) + + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) + + self.F0_conv = weight_norm( + nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1) + ) + + self.N_conv = weight_norm( + nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1) + ) + + self.asr_res = nn.Sequential( + weight_norm(nn.Conv1d(512, 64, kernel_size=1)), + ) + + self.generator = Generator( + style_dim, + resblock_kernel_sizes, + upsample_rates, + upsample_initial_channel, + resblock_dilation_sizes, + upsample_kernel_sizes, + gen_istft_n_fft, + gen_istft_hop_size, + ) + + def forward(self, asr, F0_curve, N, s): + if self.training: + downlist = [0, 3, 7] + F0_down = downlist[random.randint(0, 2)] + downlist = [0, 3, 7, 15] + N_down = downlist[random.randint(0, 3)] + if F0_down: + F0_curve = ( + nn.functional.conv1d( + F0_curve.unsqueeze(1), + torch.ones(1, 1, F0_down).to("cuda"), + padding=F0_down // 2, + ).squeeze(1) + / F0_down + ) + if N_down: + N = ( + nn.functional.conv1d( + N.unsqueeze(1), + torch.ones(1, 1, N_down).to("cuda"), + padding=N_down // 2, + ).squeeze(1) + / N_down + ) + + F0 = self.F0_conv(F0_curve.unsqueeze(1)) + N = self.N_conv(N.unsqueeze(1)) + + x = torch.cat([asr, F0, N], axis=1) + x = self.encode(x, s) + + asr_res = self.asr_res(asr) + + res = True + for block in self.decode: + if res: + x = torch.cat([x, asr_res, F0, N], axis=1) + x = block(x, s) + if block.upsample_type != "none": + res = False + + x = self.generator(x, s, F0_curve) + return x diff --git a/src/Modules/slmadv.py b/src/Modules/slmadv.py new file mode 100644 index 0000000000000000000000000000000000000000..4f60d67f1a0dd3f52e3d06336f2aed674ef20e8a --- /dev/null +++ b/src/Modules/slmadv.py @@ -0,0 +1,256 @@ +import torch +import numpy as np +import torch.nn.functional as F + + +class SLMAdversarialLoss(torch.nn.Module): + def __init__( + self, + model, + wl, + sampler, + min_len, + max_len, + batch_percentage=0.5, + skip_update=10, + sig=1.5, + ): + super(SLMAdversarialLoss, self).__init__() + self.model = model + self.wl = wl + self.sampler = sampler + + self.min_len = min_len + self.max_len = max_len + self.batch_percentage = batch_percentage + + self.sig = sig + self.skip_update = skip_update + + def forward( + self, + iters, + y_rec_gt, + y_rec_gt_pred, + waves, + mel_input_length, + ref_text, + ref_lengths, + use_ind, + s_trg, + ref_s=None, + ): + text_mask = length_to_mask(ref_lengths).to(ref_text.device) + bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) + d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) + + if use_ind and np.random.rand() < 0.5: + s_preds = s_trg + else: + num_steps = np.random.randint(3, 5) + if ref_s is not None: + s_preds = self.sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), + embedding=bert_dur, + embedding_scale=1, + features=ref_s, # reference from the same speaker as the embedding + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + else: + s_preds = self.sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), + embedding=bert_dur, + embedding_scale=1, + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + + s_dur = s_preds[:, 128:] + s = s_preds[:, :128] + + d, _ = self.model.predictor( + d_en, + s_dur, + ref_lengths, + torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device), + text_mask, + ) + + bib = 0 + + output_lengths = [] + attn_preds = [] + + # differentiable duration modeling + for _s2s_pred, _text_length in zip(d, ref_lengths): + _s2s_pred_org = _s2s_pred[:_text_length, :] + + _s2s_pred = torch.sigmoid(_s2s_pred_org) + _dur_pred = _s2s_pred.sum(axis=-1) + + l = int(torch.round(_s2s_pred.sum()).item()) + t = torch.arange(0, l).expand(l) + + t = ( + torch.arange(0, l) + .unsqueeze(0) + .expand((len(_s2s_pred), l)) + .to(ref_text.device) + ) + loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 + + h = torch.exp( + -0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2 + ) + + out = torch.nn.functional.conv1d( + _s2s_pred_org.unsqueeze(0), + h.unsqueeze(1), + padding=h.shape[-1] - 1, + groups=int(_text_length), + )[..., :l] + attn_preds.append(F.softmax(out.squeeze(), dim=0)) + + output_lengths.append(l) + + max_len = max(output_lengths) + + with torch.no_grad(): + t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) + + s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to( + ref_text.device + ) + for bib in range(len(output_lengths)): + s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib] + + asr_pred = t_en @ s2s_attn + + _, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask) + + mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) + mel_len = min(mel_len, self.max_len // 2) + + # get clips + + en = [] + p_en = [] + sp = [] + + F0_fakes = [] + N_fakes = [] + + wav = [] + + for bib in range(len(output_lengths)): + mel_length_pred = output_lengths[bib] + mel_length_gt = int(mel_input_length[bib].item() / 2) + if mel_length_gt <= mel_len or mel_length_pred <= mel_len: + continue + + sp.append(s_preds[bib]) + + random_start = np.random.randint(0, mel_length_pred - mel_len) + en.append(asr_pred[bib, :, random_start : random_start + mel_len]) + p_en.append(p_pred[bib, :, random_start : random_start + mel_len]) + + # get ground truth clips + random_start = np.random.randint(0, mel_length_gt - mel_len) + y = waves[bib][ + (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 + ] + wav.append(torch.from_numpy(y).to(ref_text.device)) + + if len(wav) >= self.batch_percentage * len( + waves + ): # prevent OOM due to longer lengths + break + + if len(sp) <= 1: + return None + + sp = torch.stack(sp) + wav = torch.stack(wav).float() + en = torch.stack(en) + p_en = torch.stack(p_en) + + F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) + y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) + + # discriminator loss + if (iters + 1) % self.skip_update == 0: + if np.random.randint(0, 2) == 0: + wav = y_rec_gt_pred + use_rec = True + else: + use_rec = False + + crop_size = min(wav.size(-1), y_pred.size(-1)) + if ( + use_rec + ): # use reconstructed (shorter lengths), do length invariant regularization + if wav.size(-1) > y_pred.size(-1): + real_GP = wav[:, :, :crop_size] + out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) + out_org = self.wl.discriminator_forward(wav.detach().squeeze()) + loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) + + if np.random.randint(0, 2) == 0: + d_loss = self.wl.discriminator( + real_GP.detach().squeeze(), y_pred.detach().squeeze() + ).mean() + else: + d_loss = self.wl.discriminator( + wav.detach().squeeze(), y_pred.detach().squeeze() + ).mean() + else: + real_GP = y_pred[:, :, :crop_size] + out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) + out_org = self.wl.discriminator_forward(y_pred.detach().squeeze()) + loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) + + if np.random.randint(0, 2) == 0: + d_loss = self.wl.discriminator( + wav.detach().squeeze(), real_GP.detach().squeeze() + ).mean() + else: + d_loss = self.wl.discriminator( + wav.detach().squeeze(), y_pred.detach().squeeze() + ).mean() + + # regularization (ignore length variation) + d_loss += loss_reg + + out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze()) + out_rec = self.wl.discriminator_forward( + y_rec_gt_pred.detach().squeeze() + ) + + # regularization (ignore reconstruction artifacts) + d_loss += F.l1_loss(out_gt, out_rec) + + else: + d_loss = self.wl.discriminator( + wav.detach().squeeze(), y_pred.detach().squeeze() + ).mean() + else: + d_loss = 0 + + # generator loss + gen_loss = self.wl.generator(y_pred.squeeze()) + + gen_loss = gen_loss.mean() + + return d_loss, gen_loss, y_pred.detach().cpu().numpy() + + +def length_to_mask(lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)) + return mask diff --git a/src/Modules/utils.py b/src/Modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e7b1e323404c29c292ce2f08e23ed212d46562 --- /dev/null +++ b/src/Modules/utils.py @@ -0,0 +1,14 @@ +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) diff --git a/src/Utils/ASR/__init__.py b/src/Utils/ASR/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/Utils/ASR/__init__.py @@ -0,0 +1 @@ + diff --git a/src/Utils/ASR/__pycache__/__init__.cpython-310.pyc b/src/Utils/ASR/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d204f43014ea73014850b7fa2effdb3ca9b0b36 Binary files /dev/null and b/src/Utils/ASR/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/Utils/ASR/__pycache__/layers.cpython-310.pyc b/src/Utils/ASR/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78cb2efe4dd0668c45e51cab133f6f2fb13246ff Binary files /dev/null and b/src/Utils/ASR/__pycache__/layers.cpython-310.pyc differ diff --git a/src/Utils/ASR/__pycache__/models.cpython-310.pyc b/src/Utils/ASR/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aeefaf2a6adf012a27b6840b953eef6c7bbb0d8 Binary files /dev/null and b/src/Utils/ASR/__pycache__/models.cpython-310.pyc differ diff --git a/src/Utils/ASR/config.yml b/src/Utils/ASR/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..ca334a10a12dd3ef12497eb12651fa5f508106dc --- /dev/null +++ b/src/Utils/ASR/config.yml @@ -0,0 +1,29 @@ +log_dir: "logs/20201006" +save_freq: 5 +device: "cuda" +epochs: 180 +batch_size: 64 +pretrained_model: "" +train_data: "ASRDataset/train_list.txt" +val_data: "ASRDataset/val_list.txt" + +dataset_params: + data_augmentation: false + +preprocess_parasm: + sr: 24000 + spect_params: + n_fft: 2048 + win_length: 1200 + hop_length: 300 + mel_params: + n_mels: 80 + +model_params: + input_dim: 80 + hidden_dim: 256 + n_token: 178 + token_embedding_dim: 512 + +optimizer_params: + lr: 0.0005 \ No newline at end of file diff --git a/src/Utils/ASR/epoch_00080.pth b/src/Utils/ASR/epoch_00080.pth new file mode 100644 index 0000000000000000000000000000000000000000..121895323ccef067212af46e08111eef37fd7a82 --- /dev/null +++ b/src/Utils/ASR/epoch_00080.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c +size 94552811 diff --git a/src/Utils/ASR/layers.py b/src/Utils/ASR/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..044e383d2ecf1251bd92e289db8c145a4f591b98 --- /dev/null +++ b/src/Utils/ASR/layers.py @@ -0,0 +1,455 @@ +import math +import torch +from torch import nn +from typing import Optional, Any +from torch import Tensor +import torch.nn.functional as F +import torchaudio +import torchaudio.functional as audio_F + +import random + +random.seed(0) + + +def _get_activation_fn(activ): + if activ == "relu": + return nn.ReLU() + elif activ == "lrelu": + return nn.LeakyReLU(0.2) + elif activ == "swish": + return lambda x: x * torch.sigmoid(x) + else: + raise RuntimeError( + "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ + ) + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + param=None, + ): + super(ConvNorm, self).__init__() + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + self.conv.weight, + gain=torch.nn.init.calculate_gain(w_init_gain, param=param), + ) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +class CausualConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=1, + dilation=1, + bias=True, + w_init_gain="linear", + param=None, + ): + super(CausualConv, self).__init__() + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) * 2 + else: + self.padding = padding * 2 + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + self.conv.weight, + gain=torch.nn.init.calculate_gain(w_init_gain, param=param), + ) + + def forward(self, x): + x = self.conv(x) + x = x[:, :, : -self.padding] + return x + + +class CausualBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"): + super(CausualBlock, self).__init__() + self.blocks = nn.ModuleList( + [ + self._get_conv( + hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p + ) + for i in range(n_conv) + ] + ) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2): + layers = [ + CausualConv( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), + _get_activation_fn(activ), + nn.BatchNorm1d(hidden_dim), + nn.Dropout(p=dropout_p), + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p), + ] + return nn.Sequential(*layers) + + +class ConvBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"): + super().__init__() + self._n_groups = 8 + self.blocks = nn.ModuleList( + [ + self._get_conv( + hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p + ) + for i in range(n_conv) + ] + ) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2): + layers = [ + ConvNorm( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), + _get_activation_fn(activ), + nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), + nn.Dropout(p=dropout_p), + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p), + ] + return nn.Sequential(*layers) + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm( + 2, + attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = LinearNorm( + attention_n_filters, attention_dim, bias=False, w_init_gain="tanh" + ) + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__( + self, + attention_rnn_dim, + embedding_dim, + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ): + super(Attention, self).__init__() + self.query_layer = LinearNorm( + attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = LinearNorm( + embedding_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, attention_dim + ) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + energies = energies.squeeze(-1) + return energies + + def forward( + self, + attention_hidden_state, + memory, + processed_memory, + attention_weights_cat, + mask, + ): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class ForwardAttentionV2(nn.Module): + def __init__( + self, + attention_rnn_dim, + embedding_dim, + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ): + super(ForwardAttentionV2, self).__init__() + self.query_layer = LinearNorm( + attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = LinearNorm( + embedding_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, attention_dim + ) + self.score_mask_value = -float(1e20) + + def get_alignment_energies(self, query, processed_memory, attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + energies = energies.squeeze(-1) + return energies + + def forward( + self, + attention_hidden_state, + memory, + processed_memory, + attention_weights_cat, + mask, + log_alpha, + ): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + log_energy = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + # log_energy = + + if mask is not None: + log_energy.data.masked_fill_(mask, self.score_mask_value) + + # attention_weights = F.softmax(alignment, dim=1) + + # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] + # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] + + # log_total_score = log_alpha + content_score + + # previous_attention_weights = attention_weights_cat[:,0,:] + + log_alpha_shift_padded = [] + max_time = log_energy.size(1) + for sft in range(2): + shifted = log_alpha[:, : max_time - sft] + shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value) + log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) + + biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2) + + log_alpha_new = biased + log_energy + + attention_weights = F.softmax(log_alpha_new, dim=1) + + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights, log_alpha_new + + +class PhaseShuffle2d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle2d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = torch.cat([right, left], dim=3) + return shuffled + + +class PhaseShuffle1d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle1d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :move] + right = x[:, :, move:] + shuffled = torch.cat([right, left], dim=2) + + return shuffled + + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = "ortho" + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer("dct_mat", dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc diff --git a/src/Utils/ASR/models.py b/src/Utils/ASR/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5fd83dda801c5edfe00400d5accc3d69195ad4 --- /dev/null +++ b/src/Utils/ASR/models.py @@ -0,0 +1,217 @@ +import math +import torch +from torch import nn +from torch.nn import TransformerEncoder +import torch.nn.functional as F +from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock + + +class ASRCNN(nn.Module): + def __init__( + self, + input_dim=80, + hidden_dim=256, + n_token=35, + n_layers=6, + token_embedding_dim=256, + ): + super().__init__() + self.n_token = n_token + self.n_down = 1 + self.to_mfcc = MFCC() + self.init_cnn = ConvNorm( + input_dim // 2, hidden_dim, kernel_size=7, padding=3, stride=2 + ) + self.cnns = nn.Sequential( + *[ + nn.Sequential( + ConvBlock(hidden_dim), + nn.GroupNorm(num_groups=1, num_channels=hidden_dim), + ) + for n in range(n_layers) + ] + ) + self.projection = ConvNorm(hidden_dim, hidden_dim // 2) + self.ctc_linear = nn.Sequential( + LinearNorm(hidden_dim // 2, hidden_dim), + nn.ReLU(), + LinearNorm(hidden_dim, n_token), + ) + self.asr_s2s = ASRS2S( + embedding_dim=token_embedding_dim, + hidden_dim=hidden_dim // 2, + n_token=n_token, + ) + + def forward(self, x, src_key_padding_mask=None, text_input=None): + x = self.to_mfcc(x) + x = self.init_cnn(x) + x = self.cnns(x) + x = self.projection(x) + x = x.transpose(1, 2) + ctc_logit = self.ctc_linear(x) + if text_input is not None: + _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input) + return ctc_logit, s2s_logit, s2s_attn + else: + return ctc_logit + + def get_feature(self, x): + x = self.to_mfcc(x.squeeze(1)) + x = self.init_cnn(x) + x = self.cnns(x) + x = self.projection(x) + return x + + def length_to_mask(self, lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)).to(lengths.device) + return mask + + def get_future_mask(self, out_length, unmask_future_steps=0): + """ + Args: + out_length (int): returned mask shape is (out_length, out_length). + unmask_futre_steps (int): unmasking future step size. + Return: + mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False + """ + index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1) + mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps) + return mask + + +class ASRS2S(nn.Module): + def __init__( + self, + embedding_dim=256, + hidden_dim=512, + n_location_filters=32, + location_kernel_size=63, + n_token=40, + ): + super(ASRS2S, self).__init__() + self.embedding = nn.Embedding(n_token, embedding_dim) + val_range = math.sqrt(6 / hidden_dim) + self.embedding.weight.data.uniform_(-val_range, val_range) + + self.decoder_rnn_dim = hidden_dim + self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) + self.attention_layer = Attention( + self.decoder_rnn_dim, + hidden_dim, + hidden_dim, + n_location_filters, + location_kernel_size, + ) + self.decoder_rnn = nn.LSTMCell( + self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim + ) + self.project_to_hidden = nn.Sequential( + LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), nn.Tanh() + ) + self.sos = 1 + self.eos = 2 + + def initialize_decoder_states(self, memory, mask): + """ + moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) + """ + B, L, H = memory.shape + self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) + self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) + self.attention_weights = torch.zeros((B, L)).type_as(memory) + self.attention_weights_cum = torch.zeros((B, L)).type_as(memory) + self.attention_context = torch.zeros((B, H)).type_as(memory) + self.memory = memory + self.processed_memory = self.attention_layer.memory_layer(memory) + self.mask = mask + self.unk_index = 3 + self.random_mask = 0.1 + + def forward(self, memory, memory_mask, text_input): + """ + moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) + moemory_mask.shape = (B, L, ) + texts_input.shape = (B, T) + """ + self.initialize_decoder_states(memory, memory_mask) + # text random mask + random_mask = (torch.rand(text_input.shape) < self.random_mask).to( + text_input.device + ) + _text_input = text_input.clone() + _text_input.masked_fill_(random_mask, self.unk_index) + decoder_inputs = self.embedding(_text_input).transpose( + 0, 1 + ) # -> [T, B, channel] + start_embedding = self.embedding( + torch.LongTensor([self.sos] * decoder_inputs.size(1)).to( + decoder_inputs.device + ) + ) + decoder_inputs = torch.cat( + (start_embedding.unsqueeze(0), decoder_inputs), dim=0 + ) + + hidden_outputs, logit_outputs, alignments = [], [], [] + while len(hidden_outputs) < decoder_inputs.size(0): + decoder_input = decoder_inputs[len(hidden_outputs)] + hidden, logit, attention_weights = self.decode(decoder_input) + hidden_outputs += [hidden] + logit_outputs += [logit] + alignments += [attention_weights] + + hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs( + hidden_outputs, logit_outputs, alignments + ) + + return hidden_outputs, logit_outputs, alignments + + def decode(self, decoder_input): + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + cell_input, (self.decoder_hidden, self.decoder_cell) + ) + + attention_weights_cat = torch.cat( + ( + self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1), + ), + dim=1, + ) + + self.attention_context, self.attention_weights = self.attention_layer( + self.decoder_hidden, + self.memory, + self.processed_memory, + attention_weights_cat, + self.mask, + ) + + self.attention_weights_cum += self.attention_weights + + hidden_and_context = torch.cat( + (self.decoder_hidden, self.attention_context), -1 + ) + hidden = self.project_to_hidden(hidden_and_context) + + # dropout to increasing g + logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) + + return hidden, logit, self.attention_weights + + def parse_decoder_outputs(self, hidden, logit, alignments): + # -> [B, T_out + 1, max_time] + alignments = torch.stack(alignments).transpose(0, 1) + # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols] + logit = torch.stack(logit).transpose(0, 1).contiguous() + hidden = torch.stack(hidden).transpose(0, 1).contiguous() + + return hidden, logit, alignments diff --git a/src/Utils/JDC/__init__.py b/src/Utils/JDC/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/Utils/JDC/__init__.py @@ -0,0 +1 @@ + diff --git a/src/Utils/JDC/__pycache__/__init__.cpython-310.pyc b/src/Utils/JDC/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7273748648d93338f50c4f504727ba32b0ae5cd Binary files /dev/null and b/src/Utils/JDC/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/Utils/JDC/__pycache__/model.cpython-310.pyc b/src/Utils/JDC/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d87caf53bdcd868595a7a920cb1ff6b924f15cd Binary files /dev/null and b/src/Utils/JDC/__pycache__/model.cpython-310.pyc differ diff --git a/src/Utils/JDC/bst.t7 b/src/Utils/JDC/bst.t7 new file mode 100644 index 0000000000000000000000000000000000000000..5aa5a7b89991a3ecce2fd13447d6cb65740d2a9b --- /dev/null +++ b/src/Utils/JDC/bst.t7 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169 +size 21029926 diff --git a/src/Utils/JDC/model.py b/src/Utils/JDC/model.py new file mode 100644 index 0000000000000000000000000000000000000000..01416d7345e667d89f6d77ec349d6e7c6b4fc1b7 --- /dev/null +++ b/src/Utils/JDC/model.py @@ -0,0 +1,212 @@ +""" +Implementation of model from: +Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using +Convolutional Recurrent Neural Networks" (2019) +Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d +""" +import torch +from torch import nn + + +class JDCNet(nn.Module): + """ + Joint Detection and Classification Network model for singing voice melody. + """ + + def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01): + super().__init__() + self.num_class = num_class + + # input = (b, 1, 31, 513), b = batch size + self.conv_block = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False + ), # out: (b, 64, 31, 513) + nn.BatchNorm2d(num_features=64), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513) + ) + + # res blocks + self.res_block1 = ResBlock( + in_channels=64, out_channels=128 + ) # (b, 128, 31, 128) + self.res_block2 = ResBlock( + in_channels=128, out_channels=192 + ) # (b, 192, 31, 32) + self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8) + + # pool block + self.pool_block = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2) + nn.Dropout(p=0.2), + ) + + # maxpool layers (for auxiliary network inputs) + # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2) + self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40)) + # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2) + self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20)) + # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2) + self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10)) + + # in = (b, 640, 31, 2), out = (b, 256, 31, 2) + self.detector_conv = nn.Sequential( + nn.Conv2d(640, 256, 1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Dropout(p=0.2), + ) + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + self.bilstm_classifier = nn.LSTM( + input_size=512, hidden_size=256, batch_first=True, bidirectional=True + ) # (b, 31, 512) + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + self.bilstm_detector = nn.LSTM( + input_size=512, hidden_size=256, batch_first=True, bidirectional=True + ) # (b, 31, 512) + + # input: (b * 31, 512) + self.classifier = nn.Linear( + in_features=512, out_features=self.num_class + ) # (b * 31, num_class) + + # input: (b * 31, 512) + self.detector = nn.Linear( + in_features=512, out_features=2 + ) # (b * 31, 2) - binary classifier + + # initialize weights + self.apply(self.init_weights) + + def get_feature_GAN(self, x): + seq_len = x.shape[-2] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + + return poolblock_out.transpose(-1, -2) + + def get_feature(self, x): + seq_len = x.shape[-2] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + + return self.pool_block[2](poolblock_out) + + def forward(self, x): + """ + Returns: + classification_prediction, detection_prediction + sizes: (b, 31, 722), (b, 31, 2) + """ + ############################### + # forward pass for classifier # + ############################### + seq_len = x.shape[-1] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + GAN_feature = poolblock_out.transpose(-1, -2) + poolblock_out = self.pool_block[2](poolblock_out) + + # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512) + classifier_out = ( + poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512)) + ) + classifier_out, _ = self.bilstm_classifier( + classifier_out + ) # ignore the hidden states + + classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512) + classifier_out = self.classifier(classifier_out) + classifier_out = classifier_out.view( + (-1, seq_len, self.num_class) + ) # (b, 31, num_class) + + # sizes: (b, 31, 722), (b, 31, 2) + # classifier output consists of predicted pitch classes per frame + # detector output consists of: (isvoice, notvoice) estimates per frame + return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out + + @staticmethod + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell): + for p in m.parameters(): + if p.data is None: + continue + + if len(p.shape) >= 2: + nn.init.orthogonal_(p.data) + else: + nn.init.normal_(p.data) + + +class ResBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01): + super().__init__() + self.downsample = in_channels != out_channels + + # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper + self.pre_conv = nn.Sequential( + nn.BatchNorm2d(num_features=in_channels), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only + ) + + # conv layers + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), + ) + + # 1 x 1 convolution layer to match the feature dimensions + self.conv1by1 = None + if self.downsample: + self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False) + + def forward(self, x): + x = self.pre_conv(x) + if self.downsample: + x = self.conv(x) + self.conv1by1(x) + else: + x = self.conv(x) + x + return x diff --git a/src/Utils/PLBERT/__pycache__/util.cpython-310.pyc b/src/Utils/PLBERT/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87d8e32a8dfd991747d71de2e2768af12d09070b Binary files /dev/null and b/src/Utils/PLBERT/__pycache__/util.cpython-310.pyc differ diff --git a/src/Utils/PLBERT/config.yml b/src/Utils/PLBERT/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..75f60d1eccfee7e253649a1ae02105b5245e2473 --- /dev/null +++ b/src/Utils/PLBERT/config.yml @@ -0,0 +1,30 @@ +log_dir: "Checkpoint" +mixed_precision: "fp16" +data_folder: "wikipedia_20220301.en.processed" +batch_size: 192 +save_interval: 5000 +log_interval: 10 +num_process: 1 # number of GPUs +num_steps: 1000000 + +dataset_params: + tokenizer: "transfo-xl-wt103" + token_separator: " " # token used for phoneme separator (space) + token_mask: "M" # token used for phoneme mask (M) + word_separator: 3039 # token used for word separator () + token_maps: "token_maps.pkl" # token map path + + max_mel_length: 512 # max phoneme length + + word_mask_prob: 0.15 # probability to mask the entire word + phoneme_mask_prob: 0.1 # probability to mask each phoneme + replace_prob: 0.2 # probablity to replace phonemes + +model_params: + vocab_size: 178 + hidden_size: 768 + num_attention_heads: 12 + intermediate_size: 2048 + max_position_embeddings: 512 + num_hidden_layers: 12 + dropout: 0.1 \ No newline at end of file diff --git a/src/Utils/PLBERT/step_1000000.t7 b/src/Utils/PLBERT/step_1000000.t7 new file mode 100644 index 0000000000000000000000000000000000000000..b28ef5c0556a97bc31b0ee996945c4e566bcf6c0 --- /dev/null +++ b/src/Utils/PLBERT/step_1000000.t7 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0 +size 25185187 diff --git a/src/Utils/PLBERT/util.py b/src/Utils/PLBERT/util.py new file mode 100644 index 0000000000000000000000000000000000000000..743f3849f55b9a1deb9a91b4b9369f1046f98716 --- /dev/null +++ b/src/Utils/PLBERT/util.py @@ -0,0 +1,49 @@ +import os +import yaml +import torch +from transformers import AlbertConfig, AlbertModel + + +class CustomAlbert(AlbertModel): + def forward(self, *args, **kwargs): + # Call the original forward method + outputs = super().forward(*args, **kwargs) + + # Only return the last_hidden_state + return outputs.last_hidden_state + + +def load_plbert(log_dir): + config_path = os.path.join(log_dir, "config.yml") + plbert_config = yaml.safe_load(open(config_path)) + + albert_base_configuration = AlbertConfig(**plbert_config["model_params"]) + bert = CustomAlbert(albert_base_configuration) + + files = os.listdir(log_dir) + ckpts = [] + for f in os.listdir(log_dir): + if f.startswith("step_"): + ckpts.append(f) + + iters = [ + int(f.split("_")[-1].split(".")[0]) + for f in ckpts + if os.path.isfile(os.path.join(log_dir, f)) + ] + iters = sorted(iters)[-1] + + checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu") + state_dict = checkpoint["net"] + from collections import OrderedDict + + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + if name.startswith("encoder."): + name = name[8:] # remove `encoder.` + new_state_dict[name] = v + del new_state_dict["embeddings.position_ids"] + bert.load_state_dict(new_state_dict, strict=False) + + return bert diff --git a/src/Utils/__init__.py b/src/Utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/Utils/__init__.py @@ -0,0 +1 @@ + diff --git a/src/Utils/__pycache__/__init__.cpython-310.pyc b/src/Utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd5d47220223b18549bb93365650671a24f2c93 Binary files /dev/null and b/src/Utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/__pycache__/api.cpython-310.pyc b/src/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f25826c46f4557ba673d3621a3cb1b73296cb64 Binary files /dev/null and b/src/__pycache__/api.cpython-310.pyc differ diff --git a/src/__pycache__/attentions.cpython-310.pyc b/src/__pycache__/attentions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..693d474329f090c88e6a64f96acac7786462f117 Binary files /dev/null and b/src/__pycache__/attentions.cpython-310.pyc differ diff --git a/src/__pycache__/commons.cpython-310.pyc b/src/__pycache__/commons.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deca31e73d41863cc8f2a2bd92d19ba8300a58ea Binary files /dev/null and b/src/__pycache__/commons.cpython-310.pyc differ diff --git a/src/__pycache__/mel_processing.cpython-310.pyc b/src/__pycache__/mel_processing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d47757fc641f6c679e0dd369eb0db33745c73a7 Binary files /dev/null and b/src/__pycache__/mel_processing.cpython-310.pyc differ diff --git a/src/__pycache__/models.cpython-310.pyc b/src/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb55d8daa7cefb73b37e9a7b994a4933ac38a9f Binary files /dev/null and b/src/__pycache__/models.cpython-310.pyc differ diff --git a/src/__pycache__/modules.cpython-310.pyc b/src/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d63254deb368d8da0807cc859e30acd95a2e5f Binary files /dev/null and b/src/__pycache__/modules.cpython-310.pyc differ diff --git a/src/__pycache__/predict.cpython-310.pyc b/src/__pycache__/predict.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..511cbe07abb9e06e65c9370334c3a80cce257226 Binary files /dev/null and b/src/__pycache__/predict.cpython-310.pyc differ diff --git a/src/__pycache__/rp_schema.cpython-310.pyc b/src/__pycache__/rp_schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..502c1ef237c0e63bc4d7d077bbeb8506cb2cc64d Binary files /dev/null and b/src/__pycache__/rp_schema.cpython-310.pyc differ diff --git a/src/__pycache__/se_extractor.cpython-310.pyc b/src/__pycache__/se_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02663502f29965a082c7b37b48e8f941fb27ce71 Binary files /dev/null and b/src/__pycache__/se_extractor.cpython-310.pyc differ diff --git a/src/__pycache__/text_utils.cpython-310.pyc b/src/__pycache__/text_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..140d2a5d520fe7a90e1211553483bd2ba14f62fa Binary files /dev/null and b/src/__pycache__/text_utils.cpython-310.pyc differ diff --git a/src/__pycache__/transforms.cpython-310.pyc b/src/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da1c22bed1bf3963b7d89fb2b0eac12ea1ac0d22 Binary files /dev/null and b/src/__pycache__/transforms.cpython-310.pyc differ diff --git a/src/__pycache__/utils.cpython-310.pyc b/src/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a38d2f919331c103585c32fe7fc4e508794cc0 Binary files /dev/null and b/src/__pycache__/utils.cpython-310.pyc differ diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000000000000000000000000000000000000..29d7362b0a5b4fc2ed9f40c4d97d2d674a4a9b8a --- /dev/null +++ b/src/api.py @@ -0,0 +1,218 @@ +import torch +import numpy as np +import re +import soundfile +import utils +import commons +import os +import librosa +from text import text_to_sequence +from mel_processing import spectrogram_torch +from models import SynthesizerTrn +from openai import OpenAI + +class OpenVoiceBaseClass(object): + def __init__(self, + config_path, + device='cuda:0'): + if 'cuda' in device: + assert torch.cuda.is_available() + + hps = utils.get_hparams_from_file(config_path) + + + + model = SynthesizerTrn( + len(getattr(hps, 'symbols', [])), + hps.data.filter_length // 2 + 1, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + + model.eval() + self.model = model + self.hps = hps + self.device = device + + def load_ckpt(self, ckpt_path): + checkpoint_dict = torch.load(ckpt_path) + a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + print("Loaded checkpoint '{}'".format(ckpt_path)) + print('missing/unexpected keys:', a, b) + + +class BaseSpeakerTTS(OpenVoiceBaseClass): + language_marks = { + "english": "EN", + "chinese": "ZH", + } + + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + @staticmethod + def get_text(text, hps, is_symbol): + text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) + if hps.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + @staticmethod + def audio_numpy_concat(segment_data_list, sr, speed=1.): + audio_segments = [] + for segment_data in segment_data_list: + audio_segments += segment_data.reshape(-1).tolist() + audio_segments += [0] * int((sr * 0.05)/speed) + audio_segments = np.array(audio_segments).astype(np.float32) + return audio_segments + + @staticmethod + def split_sentences_into_pieces(text, language_str): + texts = utils.split_sentence(text, language_str=language_str) + print(" > Text splitted to sentences.") + print('\n'.join(texts)) + print(" > ===========================") + return texts + + def tts(self, text, output_path, speaker, language='English', speed=1.0,use_emotions=False): + mark = self.language_marks.get(language.lower(), None) + assert mark is not None, f"language {language} is not supported" + + texts = self.split_sentences_into_pieces(text, mark) + + audio_list = [] + for t in texts: + speaker_id = self.hps.speakers[speaker] + try: + if(use_emotions): + print(f"finding emotion for {t}") + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are an expert in identifying the emotion of a sentence. Use noemotion if no emotions. Possible answers are noemotion,friendly,cheerful,excited,sad,angry,terrified,shouting,whispering.Just reply 1 word"}, + {"role": "user", "content": t}, + ] + ) + result = response.choices[0].message.content + print(result) + if result not in ["friendly", "cheerful", "excited", "sad", "angry", "terrified", "shouting", "whispering"]: + result = 'default' + print(f"CHOSEN {result}") + speaker_id = self.hps.speakers[result] + except Exception as e: + print(f"Exception {e}") + speaker_id = self.hps.speakers['default'] + t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) + t = f'[{mark}]{t}[{mark}]' + stn_tst = self.get_text(t, self.hps, False) + device = self.device + + with torch.no_grad(): + x_tst = stn_tst.unsqueeze(0).to(device) + x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) + sid = torch.LongTensor([speaker_id]).to(device) + audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6, + length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() + audio_list.append(audio) + audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) + + if output_path is None: + return audio + else: + soundfile.write(output_path, audio, self.hps.data.sampling_rate) + + +class ToneColorConverter(OpenVoiceBaseClass): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.watermark_model = None + + def extract_se(self, ref_wav_list, se_save_path=None): + if isinstance(ref_wav_list, str): + ref_wav_list = [ref_wav_list] + + device = self.device + hps = self.hps + gs = [] + + for fname in ref_wav_list: + audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate) + y = torch.FloatTensor(audio_ref) + y = y.to(device) + y = y.unsqueeze(0) + y = spectrogram_torch(y, hps.data.filter_length, + hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, + center=False).to(device) + with torch.no_grad(): + g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) + gs.append(g.detach()) + gs = torch.stack(gs).mean(0) + + if se_save_path is not None: + os.makedirs(os.path.dirname(se_save_path), exist_ok=True) + torch.save(gs.cpu(), se_save_path) + + return gs + + def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"): + hps = self.hps + # load audio + audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate) + audio = torch.tensor(audio).float() + + with torch.no_grad(): + y = torch.FloatTensor(audio).to(self.device) + y = y.unsqueeze(0) + spec = spectrogram_torch(y, hps.data.filter_length, + hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, + center=False).to(self.device) + spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) + audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][ + 0, 0].data.cpu().float().numpy() + audio = self.add_watermark(audio, message) + if output_path is None: + return audio + else: + soundfile.write(output_path, audio, hps.data.sampling_rate) + + def add_watermark(self, audio, message): + if self.watermark_model is None: + return audio + device = self.device + bits = utils.string_to_bits(message).reshape(-1) + n_repeat = len(bits) // 32 + + K = 16000 + coeff = 2 + for n in range(n_repeat): + trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] + if len(trunck) != K: + print('Audio too short, fail to add watermark') + break + message_npy = bits[n * 32: (n + 1) * 32] + + with torch.no_grad(): + signal = torch.FloatTensor(trunck).to(device)[None] + message_tensor = torch.FloatTensor(message_npy).to(device)[None] + signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor) + signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze() + audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy + return audio + + def detect_watermark(self, audio, n_repeat): + bits = [] + K = 16000 + coeff = 2 + for n in range(n_repeat): + trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] + if len(trunck) != K: + print('Audio too short, fail to detect watermark') + return 'Fail' + with torch.no_grad(): + signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0) + message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze() + bits.append(message_decoded_npy) + bits = np.stack(bits).reshape(-1, 8) + message = utils.bits_to_string(bits) + return message + diff --git a/src/attentions.py b/src/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..355a72e5172fa86354fd848954b521dfb2bb158f --- /dev/null +++ b/src/attentions.py @@ -0,0 +1,465 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import commons +import logging + +logger = logging.getLogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + # if isflow: + # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) + # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + # self.cond_layer = weight_norm(cond_layer, name='weight') + # self.gin_channels = 256 + self.cond_layer_idx = self.n_layers + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = ( + kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + ) + # logging.debug(self.gin_channels, self.cond_layer_idx) + assert ( + self.cond_layer_idx < self.n_layers + ), "cond_layer_idx should be less than n_layers" + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + if i == self.cond_layer_idx and g is not None: + g = self.spk_emb_linear(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # pad along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/src/checkpoints/base_speakers/EN/checkpoint.pth b/src/checkpoints/base_speakers/EN/checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..fb7c26af57011437a02ebb1c4fe8ed307cc30f21 --- /dev/null +++ b/src/checkpoints/base_speakers/EN/checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1db1ae1a5c8ded049bd1536051489aefbfad4a5077c01c2257e9e88fa1bb8422 +size 160467309 diff --git a/src/checkpoints/base_speakers/EN/config.json b/src/checkpoints/base_speakers/EN/config.json new file mode 100644 index 0000000000000000000000000000000000000000..f7309ad10eae3c160ea0ef44261372c4f3364587 --- /dev/null +++ b/src/checkpoints/base_speakers/EN/config.json @@ -0,0 +1,145 @@ +{ + "data": { + "text_cleaners": [ + "cjke_cleaners2" + ], + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mel_channels": 80, + "add_blank": true, + "cleaned_text": true, + "n_speakers": 10 + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "n_layers_trans_flow": 3, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256 + }, + "symbols": [ + "_", + ",", + ".", + "!", + "?", + "-", + "~", + "\u2026", + "N", + "Q", + "a", + "b", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "\u0251", + "\u00e6", + "\u0283", + "\u0291", + "\u00e7", + "\u026f", + "\u026a", + "\u0254", + "\u025b", + "\u0279", + "\u00f0", + "\u0259", + "\u026b", + "\u0265", + "\u0278", + "\u028a", + "\u027e", + "\u0292", + "\u03b8", + "\u03b2", + "\u014b", + "\u0266", + "\u207c", + "\u02b0", + "`", + "^", + "#", + "*", + "=", + "\u02c8", + "\u02cc", + "\u2192", + "\u2193", + "\u2191", + " " + ], + "speakers": { + "default": 1, + "whispering": 2, + "shouting": 3, + "excited": 4, + "cheerful": 5, + "terrified": 6, + "angry": 7, + "sad": 8, + "friendly": 9 + } +} \ No newline at end of file diff --git a/src/checkpoints/base_speakers/EN/en_default_se.pth b/src/checkpoints/base_speakers/EN/en_default_se.pth new file mode 100644 index 0000000000000000000000000000000000000000..319d7eb4bee7b785a47f4e6191c2132dec12abcf --- /dev/null +++ b/src/checkpoints/base_speakers/EN/en_default_se.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cab24002eec738d0fe72cb73a34e57fbc3999c1bd4a1670a7b56ee4e3590ac9 +size 1789 diff --git a/src/checkpoints/base_speakers/EN/en_style_se.pth b/src/checkpoints/base_speakers/EN/en_style_se.pth new file mode 100644 index 0000000000000000000000000000000000000000..c2fd50abf058f6ab65879395b62fb7e3c0289b47 --- /dev/null +++ b/src/checkpoints/base_speakers/EN/en_style_se.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f698153be5004b90a8642d1157c89cae7dd296752a3276450ced6a17b8b98a9 +size 1783 diff --git a/src/checkpoints/base_speakers/ZH/checkpoint.pth b/src/checkpoints/base_speakers/ZH/checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..fcadb5c222e9ea92fc9ada4920249fc65cad1692 --- /dev/null +++ b/src/checkpoints/base_speakers/ZH/checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de9fb0eb749f3254130fe0172fcbb20e75f88a9b16b54dd0b73cac0dc40da7d9 +size 160467309 diff --git a/src/checkpoints/base_speakers/ZH/config.json b/src/checkpoints/base_speakers/ZH/config.json new file mode 100644 index 0000000000000000000000000000000000000000..130256092fb8ad00f938149bf8aa1a62aae30023 --- /dev/null +++ b/src/checkpoints/base_speakers/ZH/config.json @@ -0,0 +1,137 @@ +{ + "data": { + "text_cleaners": [ + "cjke_cleaners2" + ], + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mel_channels": 80, + "add_blank": true, + "cleaned_text": true, + "n_speakers": 10 + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "n_layers_trans_flow": 3, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256 + }, + "symbols": [ + "_", + ",", + ".", + "!", + "?", + "-", + "~", + "\u2026", + "N", + "Q", + "a", + "b", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "\u0251", + "\u00e6", + "\u0283", + "\u0291", + "\u00e7", + "\u026f", + "\u026a", + "\u0254", + "\u025b", + "\u0279", + "\u00f0", + "\u0259", + "\u026b", + "\u0265", + "\u0278", + "\u028a", + "\u027e", + "\u0292", + "\u03b8", + "\u03b2", + "\u014b", + "\u0266", + "\u207c", + "\u02b0", + "`", + "^", + "#", + "*", + "=", + "\u02c8", + "\u02cc", + "\u2192", + "\u2193", + "\u2191", + " " + ], + "speakers": { + "default": 0 + } +} \ No newline at end of file diff --git a/src/checkpoints/base_speakers/ZH/zh_default_se.pth b/src/checkpoints/base_speakers/ZH/zh_default_se.pth new file mode 100644 index 0000000000000000000000000000000000000000..471841ae84a31aae1c8e25c1ef4548b3e87a32bb --- /dev/null +++ b/src/checkpoints/base_speakers/ZH/zh_default_se.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b62e8264962059b8a84dd00b29e2fcccc92f5d3be90eec67dfa082c0cf58ccf +size 1789 diff --git a/src/checkpoints/converter/checkpoint.pth b/src/checkpoints/converter/checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..c38ff17666bae2bae4236f85bfe2284f4885b31a --- /dev/null +++ b/src/checkpoints/converter/checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89ae83aa4e3668fef64b388b789ff7b0ce0def9f801069edfc18a00ea420748d +size 131327338 diff --git a/src/checkpoints/converter/config.json b/src/checkpoints/converter/config.json new file mode 100644 index 0000000000000000000000000000000000000000..a163d4254b637e9fd489712db40c15aeacda169e --- /dev/null +++ b/src/checkpoints/converter/config.json @@ -0,0 +1,57 @@ +{ + "data": { + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_speakers": 0 + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256 + } +} \ No newline at end of file diff --git a/src/commons.py b/src/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fa07f65b1681e1f469b04b2fe689b7c174eaaa --- /dev/null +++ b/src/commons.py @@ -0,0 +1,160 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/src/compute.py b/src/compute.py new file mode 100644 index 0000000000000000000000000000000000000000..ea68abba8d3eb349051d83fd2edfcd5d07d4897f --- /dev/null +++ b/src/compute.py @@ -0,0 +1,138 @@ +from cached_path import cached_path + +# from dp.phonemizer import Phonemizer +print("NLTK") +import nltk +nltk.download('punkt') +print("SCIPY") +from scipy.io.wavfile import write +print("TORCH STUFF") +import torch +print("START") +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +import random +random.seed(0) + +import numpy as np +np.random.seed(0) + +# load packages +import time +import random +import yaml +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +from nltk.tokenize import word_tokenize + +from models import * +from utils import * +from text_utils import TextCleaner +textclenaer = TextCleaner() + + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) +mean, std = -4, 4 + +def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + +def compute_style(path): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = preprocess(audio).to(device) + + with torch.no_grad(): + ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + +device = 'cpu' +if torch.cuda.is_available(): + device = 'cuda' +elif torch.backends.mps.is_available(): + print("MPS would be available but cannot be used rn") + # device = 'mps' + + + +# config = yaml.safe_load(open("Models/LibriTTS/config.yml")) +config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml")))) + +# load pretrained ASR model +ASR_config = config.get('ASR_config', False) +ASR_path = config.get('ASR_path', False) +text_aligner = load_ASR_models(ASR_path, ASR_config) + +# load pretrained F0 model +F0_path = config.get('F0_path', False) +pitch_extractor = load_F0_models(F0_path) + +# load BERT model +from Utils.PLBERT.util import load_plbert +BERT_path = config.get('PLBERT_dir', False) +plbert = load_plbert(BERT_path) + +model_params = recursive_munch(config['model_params']) +model = build_model(model_params, text_aligner, pitch_extractor, plbert) +_ = [model[key].eval() for key in model] +_ = [model[key].to(device) for key in model] + +# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu') +params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu') +params = params_whole['net'] + +for key in model: + if key in params: + print('%s loaded' % key) + try: + model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model[key].load_state_dict(new_state_dict, strict=False) +# except: +# _load(params[key], model[key]) +_ = [model[key].eval() for key in model] + +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False +) +voicelist = ['f-us-1', 'f-us-2', 'f-us-3', 'f-us-4', 'm-us-1', 'm-us-2', 'm-us-3', 'm-us-4'] +voices = {} +# todo: cache computed style, load using pickle +for v in voicelist: + print(f"Loading voice {v}") + voices[v] = compute_style(f'voices/{v}.wav') +import pickle +with open('voices.pkl', 'wb') as f: + pickle.dump(voices, f) \ No newline at end of file diff --git a/src/gruut_phonemize.py b/src/gruut_phonemize.py new file mode 100644 index 0000000000000000000000000000000000000000..d3514c39ac6e9752e20a0259e02d434e109a5afb --- /dev/null +++ b/src/gruut_phonemize.py @@ -0,0 +1,10 @@ +from gruut import sentences + + +def gphonemize(text): + phonemes = '' + for sent in sentences(text, lang="en-us"): + for word in sent: + if word.phonemes: + phonemes += ''.join(word.phonemes) + return phonemes \ No newline at end of file diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..378846d50d03518eff37b729d5e9f7a7cd9f15c4 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,375 @@ +import torch +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +import nltk +import time +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule +nltk.download('punkt') +from tortoise.utils.text import split_and_recombine_text + +import random +random.seed(0) + +import numpy as np +np.random.seed(0) + +import time +import random +import yaml +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +from nltk.tokenize import word_tokenize + +from models import * +from utils import * +from text_utils import TextCleaner +import soundfile as sf + +textclenaer = TextCleaner() + + +start_time = time.time() +from resemble_enhance.enhancer.inference import denoise, enhance + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +''' +def _fn(path, solver, nfe, tau, denoising): + if path is None: + return None, None + + solver = solver.lower() + nfe = int(nfe) + lambd = 0.9 if denoising else 0.1 + + dwav, sr = torchaudio.load(path) + dwav = dwav.mean(dim=0) + + wav1, new_sr = denoise(dwav, sr, device) + wav2, new_sr = enhance(dwav, sr, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) + + wav1 = wav1.cpu().numpy() + wav2 = wav2.cpu().numpy() + + sf.write('output_wav1.wav', wav1, new_sr) + sf.write('output_wav2.wav', wav2, new_sr) + return (new_sr, wav1), (new_sr, wav2) + +(new_sr, wav1), (new_sr, wav2) = _fn('/root/src/hf/videly/voices/huberman_clone.wav',"Midpoint",32,0.5,True) + +end_time = time.time() +elapsed_time = end_time - start_time +print(f"Loop took {elapsed_time} seconds to complete.") + +''' +start_time = time.time() + + + + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) +mean, std = -4, 4 + +def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + +def compute_style(path): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = preprocess(audio).to(device) + + with torch.no_grad(): + ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +print(device) +import phonemizer +global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) + +config = yaml.safe_load(open("Configs/hg.yml")) +print(config) + +# load pretrained ASR model +ASR_config = config.get('ASR_config', False) +ASR_path = config.get('ASR_path', False) +text_aligner = load_ASR_models(ASR_path, ASR_config) + +# load pretrained F0 model +F0_path = config.get('F0_path', False) +pitch_extractor = load_F0_models(F0_path) + +# load BERT model +from Utils.PLBERT.util import load_plbert +BERT_path = config.get('PLBERT_dir', False) +plbert = load_plbert(BERT_path) + +model_params = recursive_munch(config['model_params']) +model = build_model(model_params, text_aligner, pitch_extractor, plbert) +_ = [model[key].eval() for key in model] +_ = [model[key].to(device) for key in model] + +params_whole = torch.load("Models/epochs_2nd_00020.pth", map_location='cpu') +params = params_whole['net'] + +for key in model: + if key in params: + print('%s loaded' % key) + try: + model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model[key].load_state_dict(new_state_dict, strict=False) +# except: +# _load(params[key], model[key]) +_ = [model[key].eval() for key in model] + + + + +sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False +) + + +def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-50] + + +def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + ps = ps.replace('``', '"') + ps = ps.replace("''", '"') + + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + if s_prev is not None: + # convex combination of previous and current style + s_pred = t * s_prev + (1 - t) * s_pred + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + s_pred = torch.cat([ref, s], dim=-1) + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-100], s_pred # + +passage = ''' +Psychology, a field as fascinating as it is complex, delves into the intricate workings of the human mind. At its core, it seeks to understand and explain how we think, feel, and behave. The journey into the realms of psychology embarks from the fundamental belief that human behavior is not random, but instead driven by internal and external factors, often intertwined in an intricate dance of cause and effect. + +The study of psychology branches out into various specializations, each focusing on different aspects of human behavior and mental processes. Clinical psychologists, for instance, explore the depths of mental health, working to diagnose, treat, and prevent mental disorders. Their work is pivotal in helping individuals navigate the often challenging waters of mental illness, offering therapies and interventions that can significantly improve the quality of life. + +On the other hand, developmental psychology provides insight into the growth and change that occur throughout a person's life. From the first words of a toddler to the wisdom of the elderly, developmental psychologists study how we evolve over time, shaping our understanding of the various stages of life. This specialization is crucial in understanding how early experiences influence behavior and personality in later years. + +Social psychology, another intriguing branch, examines how individuals are influenced by others. It uncovers the subtle yet powerful ways in which societal norms, group dynamics, and interpersonal relationships shape our actions and beliefs. Understanding these social factors is essential in addressing broader societal issues, from discrimination and prejudice to conflict and cooperation. +''' + +path = "output_wav2.wav" +s_ref = compute_style(path) +#sentences = passage.split('.') # simple split by comma +#sentences = passage +sentences = split_and_recombine_text(passage) +wavs = [] +s_prev = None +for text in sentences: + if text.strip() == "": continue + text += '.' # add it back + wav, s_prev = LFinference(text, + s_prev, + s_ref, + alpha = 0, + beta = 0.3, # make it more suitable for the text + t = 0.7, + diffusion_steps=10, embedding_scale=1) + wavs.append(wav) + +audio_arrays = [] +for wav_file in wavs: + audio_arrays.append(wav_file) +concatenated_audio = np.concatenate(audio_arrays) +print('Synthesized: ') + +sf.write('huberman_clone_after_resemble.wav', concatenated_audio, 24000) + +end_time = time.time() +elapsed_time = end_time - start_time +print(f"Loop took {elapsed_time} seconds to complete.") + +def _fn(path, solver, nfe, tau, denoising): + if path is None: + return None, None + + solver = solver.lower() + nfe = int(nfe) + lambd = 0.9 if denoising else 0.1 + + dwav, sr = torchaudio.load(path) + dwav = dwav.mean(dim=0) + + wav1, new_sr = denoise(dwav, sr, device) + wav2, new_sr = enhance(dwav, sr, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) + + wav1 = wav1.cpu().numpy() + wav2 = wav2.cpu().numpy() + + sf.write('enhanced.wav', wav2, new_sr) + return (new_sr, wav1), (new_sr, wav2) + +(new_sr, wav1), (new_sr, wav2) = _fn('/root/src/hf/videly/huberman_clone_after_resemble.wav',"Midpoint",32,0.5,True) \ No newline at end of file diff --git a/src/ljspeechimportable.py b/src/ljspeechimportable.py new file mode 100644 index 0000000000000000000000000000000000000000..171174bba6bf4286350f7316f753eff180b00cc6 --- /dev/null +++ b/src/ljspeechimportable.py @@ -0,0 +1,225 @@ +from cached_path import cached_path + + +import torch +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +import random +random.seed(0) + +import numpy as np +np.random.seed(0) + +import nltk +nltk.download('punkt') + +# load packages +import time +import random +import yaml +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +from nltk.tokenize import word_tokenize + +from models import * +from utils import * +from text_utils import TextCleaner +textclenaer = TextCleaner() + + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) +mean, std = -4, 4 + +def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + +def compute_style(ref_dicts): + reference_embeddings = {} + for key, path in ref_dicts.items(): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = preprocess(audio).to(device) + + with torch.no_grad(): + ref = model.style_encoder(mel_tensor.unsqueeze(1)) + reference_embeddings[key] = (ref.squeeze(1), audio) + + return reference_embeddings + +# load phonemizer +import phonemizer +global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore') + +# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt'))) + + +config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml')))) + +# load pretrained ASR model +ASR_config = config.get('ASR_config', False) +ASR_path = config.get('ASR_path', False) +text_aligner = load_ASR_models(ASR_path, ASR_config) + +# load pretrained F0 model +F0_path = config.get('F0_path', False) +pitch_extractor = load_F0_models(F0_path) + +# load BERT model +from Utils.PLBERT.util import load_plbert +BERT_path = config.get('PLBERT_dir', False) +plbert = load_plbert(BERT_path) + +model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) +_ = [model[key].eval() for key in model] +_ = [model[key].to(device) for key in model] + +# params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu') +params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu') +params = params_whole['net'] + +for key in model: + if key in params: + print('%s loaded' % key) + try: + model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model[key].load_state_dict(new_state_dict, strict=False) +# except: +# _load(params[key], model[key]) +_ = [model[key].eval() for key in model] + +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False +) + +def inference(text, noise, diffusion_steps=5, embedding_scale=1): + text = text.strip() + text = text.replace('"', '') + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) + text_mask = length_to_mask(input_lengths).to(tokens.device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise, + embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, + embedding_scale=embedding_scale).squeeze(0) + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + pred_dur[-1] += 5 + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + return out.squeeze().cpu().numpy() + +def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1): + text = text.strip() + text = text.replace('"', '') + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) + text_mask = length_to_mask(input_lengths).to(tokens.device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise, + embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, + embedding_scale=embedding_scale).squeeze(0) + + if s_prev is not None: + # convex combination of previous and current style + s_pred = alpha * s_prev + (1 - alpha) * s_pred + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + return out.squeeze().cpu().numpy(), s_pred \ No newline at end of file diff --git a/src/losses.py b/src/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..135807afaf698162e3d5bb577072246f5835ae80 --- /dev/null +++ b/src/losses.py @@ -0,0 +1,303 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +from transformers import AutoModel + + +class SpectralConvergengeLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergengeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__( + self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window + ): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.to_mel = torchaudio.transforms.MelSpectrogram( + sample_rate=24000, + n_fft=fft_size, + win_length=win_length, + hop_length=shift_size, + window_fn=window, + ) + + self.spectral_convergenge_loss = SpectralConvergengeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + """ + x_mag = self.to_mel(x) + mean, std = -4, 4 + x_mag = (torch.log(1e-5 + x_mag) - mean) / std + + y_mag = self.to_mel(y) + mean, std = -4, 4 + y_mag = (torch.log(1e-5 + y_mag) - mean) / std + + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + return sc_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__( + self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window=torch.hann_window, + ): + """Initialize Multi resolution STFT loss module. + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + """ + sc_loss = 0.0 + for f in self.stft_losses: + sc_l = f(x, y) + sc_loss += sc_l + sc_loss /= len(self.stft_losses) + + return sc_loss + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """ + + +def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + tau = 0.04 + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + for dg, dr in zip(disc_real_outputs, disc_generated_outputs): + tau = 0.04 + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +class GeneratorLoss(torch.nn.Module): + def __init__(self, mpd, msd): + super(GeneratorLoss, self).__init__() + self.mpd = mpd + self.msd = msd + + def forward(self, y, y_hat): + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat) + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss( + y_ds_hat_r, y_ds_hat_g + ) + + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel + + return loss_gen_all.mean() + + +class DiscriminatorLoss(torch.nn.Module): + def __init__(self, mpd, msd): + super(DiscriminatorLoss, self).__init__() + self.mpd = mpd + self.msd = msd + + def forward(self, y, y_hat): + # MPD + y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( + y_df_hat_r, y_df_hat_g + ) + # MSD + y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( + y_ds_hat_r, y_ds_hat_g + ) + + loss_rel = discriminator_TPRLS_loss( + y_df_hat_r, y_df_hat_g + ) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) + + d_loss = loss_disc_s + loss_disc_f + loss_rel + + return d_loss.mean() + + +class WavLMLoss(torch.nn.Module): + def __init__(self, model, wd, model_sr, slm_sr=16000): + super(WavLMLoss, self).__init__() + self.wavlm = AutoModel.from_pretrained(model) + self.wd = wd + self.resample = torchaudio.transforms.Resample(model_sr, slm_sr) + + def forward(self, wav, y_rec): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16.squeeze(), output_hidden_states=True + ).hidden_states + + floss = 0 + for er, eg in zip(wav_embeddings, y_rec_embeddings): + floss += torch.mean(torch.abs(er - eg)) + + return floss.mean() + + def generator(self, y_rec): + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16, output_hidden_states=True + ).hidden_states + y_rec_embeddings = ( + torch.stack(y_rec_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + y_df_hat_g = self.wd(y_rec_embeddings) + loss_gen = torch.mean((1 - y_df_hat_g) ** 2) + + return loss_gen + + def discriminator(self, wav, y_rec): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16, output_hidden_states=True + ).hidden_states + + y_embeddings = ( + torch.stack(wav_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + y_rec_embeddings = ( + torch.stack(y_rec_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + + y_d_rs = self.wd(y_embeddings) + y_d_gs = self.wd(y_rec_embeddings) + + y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs + + r_loss = torch.mean((1 - y_df_hat_r) ** 2) + g_loss = torch.mean((y_df_hat_g) ** 2) + + loss_disc_f = r_loss + g_loss + + return loss_disc_f.mean() + + def discriminator_forward(self, wav): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_embeddings = ( + torch.stack(wav_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + + y_d_rs = self.wd(y_embeddings) + + return y_d_rs diff --git a/src/mel_processing.py b/src/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..822d7f19062497b198ae54554a3ab828c10147ad --- /dev/null +++ b/src/mel_processing.py @@ -0,0 +1,183 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.1: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.1: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): + # if torch.min(y) < -1.: + # print('min value is ', torch.min(y)) + # if torch.max(y) > 1.: + # print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + + # ******************** original ************************# + # y = y.squeeze(1) + # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + + # ******************** ConvSTFT ************************# + freq_cutoff = n_fft // 2 + 1 + fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) + forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) + forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() + + import torch.nn.functional as F + + # if center: + # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) + assert center is False + + forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size) + spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1) + + + # ******************** Verification ************************# + spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + assert torch.allclose(spec1, spec2, atol=1e-4) + + spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=y.dtype, device=y.device + ) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec \ No newline at end of file diff --git a/src/meldataset.py b/src/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e0c09728942f1ae77a41c9476cc7a0d52e2a25 --- /dev/null +++ b/src/meldataset.py @@ -0,0 +1,294 @@ +# coding: utf-8 +import os +import os.path as osp +import time +import random +import numpy as np +import random +import soundfile as sf +import librosa + +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +from torch.utils.data import DataLoader + +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +import pandas as pd + +_pad = "$" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +dicts = {} +for i in range(len((symbols))): + dicts[symbols[i]] = i + + +class TextCleaner: + def __init__(self, dummy=None): + self.word_index_dictionary = dicts + + def __call__(self, text): + indexes = [] + for char in text: + try: + indexes.append(self.word_index_dictionary[char]) + except KeyError: + print(text) + return indexes + + +np.random.seed(1) +random.seed(1) +SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300} +MEL_PARAMS = { + "n_mels": 80, +} + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300 +) +mean, std = -4, 4 + + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + + +class FilePathDataset(torch.utils.data.Dataset): + def __init__( + self, + data_list, + root_path, + sr=24000, + data_augmentation=False, + validation=False, + OOD_data="Data/OOD_texts.txt", + min_length=50, + ): + spect_params = SPECT_PARAMS + mel_params = MEL_PARAMS + + _data_list = [l[:-1].split("|") for l in data_list] + self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] + self.text_cleaner = TextCleaner() + self.sr = sr + + self.df = pd.DataFrame(self.data_list) + + self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) + + self.mean, self.std = -4, 4 + self.data_augmentation = data_augmentation and (not validation) + self.max_mel_length = 192 + + self.min_length = min_length + with open(OOD_data, "r") as f: + tl = f.readlines() + idx = 1 if ".wav" in tl[0].split("|")[0] else 0 + self.ptexts = [t.split("|")[idx] for t in tl] + + self.root_path = root_path + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + data = self.data_list[idx] + path = data[0] + + wave, text_tensor, speaker_id = self._load_tensor(data) + + mel_tensor = preprocess(wave).squeeze() + + acoustic_feature = mel_tensor.squeeze() + length_feature = acoustic_feature.size(1) + acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)] + + # get reference sample + ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist() + ref_mel_tensor, ref_label = self._load_data(ref_data[:3]) + + # get OOD text + + ps = "" + + while len(ps) < self.min_length: + rand_idx = np.random.randint(0, len(self.ptexts) - 1) + ps = self.ptexts[rand_idx] + + text = self.text_cleaner(ps) + text.insert(0, 0) + text.append(0) + + ref_text = torch.LongTensor(text) + + return ( + speaker_id, + acoustic_feature, + text_tensor, + ref_text, + ref_mel_tensor, + ref_label, + path, + wave, + ) + + def _load_tensor(self, data): + wave_path, text, speaker_id = data + speaker_id = int(speaker_id) + wave, sr = sf.read(osp.join(self.root_path, wave_path)) + if wave.shape[-1] == 2: + wave = wave[:, 0].squeeze() + if sr != 24000: + wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) + print(wave_path, sr) + + wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) + + text = self.text_cleaner(text) + + text.insert(0, 0) + text.append(0) + + text = torch.LongTensor(text) + + return wave, text, speaker_id + + def _load_data(self, data): + wave, text_tensor, speaker_id = self._load_tensor(data) + mel_tensor = preprocess(wave).squeeze() + + mel_length = mel_tensor.size(1) + if mel_length > self.max_mel_length: + random_start = np.random.randint(0, mel_length - self.max_mel_length) + mel_tensor = mel_tensor[ + :, random_start : random_start + self.max_mel_length + ] + + return mel_tensor, speaker_id + + +class Collater(object): + """ + Args: + adaptive_batch_size (bool): if true, decrease batch size when long data comes. + """ + + def __init__(self, return_wave=False): + self.text_pad_index = 0 + self.min_mel_length = 192 + self.max_mel_length = 192 + self.return_wave = return_wave + + def __call__(self, batch): + # batch[0] = wave, mel, text, f0, speakerid + batch_size = len(batch) + + # sort by mel length + lengths = [b[1].shape[1] for b in batch] + batch_indexes = np.argsort(lengths)[::-1] + batch = [batch[bid] for bid in batch_indexes] + + nmels = batch[0][1].size(0) + max_mel_length = max([b[1].shape[1] for b in batch]) + max_text_length = max([b[2].shape[0] for b in batch]) + max_rtext_length = max([b[3].shape[0] for b in batch]) + + labels = torch.zeros((batch_size)).long() + mels = torch.zeros((batch_size, nmels, max_mel_length)).float() + texts = torch.zeros((batch_size, max_text_length)).long() + ref_texts = torch.zeros((batch_size, max_rtext_length)).long() + + input_lengths = torch.zeros(batch_size).long() + ref_lengths = torch.zeros(batch_size).long() + output_lengths = torch.zeros(batch_size).long() + ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() + ref_labels = torch.zeros((batch_size)).long() + paths = ["" for _ in range(batch_size)] + waves = [None for _ in range(batch_size)] + + for bid, ( + label, + mel, + text, + ref_text, + ref_mel, + ref_label, + path, + wave, + ) in enumerate(batch): + mel_size = mel.size(1) + text_size = text.size(0) + rtext_size = ref_text.size(0) + labels[bid] = label + mels[bid, :, :mel_size] = mel + texts[bid, :text_size] = text + ref_texts[bid, :rtext_size] = ref_text + input_lengths[bid] = text_size + ref_lengths[bid] = rtext_size + output_lengths[bid] = mel_size + paths[bid] = path + ref_mel_size = ref_mel.size(1) + ref_mels[bid, :, :ref_mel_size] = ref_mel + + ref_labels[bid] = ref_label + waves[bid] = wave + + return ( + waves, + texts, + input_lengths, + ref_texts, + ref_lengths, + mels, + output_lengths, + ref_mels, + ) + + +def build_dataloader( + path_list, + root_path, + validation=False, + OOD_data="Data/OOD_texts.txt", + min_length=50, + batch_size=4, + num_workers=1, + device="cpu", + collate_config={}, + dataset_config={}, +): + dataset = FilePathDataset( + path_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + validation=validation, + **dataset_config + ) + collate_fn = Collater(**collate_config) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=(not validation), + num_workers=num_workers, + drop_last=(not validation), + collate_fn=collate_fn, + pin_memory=(device != "cpu"), + ) + + return data_loader diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d050d5a70783e46655f28e7d82ddba5c0dc974 --- /dev/null +++ b/src/models.py @@ -0,0 +1,1379 @@ +# coding:utf-8 + +import os +import os.path as osp + +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from Utils.ASR.models import ASRCNN +from Utils.JDC.model import JDCNet + +from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution +from Modules.diffusion.modules import Transformer1d, StyleTransformer1d +from Modules.diffusion.diffusion import AudioDiffusionConditional + +from Modules.discriminators import ( + MultiPeriodDiscriminator, + MultiResSpecDiscriminator, + WavLMDiscriminator, +) + +from munch import Munch +import yaml + +import math +import torch +from torch import nn +from torch.nn import functional as F + +import commons +import modules +import attentions + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from commons import init_weights, get_padding + + +class LearnedDownSample(nn.Module): + def __init__(self, layer_type, dim_in): + super().__init__() + self.layer_type = layer_type + + if self.layer_type == "none": + self.conv = nn.Identity() + elif self.layer_type == "timepreserve": + self.conv = spectral_norm( + nn.Conv2d( + dim_in, + dim_in, + kernel_size=(3, 1), + stride=(2, 1), + groups=dim_in, + padding=(1, 0), + ) + ) + elif self.layer_type == "half": + self.conv = spectral_norm( + nn.Conv2d( + dim_in, + dim_in, + kernel_size=(3, 3), + stride=(2, 2), + groups=dim_in, + padding=1, + ) + ) + else: + raise RuntimeError( + "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]" + % self.layer_type + ) + + def forward(self, x): + return self.conv(x) + + +class LearnedUpSample(nn.Module): + def __init__(self, layer_type, dim_in): + super().__init__() + self.layer_type = layer_type + + if self.layer_type == "none": + self.conv = nn.Identity() + elif self.layer_type == "timepreserve": + self.conv = nn.ConvTranspose2d( + dim_in, + dim_in, + kernel_size=(3, 1), + stride=(2, 1), + groups=dim_in, + output_padding=(1, 0), + padding=(1, 0), + ) + elif self.layer_type == "half": + self.conv = nn.ConvTranspose2d( + dim_in, + dim_in, + kernel_size=(3, 3), + stride=(2, 2), + groups=dim_in, + output_padding=1, + padding=1, + ) + else: + raise RuntimeError( + "Got unexpected upsampletype %s, expected is [none, timepreserve, half]" + % self.layer_type + ) + + def forward(self, x): + return self.conv(x) + + +class DownSample(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == "none": + return x + elif self.layer_type == "timepreserve": + return F.avg_pool2d(x, (2, 1)) + elif self.layer_type == "half": + if x.shape[-1] % 2 != 0: + x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1) + return F.avg_pool2d(x, 2) + else: + raise RuntimeError( + "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]" + % self.layer_type + ) + + +class UpSample(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == "none": + return x + elif self.layer_type == "timepreserve": + return F.interpolate(x, scale_factor=(2, 1), mode="nearest") + elif self.layer_type == "half": + return F.interpolate(x, scale_factor=2, mode="nearest") + else: + raise RuntimeError( + "Got unexpected upsampletype %s, expected is [none, timepreserve, half]" + % self.layer_type + ) + + +class ResBlk(nn.Module): + def __init__( + self, + dim_in, + dim_out, + actv=nn.LeakyReLU(0.2), + normalize=False, + downsample="none", + ): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = DownSample(downsample) + self.downsample_res = LearnedDownSample(downsample, dim_in) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1)) + self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1)) + if self.normalize: + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + if self.learned_sc: + self.conv1x1 = spectral_norm( + nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + ) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = self.downsample(x) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + x = self.downsample_res(x) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / math.sqrt(2) # unit variance + + +class StyleEncoder(nn.Module): + def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384): + super().__init__() + blocks = [] + blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))] + + repeat_num = 4 + for _ in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample="half")] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))] + blocks += [nn.AdaptiveAvgPool2d(1)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + + self.unshared = nn.Linear(dim_out, style_dim) + + def forward(self, x): + h = self.shared(x) + h = h.view(h.size(0), -1) + s = self.unshared(h) + + return s + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + def forward(self, x): + return self.linear_layer(x) + + +class Discriminator2d(nn.Module): + def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4): + super().__init__() + blocks = [] + blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))] + + for lid in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample="half")] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))] + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.AdaptiveAvgPool2d(1)] + blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))] + self.main = nn.Sequential(*blocks) + + def get_feature(self, x): + features = [] + for l in self.main: + x = l(x) + features.append(x) + out = features[-1] + out = out.view(out.size(0), -1) # (batch, num_domains) + return out, features + + def forward(self, x): + out, features = self.get_feature(x) + out = out.squeeze() # (batch) + return out, features + + +class ResBlk1d(nn.Module): + def __init__( + self, + dim_in, + dim_out, + actv=nn.LeakyReLU(0.2), + normalize=False, + downsample="none", + dropout_p=0.2, + ): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample_type = downsample + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + self.dropout_p = dropout_p + + if self.downsample_type == "none": + self.pool = nn.Identity() + else: + self.pool = weight_norm( + nn.Conv1d( + dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1 + ) + ) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + if self.normalize: + self.norm1 = nn.InstanceNorm1d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm1d(dim_in, affine=True) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def downsample(self, x): + if self.downsample_type == "none": + return x + else: + if x.shape[-1] % 2 != 0: + x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1) + return F.avg_pool1d(x, 2) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + x = self.downsample(x) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = F.dropout(x, p=self.dropout_p, training=self.training) + + x = self.conv1(x) + x = self.pool(x) + if self.normalize: + x = self.norm2(x) + + x = self.actv(x) + x = F.dropout(x, p=self.dropout_p, training=self.training) + + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / math.sqrt(2) # unit variance + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class TextEncoder(nn.Module): + def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): + super().__init__() + self.embedding = nn.Embedding(n_symbols, channels) + + padding = (kernel_size - 1) // 2 + self.cnn = nn.ModuleList() + for _ in range(depth): + self.cnn.append( + nn.Sequential( + weight_norm( + nn.Conv1d( + channels, channels, kernel_size=kernel_size, padding=padding + ) + ), + LayerNorm(channels), + actv, + nn.Dropout(0.2), + ) + ) + # self.cnn = nn.Sequential(*self.cnn) + + self.lstm = nn.LSTM( + channels, channels // 2, 1, batch_first=True, bidirectional=True + ) + + def forward(self, x, input_lengths, m): + x = self.embedding(x) # [B, T, emb] + x = x.transpose(1, 2) # [B, emb, T] + m = m.to(input_lengths.device).unsqueeze(1) + x.masked_fill_(m, 0.0) + + for c in self.cnn: + x = c(x) + x.masked_fill_(m, 0.0) + + x = x.transpose(1, 2) # [B, T, chn] + + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True, enforce_sorted=False + ) + + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + + x = x.transpose(-1, -2) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + + x_pad[:, :, : x.shape[-1]] = x + x = x_pad.to(x.device) + + x.masked_fill_(m, 0.0) + + return x + + def inference(self, x): + x = self.embedding(x) + x = x.transpose(1, 2) + x = self.cnn(x) + x = x.transpose(1, 2) + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + return x + + def length_to_mask(self, lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)) + return mask + + +class AdaIN1d(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm1d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == "none": + return x + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class AdainResBlk1d(nn.Module): + def __init__( + self, + dim_in, + dim_out, + style_dim=64, + actv=nn.LeakyReLU(0.2), + upsample="none", + dropout_p=0.0, + ): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + + if upsample == "none": + self.pool = nn.Identity() + else: + self.pool = weight_norm( + nn.ConvTranspose1d( + dim_in, + dim_in, + kernel_size=3, + stride=2, + groups=dim_in, + padding=1, + output_padding=1, + ) + ) + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class AdaLayerNorm(nn.Module): + def __init__(self, style_dim, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.fc = nn.Linear(style_dim, channels * 2) + + def forward(self, x, s): + x = x.transpose(-1, -2) + x = x.transpose(1, -1) + + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) + + x = F.layer_norm(x, (self.channels,), eps=self.eps) + x = (1 + gamma) * x + beta + return x.transpose(1, -1).transpose(-1, -2) + + +class ProsodyPredictor(nn.Module): + def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): + super().__init__() + + self.text_encoder = DurationEncoder( + sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout + ) + + self.lstm = nn.LSTM( + d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True + ) + self.duration_proj = LinearNorm(d_hid, max_dur) + + self.shared = nn.LSTM( + d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True + ) + self.F0 = nn.ModuleList() + self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.F0.append( + AdainResBlk1d( + d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout + ) + ) + self.F0.append( + AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout) + ) + + self.N = nn.ModuleList() + self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.N.append( + AdainResBlk1d( + d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout + ) + ) + self.N.append( + AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout) + ) + + self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + + def forward(self, texts, style, text_lengths, alignment, m): + d = self.text_encoder(texts, style, text_lengths, m) + + batch_size = d.shape[0] + text_size = d.shape[1] + + # predict duration + input_lengths = text_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + d, input_lengths, batch_first=True, enforce_sorted=False + ) + + m = m.to(text_lengths.device).unsqueeze(1) + + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + + x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]]) + + x_pad[:, : x.shape[1], :] = x + x = x_pad.to(x.device) + + duration = self.duration_proj( + nn.functional.dropout(x, 0.5, training=self.training) + ) + + en = d.transpose(-1, -2) @ alignment + + return duration.squeeze(-1), en + + def F0Ntrain(self, x, s): + x, _ = self.shared(x.transpose(-1, -2)) + + F0 = x.transpose(-1, -2) + for block in self.F0: + F0 = block(F0, s) + F0 = self.F0_proj(F0) + + N = x.transpose(-1, -2) + for block in self.N: + N = block(N, s) + N = self.N_proj(N) + + return F0.squeeze(1), N.squeeze(1) + + def length_to_mask(self, lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)) + return mask + + +class DurationEncoder(nn.Module): + def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): + super().__init__() + self.lstms = nn.ModuleList() + for _ in range(nlayers): + self.lstms.append( + nn.LSTM( + d_model + sty_dim, + d_model // 2, + num_layers=1, + batch_first=True, + bidirectional=True, + dropout=dropout, + ) + ) + self.lstms.append(AdaLayerNorm(sty_dim, d_model)) + + self.dropout = dropout + self.d_model = d_model + self.sty_dim = sty_dim + + def forward(self, x, style, text_lengths, m): + masks = m.to(text_lengths.device) + + x = x.permute(2, 0, 1) + s = style.expand(x.shape[0], x.shape[1], -1) + x = torch.cat([x, s], axis=-1) + x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) + + x = x.transpose(0, 1) + input_lengths = text_lengths.cpu().numpy() + x = x.transpose(-1, -2) + + for block in self.lstms: + if isinstance(block, AdaLayerNorm): + x = block(x.transpose(-1, -2), style).transpose(-1, -2) + x = torch.cat([x, s.permute(1, -1, 0)], axis=1) + x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0) + else: + x = x.transpose(-1, -2) + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True, enforce_sorted=False + ) + block.flatten_parameters() + x, _ = block(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + x = F.dropout(x, p=self.dropout, training=self.training) + x = x.transpose(-1, -2) + + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + + x_pad[:, :, : x.shape[-1]] = x + x = x_pad.to(x.device) + + return x.transpose(-1, -2) + + def inference(self, x, style): + x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model) + style = style.expand(x.shape[0], x.shape[1], -1) + x = torch.cat([x, style], axis=-1) + src = self.pos_encoder(x) + output = self.transformer_encoder(src).transpose(0, 1) + return output + + def length_to_mask(self, lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)) + return mask + + +def load_F0_models(path): + # load F0 model + + F0_model = JDCNet(num_class=1, seq_len=192) + params = torch.load(path, map_location="cpu")["net"] + F0_model.load_state_dict(params) + _ = F0_model.train() + + return F0_model + + +def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG): + # load ASR model + def _load_config(path): + with open(path) as f: + config = yaml.safe_load(f) + model_config = config["model_params"] + return model_config + + def _load_model(model_config, model_path): + model = ASRCNN(**model_config) + params = torch.load(model_path, map_location="cpu")["model"] + model.load_state_dict(params) + return model + + asr_model_config = _load_config(ASR_MODEL_CONFIG) + asr_model = _load_model(asr_model_config, ASR_MODEL_PATH) + _ = asr_model.train() + + return asr_model + + +def build_model(args, text_aligner, pitch_extractor, bert): + assert args.decoder.type in ["istftnet", "hifigan"], "Decoder type unknown" + + if args.decoder.type == "istftnet": + from Modules.istftnet import Decoder + + decoder = Decoder( + dim_in=args.hidden_dim, + style_dim=args.style_dim, + dim_out=args.n_mels, + resblock_kernel_sizes=args.decoder.resblock_kernel_sizes, + upsample_rates=args.decoder.upsample_rates, + upsample_initial_channel=args.decoder.upsample_initial_channel, + resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, + upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, + gen_istft_n_fft=args.decoder.gen_istft_n_fft, + gen_istft_hop_size=args.decoder.gen_istft_hop_size, + ) + else: + from Modules.hifigan import Decoder + + decoder = Decoder( + dim_in=args.hidden_dim, + style_dim=args.style_dim, + dim_out=args.n_mels, + resblock_kernel_sizes=args.decoder.resblock_kernel_sizes, + upsample_rates=args.decoder.upsample_rates, + upsample_initial_channel=args.decoder.upsample_initial_channel, + resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, + upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, + ) + + text_encoder = TextEncoder( + channels=args.hidden_dim, + kernel_size=5, + depth=args.n_layer, + n_symbols=args.n_token, + ) + + predictor = ProsodyPredictor( + style_dim=args.style_dim, + d_hid=args.hidden_dim, + nlayers=args.n_layer, + max_dur=args.max_dur, + dropout=args.dropout, + ) + + style_encoder = StyleEncoder( + dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim + ) # acoustic style encoder + predictor_encoder = StyleEncoder( + dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim + ) # prosodic style encoder + + # define diffusion model + if args.multispeaker: + transformer = StyleTransformer1d( + channels=args.style_dim * 2, + context_embedding_features=bert.config.hidden_size, + context_features=args.style_dim * 2, + **args.diffusion.transformer + ) + else: + transformer = Transformer1d( + channels=args.style_dim * 2, + context_embedding_features=bert.config.hidden_size, + **args.diffusion.transformer + ) + + diffusion = AudioDiffusionConditional( + in_channels=1, + embedding_max_length=bert.config.max_position_embeddings, + embedding_features=bert.config.hidden_size, + embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements, + channels=args.style_dim * 2, + context_features=args.style_dim * 2, + ) + + diffusion.diffusion = KDiffusion( + net=diffusion.unet, + sigma_distribution=LogNormalDistribution( + mean=args.diffusion.dist.mean, std=args.diffusion.dist.std + ), + sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model + dynamic_threshold=0.0, + ) + diffusion.diffusion.net = transformer + diffusion.unet = transformer + + nets = Munch( + bert=bert, + bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim), + predictor=predictor, + decoder=decoder, + text_encoder=text_encoder, + predictor_encoder=predictor_encoder, + style_encoder=style_encoder, + diffusion=diffusion, + text_aligner=text_aligner, + pitch_extractor=pitch_extractor, + mpd=MultiPeriodDiscriminator(), + msd=MultiResSpecDiscriminator(), + # slm discriminator head + wd=WavLMDiscriminator( + args.slm.hidden, args.slm.nlayers, args.slm.initial_channel + ), + ) + + return nets + + +def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]): + state = torch.load(path, map_location="cpu") + params = state["net"] + for key in model: + if key in params and key not in ignore_modules: + print("%s loaded" % key) + model[key].load_state_dict(params[key], strict=False) + _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + else: + epoch = 0 + iters = 0 + + return model, optimizer, epoch, iters + + +class TextEncoderOpenVoice(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None, tau=1.0): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0, layernorm=True): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + if layernorm: + self.layernorm = nn.LayerNorm(self.spec_channels) + else: + self.layernorm = None + + def forward(self, inputs, mask=None): + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=256, + gin_channels=256, + **kwargs + ): + super().__init__() + + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + self.n_speakers = n_speakers + if n_speakers == 0: + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) + else: + self.enc_p = TextEncoderOpenVoice(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \ + + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) + + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): + g_src = sid_src + g_tgt = sid_tgt + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) diff --git a/src/modules.py b/src/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a0461728edee2171f43b0a99f082d1ca6ab7cb3f --- /dev/null +++ b/src/modules.py @@ -0,0 +1,598 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm + +import commons +from commons import init_weights, get_padding +from transforms import piecewise_rational_quadratic_transform +from attentions import Encoder + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d( + gin_channels, 2 * hidden_channels * n_layers, 1 + ) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class TransformerCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels=0, + ): + assert n_layers == 3, n_layers + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = ( + Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=gin_channels, + ) + if wn_sharing_parameter is None + else wn_sharing_parameter + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x diff --git a/src/mytest.py b/src/mytest.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce8a2ae5c86724d74fd3b78ac7393053312cb36 --- /dev/null +++ b/src/mytest.py @@ -0,0 +1,82 @@ +''' +import os +import torch +import se_extractor +from api import ToneColorConverter + +ckpt_converter = 'checkpoints/converter' +device = 'cuda:0' +output_dir = 'outputs' + +tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) +tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') + +os.makedirs(output_dir, exist_ok=True) + +from openai import OpenAI + +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + +response = client.audio.speech.create( + model="tts-1", + voice="nova", + input="This audio will be used to extract the base speaker tone color embedding. " + \ + "Typically a very short audio should be sufficient, but increasing the audio " + \ + "length will also improve the output audio quality." +) + +response.stream_to_file(f"{output_dir}/openai_source_output.mp3") + +base_speaker = f"{output_dir}/openai_source_output.mp3" +source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter) + +reference_speaker = 'resources/example_reference.mp3' +target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter) + +text = [ + "MyShell is a decentralized and comprehensive platform for discovering, creating, and staking AI-native apps.", +] +src_path = f'{output_dir}/tmp.wav' + +for i, t in enumerate(text): + + response = client.audio.speech.create( + model="tts-1", + voice="alloy", + input=t, + ) + + response.stream_to_file(src_path) + + save_path = f'{output_dir}/output_crosslingual_{i}.wav' + + tone_color_converter.convert( + audio_src_path=src_path, + src_se=source_se, + tgt_se=target_se, + output_path=save_path, + message='') + + + +model = models.openai("gpt-3.5-turbo",system_prompt='You are an expert in identifying the emotion of a sentence') +result = model.generate_choice("Harry's mind was racing with thoughts of the recent events at Hogwarts", ["friendly", "cheerful", "excited", "sad", "angry", "terrified", "shouting", "whispering"]) +print(result) +from openai import OpenAI +import os +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + +response = client.audio.speech.create( + model="tts-1", + voice="fable", + input="This audio will be used to extract the base speaker tone color embedding. " + \ + "Typically a very short audio should be sufficient, but increasing the audio " + \ + "length will also improve the output audio quality." +) + +response.stream_to_file(f"openai_source_output.mp3") +''' +import boto3 +s3_client = boto3.client('s3',aws_access_key_id='AKIAW7WTE5RKJY2WJ55F', aws_secret_access_key='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') +response = s3_client.upload_file('/root/src/videly/openai_source_output.mp3', 'demovidelyusergenerations', 'test.mp3') \ No newline at end of file diff --git a/src/openai_source_output.mp3 b/src/openai_source_output.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..fcfb1c1bf1248b106c9ca487b418e7ff676f1797 --- /dev/null +++ b/src/openai_source_output.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8275bdb21514d1b01accecb27fdf3fbe0d27ae61dbc14840e875157d9733bbd9 +size 252960 diff --git a/src/optimizers.py b/src/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..e192bc1f2d2c21f63c9c7d3bf6195715c8713278 --- /dev/null +++ b/src/optimizers.py @@ -0,0 +1,86 @@ +# coding:utf-8 +import os, sys +import os.path as osp +import numpy as np +import torch +from torch import nn +from torch.optim import Optimizer +from functools import reduce +from torch.optim import AdamW + + +class MultiOptimizer: + def __init__(self, optimizers={}, schedulers={}): + self.optimizers = optimizers + self.schedulers = schedulers + self.keys = list(optimizers.keys()) + self.param_groups = reduce( + lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()] + ) + + def state_dict(self): + state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys] + return state_dicts + + def load_state_dict(self, state_dict): + for key, val in state_dict: + try: + self.optimizers[key].load_state_dict(val) + except: + print("Unloaded %s" % key) + + def step(self, key=None, scaler=None): + keys = [key] if key is not None else self.keys + _ = [self._step(key, scaler) for key in keys] + + def _step(self, key, scaler=None): + if scaler is not None: + scaler.step(self.optimizers[key]) + scaler.update() + else: + self.optimizers[key].step() + + def zero_grad(self, key=None): + if key is not None: + self.optimizers[key].zero_grad() + else: + _ = [self.optimizers[key].zero_grad() for key in self.keys] + + def scheduler(self, *args, key=None): + if key is not None: + self.schedulers[key].step(*args) + else: + _ = [self.schedulers[key].step(*args) for key in self.keys] + + +def define_scheduler(optimizer, params): + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=params.get("max_lr", 2e-4), + epochs=params.get("epochs", 200), + steps_per_epoch=params.get("steps_per_epoch", 1000), + pct_start=params.get("pct_start", 0.0), + div_factor=1, + final_div_factor=1, + ) + + return scheduler + + +def build_optimizer(parameters_dict, scheduler_params_dict, lr): + optim = dict( + [ + (key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9)) + for key, params in parameters_dict.items() + ] + ) + + schedulers = dict( + [ + (key, define_scheduler(opt, scheduler_params_dict[key])) + for key, opt in optim.items() + ] + ) + + multi_optim = MultiOptimizer(optim, schedulers) + return multi_optim diff --git a/src/predict.py b/src/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ddc4fe284177af939a04383308de19de3df080 --- /dev/null +++ b/src/predict.py @@ -0,0 +1,364 @@ +""" +This file contains the Predictor class, which is used to run predictions on the +Whisper model. It is based on the Predictor class from the original Whisper +repository, with some modifications to make it work with the RP platform. +""" + +from concurrent.futures import ThreadPoolExecutor +import numpy as np + +from runpod.serverless.utils import rp_cuda +import boto3 +import random +random.seed(0) +from glob import glob +import subprocess + +import io + +import numpy as np +np.random.seed(0) +import subprocess +import se_extractor + +import yaml +from munch import Munch +import uuid +import shutil +from openai import OpenAI + + +import time +import os +import phonemizer +import torch +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +from nltk.tokenize import word_tokenize +import nltk +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule +nltk.download('punkt') +from models import * +from utils import * +import soundfile as sf +from tortoise.utils.text import split_and_recombine_text +from resemble_enhance.enhancer.inference import denoise, enhance +from text_utils import TextCleaner +from pydantic import BaseModel, HttpUrl +from api import BaseSpeakerTTS, ToneColorConverter + +class Predictor: + def __init__(self): + self.model = None + self.sampler = None + self.to_mel = None + self.global_phonemizer = None + self.model_params = None + self.textclenaer = None + self.mean = 0 + self.std = 0 + self.device = 'cuda:0' + + self.ckpt_base = 'checkpoints/base_speakers/EN' + self.ckpt_converter = 'checkpoints/converter' + self.base_speaker_tts = None + self.tone_color_converter = None + self.output_dir = 'outputs' + self.processed_dir = 'processed' + os.makedirs(self.processed_dir, exist_ok=True) + os.makedirs(self.output_dir, exist_ok=True) + self.s3_client = boto3.client('s3',aws_access_key_id=os.getenv('AWS_ACCESS_KEY'), aws_secret_access_key=os.getenv('AWS_SECRET_KEY')) + print(os.getenv("AWS_ACCESS_KEY")) + print(os.getenv("AWS_SECRET_KEY")) + + + def setup(self): + self.global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) + self.textclenaer = TextCleaner() + self.to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) + self.mean, self.std = -4, 4 + + config = yaml.safe_load(open("Configs/hg.yml")) + print(config) + + ASR_config = config.get('ASR_config', False) + ASR_path = config.get('ASR_path', False) + text_aligner = load_ASR_models(ASR_path, ASR_config) + + F0_path = config.get('F0_path', False) + pitch_extractor = load_F0_models(F0_path) + + from Utils.PLBERT.util import load_plbert + BERT_path = config.get('PLBERT_dir', False) + plbert = load_plbert(BERT_path) + + self.model_params = recursive_munch(config['model_params']) + self.model = build_model(self.model_params, text_aligner, pitch_extractor, plbert) + _ = [self.model[key].eval() for key in self.model] + _ = [self.model[key].to(self.device) for key in self.model] + + params_whole = torch.load("Models/epochs_2nd_00020.pth", map_location='cpu') + params = params_whole['net'] + + for key in self.model: + if key in params: + print('%s loaded' % key) + try: + self.model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + self.model[key].load_state_dict(new_state_dict, strict=False) + # except: + # _load(params[key], model[key]) + _ = [self.model[key].eval() for key in self.model] + self.sampler = DiffusionSampler( + self.model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False + ) + + + def predict(self,s3_url,passage,method_type='voice_clone'): + output_dir = 'processed' + gen_id = str(uuid.uuid4()) + os.makedirs(output_dir,exist_ok=True) + raw_dir = os.path.join(output_dir,gen_id,'raw') + segments_dir = os.path.join(output_dir,gen_id,'segments') + results_dir = os.path.join(output_dir,gen_id,'results') + openvoice_dir = os.path.join(output_dir,gen_id,'openvoice') + os.makedirs(raw_dir) + os.makedirs(segments_dir) + os.makedirs(results_dir) + + + s3_key = s3_url.split('/')[-1] + bucket_name = 'demovidelyuseruploads' + local_file_path = os.path.join(raw_dir,s3_key) + self.download_file_from_s3(self.s3_client,bucket_name,s3_key,local_file_path) + + se_extractor.generate_voice_segments(local_file_path,segments_dir,vad=True) + + if method_type == 'voice_clone': + #voice_clone with styletts2 + model,sampler = self.model,self.sampler + processed_seg_dir = os.path.join(segments_dir,s3_key.split('.')[0],'wavs') + result = self.process_audio_file(processed_seg_dir,passage,model,sampler) + final_output = os.path.join(results_dir,f"{gen_id}-voice-clone-1.wav") + sf.write(final_output,result,24000) + + + mp3_final_output_1 = str(final_output).replace('wav','mp3') + self.convert_wav_to_mp3(final_output,mp3_final_output_1) + print(mp3_final_output_1) + self.upload_file_to_s3(mp3_final_output_1,'demovidelyusergenerations',f"{gen_id}-voice-clone-1.mp3") + return {"voice_clone_1":f"https://demovidelyusergenerations.s3.amazonaws.com/{gen_id}-voice-clone-1.mp3"} + + + def _fn(self,path, solver, nfe, tau): + if path is None: + return None, None + + solver = solver.lower() + nfe = int(nfe) + lambd = 0.9 + + dwav, sr = torchaudio.load(path) + dwav = dwav.mean(dim=0) + + wav1, new_sr = enhance(dwav, sr, self.device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) + + wav1 = wav1.cpu().numpy() + + return (new_sr, wav1) + + def _fn_denoise(self,path, solver, nfe, tau): + if path is None: + return None + print(torch.cuda.is_available()) + print("Going to denoise") + solver = solver.lower() + nfe = int(nfe) + lambd = 0.9 + + dwav, sr = torchaudio.load(path) + dwav = dwav.mean(dim=0) + + wav1, new_sr = denoise(dwav, sr, self.device) + + wav1 = wav1.cpu().numpy() + print("Done noising") + + return (new_sr, wav1) + + def LFinference(self,model,sampler,text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1): + text = text.strip() + ps = self.global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + ps = ps.replace('``', '"') + ps = ps.replace("''", '"') + + tokens = self.textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) + text_mask = self.length_to_mask(input_lengths).to(self.device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(self.device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + if s_prev is not None: + # convex combination of previous and current style + s_pred = t * s_prev + (1 - t) * s_pred + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + s_pred = torch.cat([ref, s], dim=-1) + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) + if self.model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) + if self.model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-100], s_pred # + + def length_to_mask(self,lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + + def preprocess(self,wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = self.to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std + return mel_tensor + + def compute_style(self,path,model): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = self.preprocess(audio).to(self.device) + + with torch.no_grad(): + ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + + def process_audio_file(self,file_dir,passage,model,sampler): + print(file_dir) + audio_segs = glob(f'{file_dir}/*.wav') + print(audio_segs) + if len(audio_segs) >= 1: + s_ref = self.compute_style(audio_segs[0], model) + else: + raise NotImplementedError('No audio segments found!') + sentences = split_and_recombine_text(passage) + wavs = [] + s_prev = None + for text in sentences: + if text.strip() == "": continue + text += '.' + wav, s_prev = self.LFinference(model,sampler,text, + s_prev, + s_ref, + alpha = 0, + beta = 0.3, # make it more suitable for the text + t = 0.7, + diffusion_steps=10, embedding_scale=1) + wavs.append(wav) + + audio_arrays = [] + for wav_file in wavs: + audio_arrays.append(wav_file) + concatenated_audio = np.concatenate(audio_arrays) + return concatenated_audio + + def download_file_from_s3(self,s3_client,bucket_name, s3_key, local_file_path): + try: + s3_client.download_file(bucket_name, s3_key, local_file_path) + print(f"File downloaded successfully: {local_file_path}") + except Exception as e: + print(f"Error downloading file: {e}") + + + def convert_wav_to_mp3(self,wav_file, mp3_file): + command = ['ffmpeg', '-i', wav_file, '-q:a', '0', '-map', 'a', mp3_file] + subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + + def upload_file_to_s3(self,file_name, bucket, object_name=None, content_type="audio/mpeg"): + + if object_name is None: + object_name = file_name + + try: + with open(file_name, 'rb') as file_data: + self.s3_client.put_object(Bucket=bucket, Key=object_name, Body=file_data, ContentType=content_type) + print("File uploaded successfully") + return True + except NoCredentialsError: + print("Error: No AWS credentials found") + return False + except Exception as e: + print(f"Error uploading file: {e}") + return False \ No newline at end of file diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/raw/039cf8da-75b8-474d-affa-fc84066c3fa3.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/raw/039cf8da-75b8-474d-affa-fc84066c3fa3.wav new file mode 100644 index 0000000000000000000000000000000000000000..fc15971244fc0a99abf29f0ccf032cb17cf8ed98 --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/raw/039cf8da-75b8-474d-affa-fc84066c3fa3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68f1fdaa436c8072a3d58f8234507be22e12302c77b78ee19b1a911168f96d33 +size 3098668 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.mp3 b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..f80b5b1de828297047c90a4cb7c24bf3e65158d9 --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b639c38bce187bdaa36582877934f9896fe87fa4a05731ed2e5f1ce9bf794820 +size 1261173 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.wav new file mode 100644 index 0000000000000000000000000000000000000000..41880fc6a9999dc6f66df1cd9f6f47ab1816606a --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/results/9ac5dfd2-1477-4903-adfc-1cc4d0351977-voice-clone-1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dd8186db2bab0a68d1b6c432f1226c26b38bafb993a77734b0d79f4b32433c3 +size 4954644 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg0.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg0.wav new file mode 100644 index 0000000000000000000000000000000000000000..b37399c9c0e32265c257b45de375886e3575b315 --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg0.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:312b3f6534a92d88ea1d1fdbe12ded45c29360d602a74864d2787329b4dbeddd +size 774616 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg1.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg1.wav new file mode 100644 index 0000000000000000000000000000000000000000..672d03296103f189a67dd160e96b3cfac0286e4f --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbff348a94a69ed3875f4a89e589b542da96af938248748de0e62b416fe76aa4 +size 774704 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg2.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg2.wav new file mode 100644 index 0000000000000000000000000000000000000000..56d8eb1cd7f90757817fad46c33762f877c89f2c --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcca412270a4dd74f260396a72772cdfbcdc557a9c66cf20b9bfec2b350778f1 +size 774704 diff --git a/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg3.wav b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg3.wav new file mode 100644 index 0000000000000000000000000000000000000000..6a9e645ba5c1bb16b4c937f44bfcb10d8d7b1180 --- /dev/null +++ b/src/processed/9ac5dfd2-1477-4903-adfc-1cc4d0351977/segments/039cf8da-75b8-474d-affa-fc84066c3fa3/wavs/039cf8da-75b8-474d-affa-fc84066c3fa3_seg3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0ee46e6a9974651b2f2f5c350f7ea2a3e3590b5197394292ad31e51b1fc4bca +size 774618 diff --git a/src/resources/framework.jpg b/src/resources/framework.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26fb4b9b0fbdcc22183b18a5ed21e24651e6ae0b Binary files /dev/null and b/src/resources/framework.jpg differ diff --git a/src/resources/lepton.jpg b/src/resources/lepton.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5bd5601481f38589f2eb2954396911bd0aae74ad Binary files /dev/null and b/src/resources/lepton.jpg differ diff --git a/src/resources/myshell.jpg b/src/resources/myshell.jpg new file mode 100644 index 0000000000000000000000000000000000000000..501d7ab6b02fe714ab25e9c0d954d2a0684268fa Binary files /dev/null and b/src/resources/myshell.jpg differ diff --git a/src/resources/openvoicelogo.jpg b/src/resources/openvoicelogo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1bc9b9e38bae8ee38e998f5136a1e7a9ed967e80 Binary files /dev/null and b/src/resources/openvoicelogo.jpg differ diff --git a/src/rp_handler.py b/src/rp_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..bb73d95a2ac0f2d50119f04aa44d98bf6273f384 --- /dev/null +++ b/src/rp_handler.py @@ -0,0 +1,33 @@ +""" +rp_handler.py for runpod worker + +rp_debugger: +- Utility that provides additional debugging information. +The handler must be called with --rp_debugger flag to enable it. +""" +import base64 +import tempfile + +from rp_schema import INPUT_VALIDATIONS +from runpod.serverless.utils import download_files_from_urls, rp_cleanup, rp_debugger +from runpod.serverless.utils.rp_validator import validate +import runpod +import predict + +MODEL = predict.Predictor() +MODEL.setup() + + +@rp_debugger.FunctionTimer +def run_voice_clone_job(job): + job_input = job['input'] + method_type = job_input['method_type'] + assert method_type in ["create_voice","voice_clone","voice_clone_with_emotions","voice_clone_with_multi_lang"] + s3_url = job_input['s3_url'] + passage = job_input['passage'] + processed_urls = MODEL.predict(s3_url,passage) + + return processed_urls + + +runpod.serverless.start({"handler": run_voice_clone_job}) \ No newline at end of file diff --git a/src/rp_schema.py b/src/rp_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f16f8eac0c563fbb0997cf8dc477a1d5f959ac --- /dev/null +++ b/src/rp_schema.py @@ -0,0 +1,18 @@ +INPUT_VALIDATIONS = { + 'method_type': { + 'type': str, + 'required': True, + 'default': 'voice_clone' + }, + 's3_url': { + 'type': str, + 'required': False, + 'default': 'None' + }, + 'passage': { + 'type': str, + 'required': False, + 'default': 'None' + }, + +} \ No newline at end of file diff --git a/src/se_extractor.py b/src/se_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..16fc679003d2fadfb4c21635720936b07b56d29b --- /dev/null +++ b/src/se_extractor.py @@ -0,0 +1,153 @@ +import os +import glob +import torch +from glob import glob +import numpy as np +from pydub import AudioSegment +from faster_whisper import WhisperModel +from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments + +model_size = "medium" +# Run on GPU with FP16 +model = None +def split_audio_whisper(audio_path, target_dir='processed'): + global model + if model is None: + model = WhisperModel(model_size, device="cuda", compute_type="float16") + audio = AudioSegment.from_file(audio_path) + max_len = len(audio) + + audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] + target_folder = os.path.join(target_dir, audio_name) + + segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True) + segments = list(segments) + + # create directory + os.makedirs(target_folder, exist_ok=True) + wavs_folder = os.path.join(target_folder, 'wavs') + os.makedirs(wavs_folder, exist_ok=True) + + # segments + s_ind = 0 + start_time = None + + for k, w in enumerate(segments): + # process with the time + if k == 0: + start_time = max(0, w.start) + + end_time = w.end + + # calculate confidence + if len(w.words) > 0: + confidence = sum([s.probability for s in w.words]) / len(w.words) + else: + confidence = 0. + # clean text + text = w.text.replace('...', '') + + # left 0.08s for each audios + audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)] + + # segment file name + fname = f"{audio_name}_seg{s_ind}.wav" + + # filter out the segment shorter than 1.5s and longer than 20s + save = audio_seg.duration_seconds > 1.5 and \ + audio_seg.duration_seconds < 20. and \ + len(text) >= 2 and len(text) < 200 + + if save: + output_file = os.path.join(wavs_folder, fname) + audio_seg.export(output_file, format='wav') + + if k < len(segments) - 1: + start_time = max(0, segments[k+1].start - 0.08) + + s_ind = s_ind + 1 + return wavs_folder + + +def split_audio_vad(audio_path, target_dir, split_seconds=10.0): + SAMPLE_RATE = 16000 + audio_vad = get_audio_tensor(audio_path) + segments = get_vad_segments( + audio_vad, + output_sample=True, + min_speech_duration=0.1, + min_silence_duration=1, + method="silero", + ) + segments = [(seg["start"], seg["end"]) for seg in segments] + segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments] + print(segments) + audio_active = AudioSegment.silent(duration=0) + audio = AudioSegment.from_file(audio_path) + + for start_time, end_time in segments: + audio_active += audio[int( start_time * 1000) : int(end_time * 1000)] + + audio_dur = audio_active.duration_seconds + print(f'after vad: dur = {audio_dur}') + audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] + target_folder = os.path.join(target_dir, audio_name) + wavs_folder = os.path.join(target_folder, 'wavs') + os.makedirs(wavs_folder, exist_ok=True) + start_time = 0. + count = 0 + num_splits = int(np.round(audio_dur / split_seconds)) + assert num_splits > 0, 'input audio is too short' + interval = audio_dur / num_splits + + for i in range(num_splits): + end_time = min(start_time + interval, audio_dur) + if i == num_splits - 1: + end_time = audio_dur + output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav" + audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)] + audio_seg.export(output_file, format='wav') + start_time = end_time + count += 1 + return wavs_folder + + + + + +def get_se(audio_path, vc_model, target_dir='processed', vad=True): + device = vc_model.device + + audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] + se_path = os.path.join(target_dir, audio_name, 'se.pth') + + if os.path.isfile(se_path): + se = torch.load(se_path).to(device) + return se, audio_name + if os.path.isdir(audio_path): + wavs_folder = audio_path + elif vad: + wavs_folder = split_audio_vad(audio_path, target_dir) + else: + wavs_folder = split_audio_whisper(audio_path, target_dir) + + audio_segs = glob(f'{wavs_folder}/*.wav') + if len(audio_segs) == 0: + raise NotImplementedError('No audio segments found!') + + return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name + + + +def generate_voice_segments(audio_path, target_dir='processed', vad=True): + audio_name = os.path.basename(audio_path).rsplit('.', 1)[0] + + if vad: + wavs_folder = split_audio_vad(audio_path, target_dir) + else: + wavs_folder = split_audio_whisper(audio_path, target_dir) + + audio_segs = glob(f'{wavs_folder}/*.wav') + if len(audio_segs) == 0: + raise NotImplementedError('No audio segments found!') + diff --git a/src/styletts2importable.py b/src/styletts2importable.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ed57e41efac2b73cbc43e277b99b5bfc9cf669 --- /dev/null +++ b/src/styletts2importable.py @@ -0,0 +1,362 @@ +# print("GRUUT") +# from gruut_phonemize import gphonemize + +# from dp.phonemizer import Phonemizer +print("NLTK") +import nltk +nltk.download('punkt') +print("SCIPY") +from scipy.io.wavfile import write +print("TORCH STUFF") +import torch +print("START") +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +import random +random.seed(0) + +import numpy as np +np.random.seed(0) + +# load packages +import time +import random +import yaml +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +from nltk.tokenize import word_tokenize + +from models import * +from utils import * +from text_utils import TextCleaner +textclenaer = TextCleaner() + + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) +mean, std = -4, 4 + +def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + +def compute_style(path): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = preprocess(audio).to(device) + + with torch.no_grad(): + ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + +device = 'cpu' +if torch.cuda.is_available(): + device = 'cuda' +elif torch.backends.mps.is_available(): + print("MPS would be available but cannot be used rn") + # device = 'mps' + +import phonemizer +global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) +# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt'))) + + +# config = yaml.safe_load(open("Models/LibriTTS/config.yml")) +config = yaml.safe_load(open("Configs/hg.yml")) + +# load pretrained ASR model +ASR_config = config.get('ASR_config', False) +ASR_path = config.get('ASR_path', False) +text_aligner = load_ASR_models(ASR_path, ASR_config) + +# load pretrained F0 model +F0_path = config.get('F0_path', False) +pitch_extractor = load_F0_models(F0_path) + +# load BERT model +from Utils.PLBERT.util import load_plbert +BERT_path = config.get('PLBERT_dir', False) +plbert = load_plbert(BERT_path) + +model_params = recursive_munch(config['model_params']) +model = build_model(model_params, text_aligner, pitch_extractor, plbert) +_ = [model[key].eval() for key in model] +_ = [model[key].to(device) for key in model] + +# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu') +params_whole = torch.load("Models/epochs_2nd_00020.pth", map_location='cpu') +params = params_whole['net'] + +for key in model: + if key in params: + print('%s loaded' % key) + try: + model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model[key].load_state_dict(new_state_dict, strict=False) +# except: +# _load(params[key], model[key]) +_ = [model[key].eval() for key in model] + +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False +) + +def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later + +def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + ps = ps.replace('``', '"') + ps = ps.replace("''", '"') + + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + if s_prev is not None: + # convex combination of previous and current style + s_pred = t * s_prev + (1 - t) * s_pred + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + s_pred = torch.cat([ref, s], dim=-1) + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later + +def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + + tokens = textclenaer(ps) + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + + ref_text = ref_text.strip() + ps = global_phonemizer.phonemize([ref_text]) + ps = word_tokenize(ps[0]) + ps = ' '.join(ps) + + ref_tokens = textclenaer(ps) + ref_tokens.insert(0, 0) + ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0) + + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device) + ref_text_mask = length_to_mask(ref_input_lengths).to(device) + ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int()) + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later \ No newline at end of file diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6b8bbd36e3532bdc6093d2c1e78523fd6c457a --- /dev/null +++ b/src/test.py @@ -0,0 +1,14 @@ +from gruut import sentences + + +text = input("> ") +phonemes = '' +for sent in sentences(text, lang="en-us"): + for word in sent: + if word.phonemes: + print(word.text + ":" + ''.join(word.phonemes)) + phonemes += ''.join(word.phonemes) + else: + print(word.text + ": NoPhonemes") +print("--") +print(phonemes) diff --git a/src/text/__init__.py b/src/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d51bec341abde9a898bb617711edae4de9ec11e --- /dev/null +++ b/src/text/__init__.py @@ -0,0 +1,79 @@ +""" from https://github.com/keithito/tacotron """ +from text import cleaners +from text.symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + + +def text_to_sequence(text, symbols, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + symbol_to_id = {s: i for i, s in enumerate(symbols)} + clean_text = _clean_text(text, cleaner_names) + print(clean_text) + print(f" length:{len(clean_text)}") + for symbol in clean_text: + if symbol not in symbol_to_id.keys(): + continue + symbol_id = symbol_to_id[symbol] + sequence += [symbol_id] + print(f" length:{len(sequence)}") + return sequence + + +def cleaned_text_to_sequence(cleaned_text, symbols): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + ''' + symbol_to_id = {s: i for i, s in enumerate(symbols)} + sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] + return sequence + + + +from text.symbols import language_tone_start_map +def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + symbol_to_id = {s: i for i, s in enumerate(symbols)} + language_id_map = {s: i for i, s in enumerate(languages)} + phones = [symbol_to_id[symbol] for symbol in cleaned_text] + tone_start = language_tone_start_map[language] + tones = [i + tone_start for i in tones] + lang_id = language_id_map[language] + lang_ids = [lang_id for i in phones] + return phones, tones, lang_ids + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text diff --git a/src/text/__pycache__/__init__.cpython-310.pyc b/src/text/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c57bc53b14d07c93b77e0cbde794c5ef95cd38d6 Binary files /dev/null and b/src/text/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/text/__pycache__/cleaners.cpython-310.pyc b/src/text/__pycache__/cleaners.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58b557d75a2bbd14018eed69322e94b6ea72c85b Binary files /dev/null and b/src/text/__pycache__/cleaners.cpython-310.pyc differ diff --git a/src/text/__pycache__/english.cpython-310.pyc b/src/text/__pycache__/english.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0634adfeef5e38883e0337ee8badc0f2d537184 Binary files /dev/null and b/src/text/__pycache__/english.cpython-310.pyc differ diff --git a/src/text/__pycache__/mandarin.cpython-310.pyc b/src/text/__pycache__/mandarin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a2ca6519ab0573ab4fd6dfa0a7987b8cf5d091 Binary files /dev/null and b/src/text/__pycache__/mandarin.cpython-310.pyc differ diff --git a/src/text/__pycache__/symbols.cpython-310.pyc b/src/text/__pycache__/symbols.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31c09a38e64e9149ebe2575f69401ef81a5a1444 Binary files /dev/null and b/src/text/__pycache__/symbols.cpython-310.pyc differ diff --git a/src/text/cleaners.py b/src/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..619ad47956fe3bbf3f76f09164768307645366b7 --- /dev/null +++ b/src/text/cleaners.py @@ -0,0 +1,16 @@ +import re +from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 +from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 + +def cjke_cleaners2(text): + text = re.sub(r'\[ZH\](.*?)\[ZH\]', + lambda x: chinese_to_ipa(x.group(1))+' ', text) + text = re.sub(r'\[JA\](.*?)\[JA\]', + lambda x: japanese_to_ipa2(x.group(1))+' ', text) + text = re.sub(r'\[KO\](.*?)\[KO\]', + lambda x: korean_to_ipa(x.group(1))+' ', text) + text = re.sub(r'\[EN\](.*?)\[EN\]', + lambda x: english_to_ipa2(x.group(1))+' ', text) + text = re.sub(r'\s+$', '', text) + text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) + return text \ No newline at end of file diff --git a/src/text/english.py b/src/text/english.py new file mode 100644 index 0000000000000000000000000000000000000000..736a53a7bc66cfdd776aa1fa01439f1e6e46f1c9 --- /dev/null +++ b/src/text/english.py @@ -0,0 +1,188 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + + +# Regular expression matching whitespace: + + +import re +import inflect +from unidecode import unidecode +import eng_to_ipa as ipa +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +# List of (ipa, lazy ipa) pairs: +_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('r', 'ɹ'), + ('æ', 'e'), + ('ɑ', 'a'), + ('ɔ', 'o'), + ('ð', 'z'), + ('θ', 's'), + ('ɛ', 'e'), + ('ɪ', 'i'), + ('ʊ', 'u'), + ('ʒ', 'ʥ'), + ('ʤ', 'ʥ'), + ('ˈ', '↓'), +]] + +# List of (ipa, lazy ipa2) pairs: +_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('r', 'ɹ'), + ('ð', 'z'), + ('θ', 's'), + ('ʒ', 'ʑ'), + ('ʤ', 'dʑ'), + ('ˈ', '↓'), +]] + +# List of (ipa, ipa2) pairs +_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('r', 'ɹ'), + ('ʤ', 'dʒ'), + ('ʧ', 'tʃ') +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def collapse_whitespace(text): + return re.sub(r'\s+', ' ', text) + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +def mark_dark_l(text): + return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) + + +def english_to_ipa(text): + text = unidecode(text).lower() + text = expand_abbreviations(text) + text = normalize_numbers(text) + phonemes = ipa.convert(text) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_to_lazy_ipa(text): + text = english_to_ipa(text) + for regex, replacement in _lazy_ipa: + text = re.sub(regex, replacement, text) + return text + + +def english_to_ipa2(text): + text = english_to_ipa(text) + text = mark_dark_l(text) + for regex, replacement in _ipa_to_ipa2: + text = re.sub(regex, replacement, text) + return text.replace('...', '…') + + +def english_to_lazy_ipa2(text): + text = english_to_ipa(text) + for regex, replacement in _lazy_ipa2: + text = re.sub(regex, replacement, text) + return text diff --git a/src/text/mandarin.py b/src/text/mandarin.py new file mode 100644 index 0000000000000000000000000000000000000000..162e1b912dabec4b448ccd3d00d56306f82ce076 --- /dev/null +++ b/src/text/mandarin.py @@ -0,0 +1,326 @@ +import os +import sys +import re +from pypinyin import lazy_pinyin, BOPOMOFO +import jieba +import cn2an +import logging + + +# List of (Latin alphabet, bopomofo) pairs: +_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('a', 'ㄟˉ'), + ('b', 'ㄅㄧˋ'), + ('c', 'ㄙㄧˉ'), + ('d', 'ㄉㄧˋ'), + ('e', 'ㄧˋ'), + ('f', 'ㄝˊㄈㄨˋ'), + ('g', 'ㄐㄧˋ'), + ('h', 'ㄝˇㄑㄩˋ'), + ('i', 'ㄞˋ'), + ('j', 'ㄐㄟˋ'), + ('k', 'ㄎㄟˋ'), + ('l', 'ㄝˊㄛˋ'), + ('m', 'ㄝˊㄇㄨˋ'), + ('n', 'ㄣˉ'), + ('o', 'ㄡˉ'), + ('p', 'ㄆㄧˉ'), + ('q', 'ㄎㄧㄡˉ'), + ('r', 'ㄚˋ'), + ('s', 'ㄝˊㄙˋ'), + ('t', 'ㄊㄧˋ'), + ('u', 'ㄧㄡˉ'), + ('v', 'ㄨㄧˉ'), + ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), + ('x', 'ㄝˉㄎㄨˋㄙˋ'), + ('y', 'ㄨㄞˋ'), + ('z', 'ㄗㄟˋ') +]] + +# List of (bopomofo, romaji) pairs: +_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'h'), + ('ㄐ', 'ʧ⁼'), + ('ㄑ', 'ʧʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ʦ`⁼'), + ('ㄔ', 'ʦ`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ʦ⁼'), + ('ㄘ', 'ʦʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ə'), + ('ㄝ', 'e'), + ('ㄞ', 'ai'), + ('ㄟ', 'ei'), + ('ㄠ', 'au'), + ('ㄡ', 'ou'), + ('ㄧㄢ', 'yeNN'), + ('ㄢ', 'aNN'), + ('ㄧㄣ', 'iNN'), + ('ㄣ', 'əNN'), + ('ㄤ', 'aNg'), + ('ㄧㄥ', 'iNg'), + ('ㄨㄥ', 'uNg'), + ('ㄩㄥ', 'yuNg'), + ('ㄥ', 'əNg'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'ɥ'), + ('ˉ', '→'), + ('ˊ', '↑'), + ('ˇ', '↓↑'), + ('ˋ', '↓'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + +# List of (romaji, ipa) pairs: +_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('ʃy', 'ʃ'), + ('ʧʰy', 'ʧʰ'), + ('ʧ⁼y', 'ʧ⁼'), + ('NN', 'n'), + ('Ng', 'ŋ'), + ('y', 'j'), + ('h', 'x') +]] + +# List of (bopomofo, ipa) pairs: +_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'x'), + ('ㄐ', 'tʃ⁼'), + ('ㄑ', 'tʃʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ts`⁼'), + ('ㄔ', 'ts`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ts⁼'), + ('ㄘ', 'tsʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ə'), + ('ㄝ', 'ɛ'), + ('ㄞ', 'aɪ'), + ('ㄟ', 'eɪ'), + ('ㄠ', 'ɑʊ'), + ('ㄡ', 'oʊ'), + ('ㄧㄢ', 'jɛn'), + ('ㄩㄢ', 'ɥæn'), + ('ㄢ', 'an'), + ('ㄧㄣ', 'in'), + ('ㄩㄣ', 'ɥn'), + ('ㄣ', 'ən'), + ('ㄤ', 'ɑŋ'), + ('ㄧㄥ', 'iŋ'), + ('ㄨㄥ', 'ʊŋ'), + ('ㄩㄥ', 'jʊŋ'), + ('ㄥ', 'əŋ'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'ɥ'), + ('ˉ', '→'), + ('ˊ', '↑'), + ('ˇ', '↓↑'), + ('ˋ', '↓'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + +# List of (bopomofo, ipa2) pairs: +_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'pwo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'h'), + ('ㄐ', 'tɕ'), + ('ㄑ', 'tɕʰ'), + ('ㄒ', 'ɕ'), + ('ㄓ', 'tʂ'), + ('ㄔ', 'tʂʰ'), + ('ㄕ', 'ʂ'), + ('ㄖ', 'ɻ'), + ('ㄗ', 'ts'), + ('ㄘ', 'tsʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ɤ'), + ('ㄝ', 'ɛ'), + ('ㄞ', 'aɪ'), + ('ㄟ', 'eɪ'), + ('ㄠ', 'ɑʊ'), + ('ㄡ', 'oʊ'), + ('ㄧㄢ', 'jɛn'), + ('ㄩㄢ', 'yæn'), + ('ㄢ', 'an'), + ('ㄧㄣ', 'in'), + ('ㄩㄣ', 'yn'), + ('ㄣ', 'ən'), + ('ㄤ', 'ɑŋ'), + ('ㄧㄥ', 'iŋ'), + ('ㄨㄥ', 'ʊŋ'), + ('ㄩㄥ', 'jʊŋ'), + ('ㄥ', 'ɤŋ'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'y'), + ('ˉ', '˥'), + ('ˊ', '˧˥'), + ('ˇ', '˨˩˦'), + ('ˋ', '˥˩'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + + +def number_to_chinese(text): + numbers = re.findall(r'\d+(?:\.?\d+)?', text) + for number in numbers: + text = text.replace(number, cn2an.an2cn(number), 1) + return text + + +def chinese_to_bopomofo(text): + text = text.replace('、', ',').replace(';', ',').replace(':', ',') + words = jieba.lcut(text, cut_all=False) + text = '' + for word in words: + bopomofos = lazy_pinyin(word, BOPOMOFO) + if not re.search('[\u4e00-\u9fff]', word): + text += word + continue + for i in range(len(bopomofos)): + bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) + if text != '': + text += ' ' + text += ''.join(bopomofos) + return text + + +def latin_to_bopomofo(text): + for regex, replacement in _latin_to_bopomofo: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_romaji(text): + for regex, replacement in _bopomofo_to_romaji: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_ipa(text): + for regex, replacement in _bopomofo_to_ipa: + text = re.sub(regex, replacement, text) + return text + + +def bopomofo_to_ipa2(text): + for regex, replacement in _bopomofo_to_ipa2: + text = re.sub(regex, replacement, text) + return text + + +def chinese_to_romaji(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_romaji(text) + text = re.sub('i([aoe])', r'y\1', text) + text = re.sub('u([aoəe])', r'w\1', text) + text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', + r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') + text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) + return text + + +def chinese_to_lazy_ipa(text): + text = chinese_to_romaji(text) + for regex, replacement in _romaji_to_ipa: + text = re.sub(regex, replacement, text) + return text + + +def chinese_to_ipa(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa(text) + text = re.sub('i([aoe])', r'j\1', text) + text = re.sub('u([aoəe])', r'w\1', text) + text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', + r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') + text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) + return text + + +def chinese_to_ipa2(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa2(text) + text = re.sub(r'i([aoe])', r'j\1', text) + text = re.sub(r'u([aoəe])', r'w\1', text) + text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) + text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) + return text diff --git a/src/text/symbols.py b/src/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..1231728d35b1f76b9da3f81a60fc46649c91501e --- /dev/null +++ b/src/text/symbols.py @@ -0,0 +1,88 @@ +''' +Defines the set of symbols used in text input to the model. +''' + +# japanese_cleaners +# _pad = '_' +# _punctuation = ',.!?-' +# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' + + +'''# japanese_cleaners2 +_pad = '_' +_punctuation = ',.!?-~…' +_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' +''' + + +'''# korean_cleaners +_pad = '_' +_punctuation = ',.!?…~' +_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' +''' + +'''# chinese_cleaners +_pad = '_' +_punctuation = ',。!?—…' +_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' +''' + +# # zh_ja_mixture_cleaners +# _pad = '_' +# _punctuation = ',.!?-~…' +# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' + + +'''# sanskrit_cleaners +_pad = '_' +_punctuation = '।' +_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' +''' + +'''# cjks_cleaners +_pad = '_' +_punctuation = ',.!?-~…' +_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' +''' + +'''# thai_cleaners +_pad = '_' +_punctuation = '.!? ' +_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' +''' + +# # cjke_cleaners2 +_pad = '_' +_punctuation = ',.!?-~…' +_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' + + +'''# shanghainese_cleaners +_pad = '_' +_punctuation = ',.!?…' +_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' +''' + +'''# chinese_dialect_cleaners +_pad = '_' +_punctuation = ',.!?~…─' +_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' +''' + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + +# Special symbol ids +SPACE_ID = symbols.index(" ") + +num_ja_tones = 1 +num_kr_tones = 1 +num_zh_tones = 6 +num_en_tones = 4 + +language_tone_start_map = { + "ZH": 0, + "JP": num_zh_tones, + "EN": num_zh_tones + num_ja_tones, + 'KR': num_zh_tones + num_ja_tones + num_en_tones, +} \ No newline at end of file diff --git a/src/text_utils.py b/src/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4bb9929c1165bb3928a6224f4040dd4636ec90 --- /dev/null +++ b/src/text_utils.py @@ -0,0 +1,28 @@ +# IPA Phonemizer: https://github.com/bootphon/phonemizer + +_pad = "$" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +dicts = {} +for i in range(len((symbols))): + dicts[symbols[i]] = i + + +class TextCleaner: + def __init__(self, dummy=None): + self.word_index_dictionary = dicts + print(len(dicts)) + + def __call__(self, text): + indexes = [] + for char in text: + try: + indexes.append(self.word_index_dictionary[char]) + except KeyError: + print(text) + return indexes diff --git a/src/train_finetune.py b/src/train_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..b204994adf1dd3cdf4c22d6ab843b6cf01c73720 --- /dev/null +++ b/src/train_finetune.py @@ -0,0 +1,839 @@ +# load packages +import random +import yaml +import time +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +import click +import shutil +import warnings + +warnings.simplefilter("ignore") +from torch.utils.tensorboard import SummaryWriter + +from meldataset import build_dataloader + +from Utils.ASR.models import ASRCNN +from Utils.JDC.model import JDCNet +from Utils.PLBERT.util import load_plbert + +from models import * +from losses import * +from utils import * + +from Modules.slmadv import SLMAdversarialLoss +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +from optimizers import build_optimizer + + +# simple fix for dataparallel that allows access to class attributes +class MyDataParallel(torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +import logging +from logging import StreamHandler + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +handler = StreamHandler() +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + + +@click.command() +@click.option("-p", "--config_path", default="Configs/config_ft.yml", type=str) +def main(config_path): + config = yaml.safe_load(open(config_path)) + + log_dir = config["log_dir"] + if not osp.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) + writer = SummaryWriter(log_dir + "/tensorboard") + + # write logs + file_handler = logging.FileHandler(osp.join(log_dir, "train.log")) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(levelname)s:%(asctime)s: %(message)s") + ) + logger.addHandler(file_handler) + + batch_size = config.get("batch_size", 10) + + epochs = config.get("epochs", 200) + save_freq = config.get("save_freq", 2) + log_interval = config.get("log_interval", 10) + saving_epoch = config.get("save_freq", 2) + + data_params = config.get("data_params", None) + sr = config["preprocess_params"].get("sr", 24000) + train_path = data_params["train_data"] + val_path = data_params["val_data"] + root_path = data_params["root_path"] + min_length = data_params["min_length"] + OOD_data = data_params["OOD_data"] + + max_len = config.get("max_len", 200) + + loss_params = Munch(config["loss_params"]) + diff_epoch = loss_params.diff_epoch + joint_epoch = loss_params.joint_epoch + + optimizer_params = Munch(config["optimizer_params"]) + + train_list, val_list = get_data_path_list(train_path, val_path) + device = "cuda" + + train_dataloader = build_dataloader( + train_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + num_workers=2, + dataset_config={}, + device=device, + ) + + val_dataloader = build_dataloader( + val_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + validation=True, + num_workers=0, + device=device, + dataset_config={}, + ) + + # load pretrained ASR model + ASR_config = config.get("ASR_config", False) + ASR_path = config.get("ASR_path", False) + text_aligner = load_ASR_models(ASR_path, ASR_config) + + # load pretrained F0 model + F0_path = config.get("F0_path", False) + pitch_extractor = load_F0_models(F0_path) + + # load PL-BERT model + BERT_path = config.get("PLBERT_dir", False) + plbert = load_plbert(BERT_path) + + # build model + model_params = recursive_munch(config["model_params"]) + multispeaker = model_params.multispeaker + model = build_model(model_params, text_aligner, pitch_extractor, plbert) + _ = [model[key].to(device) for key in model] + + # DP + for key in model: + if key != "mpd" and key != "msd" and key != "wd": + model[key] = MyDataParallel(model[key]) + + start_epoch = 0 + iters = 0 + + load_pretrained = config.get("pretrained_model", "") != "" and config.get( + "second_stage_load_pretrained", False + ) + + if not load_pretrained: + if config.get("first_stage_path", "") != "": + first_stage_path = osp.join( + log_dir, config.get("first_stage_path", "first_stage.pth") + ) + print("Loading the first stage model at %s ..." % first_stage_path) + model, _, start_epoch, iters = load_checkpoint( + model, + None, + first_stage_path, + load_only_params=True, + ignore_modules=[ + "bert", + "bert_encoder", + "predictor", + "predictor_encoder", + "msd", + "mpd", + "wd", + "diffusion", + ], + ) # keep starting epoch for tensorboard log + + # these epochs should be counted from the start epoch + diff_epoch += start_epoch + joint_epoch += start_epoch + epochs += start_epoch + + model.predictor_encoder = copy.deepcopy(model.style_encoder) + else: + raise ValueError("You need to specify the path to the first stage model.") + + gl = GeneratorLoss(model.mpd, model.msd).to(device) + dl = DiscriminatorLoss(model.mpd, model.msd).to(device) + wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device) + + gl = MyDataParallel(gl) + dl = MyDataParallel(dl) + wl = MyDataParallel(wl) + + sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule( + sigma_min=0.0001, sigma_max=3.0, rho=9.0 + ), # empirical parameters + clamp=False, + ) + + scheduler_params = { + "max_lr": optimizer_params.lr, + "pct_start": float(0), + "epochs": epochs, + "steps_per_epoch": len(train_dataloader), + } + scheduler_params_dict = {key: scheduler_params.copy() for key in model} + scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2 + scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2 + scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2 + + optimizer = build_optimizer( + {key: model[key].parameters() for key in model}, + scheduler_params_dict=scheduler_params_dict, + lr=optimizer_params.lr, + ) + + # adjust BERT learning rate + for g in optimizer.optimizers["bert"].param_groups: + g["betas"] = (0.9, 0.99) + g["lr"] = optimizer_params.bert_lr + g["initial_lr"] = optimizer_params.bert_lr + g["min_lr"] = 0 + g["weight_decay"] = 0.01 + + # adjust acoustic module learning rate + for module in ["decoder", "style_encoder"]: + for g in optimizer.optimizers[module].param_groups: + g["betas"] = (0.0, 0.99) + g["lr"] = optimizer_params.ft_lr + g["initial_lr"] = optimizer_params.ft_lr + g["min_lr"] = 0 + g["weight_decay"] = 1e-4 + + # load models if there is a model + if load_pretrained: + model, optimizer, start_epoch, iters = load_checkpoint( + model, + optimizer, + config["pretrained_model"], + load_only_params=config.get("load_only_params", True), + ) + + n_down = model.text_aligner.n_down + + best_loss = float("inf") # best test loss + loss_train_record = list([]) + loss_test_record = list([]) + iters = 0 + + criterion = nn.L1Loss() # F0 loss (regression) + torch.cuda.empty_cache() + + stft_loss = MultiResolutionSTFTLoss().to(device) + + print("BERT", optimizer.optimizers["bert"]) + print("decoder", optimizer.optimizers["decoder"]) + + start_ds = False + + running_std = [] + + slmadv_params = Munch(config["slmadv_params"]) + slmadv = SLMAdversarialLoss( + model, + wl, + sampler, + slmadv_params.min_len, + slmadv_params.max_len, + batch_percentage=slmadv_params.batch_percentage, + skip_update=slmadv_params.iter, + sig=slmadv_params.sig, + ) + + for epoch in range(start_epoch, epochs): + running_loss = 0 + start_time = time.time() + + _ = [model[key].eval() for key in model] + + model.text_aligner.train() + model.text_encoder.train() + + model.predictor.train() + model.bert_encoder.train() + model.bert.train() + model.msd.train() + model.mpd.train() + + for i, batch in enumerate(train_dataloader): + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + ( + texts, + input_lengths, + ref_texts, + ref_lengths, + mels, + mel_input_length, + ref_mels, + ) = batch + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to(device) + mel_mask = length_to_mask(mel_input_length).to(device) + text_mask = length_to_mask(input_lengths).to(texts.device) + + # compute reference styles + if multispeaker and epoch >= diff_epoch: + ref_ss = model.style_encoder(ref_mels.unsqueeze(1)) + ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1)) + ref = torch.cat([ref_ss, ref_sp], dim=1) + + try: + ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + except: + continue + + mask_ST = mask_from_lens( + s2s_attn, input_lengths, mel_input_length // (2**n_down) + ) + s2s_attn_mono = maximum_path(s2s_attn, mask_ST) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + + # 50% of chance of using monotonic version + if bool(random.getrandbits(1)): + asr = t_en @ s2s_attn + else: + asr = t_en @ s2s_attn_mono + + d_gt = s2s_attn_mono.sum(axis=-1).detach() + + # compute the style of the entire utterance + # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool) + ss = [] + gs = [] + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item()) + mel = mels[bib, :, : mel_input_length[bib]] + s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1)) + ss.append(s) + s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) + gs.append(s) + + s_dur = torch.stack(ss).squeeze() # global prosodic styles + gs = torch.stack(gs).squeeze() # global acoustic styles + s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser + + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + # denoiser training + if epoch >= diff_epoch: + num_steps = np.random.randint(3, 5) + + if model_params.diffusion.dist.estimate_sigma_data: + model.diffusion.module.diffusion.sigma_data = ( + s_trg.std(axis=-1).mean().item() + ) # batch-wise std estimation + running_std.append(model.diffusion.module.diffusion.sigma_data) + + if multispeaker: + s_preds = sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=1, + features=ref, # reference from the same speaker as the embedding + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + loss_diff = model.diffusion( + s_trg.unsqueeze(1), embedding=bert_dur, features=ref + ).mean() # EDM loss + loss_sty = F.l1_loss( + s_preds, s_trg.detach() + ) # style reconstruction loss + else: + s_preds = sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=1, + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + loss_diff = model.diffusion.module.diffusion( + s_trg.unsqueeze(1), embedding=bert_dur + ).mean() # EDM loss + loss_sty = F.l1_loss( + s_preds, s_trg.detach() + ) # style reconstruction loss + else: + loss_sty = 0 + loss_diff = 0 + + s_loss = 0 + + d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask) + + mel_len_st = int(mel_input_length.min().item() / 2 - 1) + mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2) + en = [] + gt = [] + p_en = [] + wav = [] + st = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + p_en.append(p[bib, :, random_start : random_start + mel_len]) + gt.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)] + ) + + y = waves[bib][ + (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 + ] + wav.append(torch.from_numpy(y).to(device)) + + # style reference (better to be different from the GT) + random_start = np.random.randint(0, mel_length - mel_len_st) + st.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)] + ) + + wav = torch.stack(wav).float().detach() + + en = torch.stack(en) + p_en = torch.stack(p_en) + gt = torch.stack(gt).detach() + st = torch.stack(st).detach() + + if gt.size(-1) < 80: + continue + + s = model.style_encoder(gt.unsqueeze(1)) + s_dur = model.predictor_encoder(gt.unsqueeze(1)) + + with torch.no_grad(): + F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) + F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze() + + N_real = log_norm(gt.unsqueeze(1)).squeeze(1) + + y_rec_gt = wav.unsqueeze(1) + y_rec_gt_pred = model.decoder(en, F0_real, N_real, s) + + wav = y_rec_gt + + F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur) + + y_rec = model.decoder(en, F0_fake, N_fake, s) + + loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10 + loss_norm_rec = F.smooth_l1_loss(N_real, N_fake) + + optimizer.zero_grad() + d_loss = dl(wav.detach(), y_rec.detach()).mean() + d_loss.backward() + optimizer.step("msd") + optimizer.step("mpd") + + # generator loss + optimizer.zero_grad() + + loss_mel = stft_loss(y_rec, wav) + loss_gen_all = gl(wav, y_rec).mean() + loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean() + + loss_ce = 0 + loss_dur = 0 + for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths): + _s2s_pred = _s2s_pred[:_text_length, :] + _text_input = _text_input[:_text_length].long() + _s2s_trg = torch.zeros_like(_s2s_pred) + for p in range(_s2s_trg.shape[0]): + _s2s_trg[p, : _text_input[p]] = 1 + _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1) + + loss_dur += F.l1_loss( + _dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1] + ) + loss_ce += F.binary_cross_entropy_with_logits( + _s2s_pred.flatten(), _s2s_trg.flatten() + ) + + loss_ce /= texts.size(0) + loss_dur /= texts.size(0) + + loss_s2s = 0 + for _s2s_pred, _text_input, _text_length in zip( + s2s_pred, texts, input_lengths + ): + loss_s2s += F.cross_entropy( + _s2s_pred[:_text_length], _text_input[:_text_length] + ) + loss_s2s /= texts.size(0) + + loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10 + + g_loss = ( + loss_params.lambda_mel * loss_mel + + loss_params.lambda_F0 * loss_F0_rec + + loss_params.lambda_ce * loss_ce + + loss_params.lambda_norm * loss_norm_rec + + loss_params.lambda_dur * loss_dur + + loss_params.lambda_gen * loss_gen_all + + loss_params.lambda_slm * loss_lm + + loss_params.lambda_sty * loss_sty + + loss_params.lambda_diff * loss_diff + + loss_params.lambda_mono * loss_mono + + loss_params.lambda_s2s * loss_s2s + ) + + running_loss += loss_mel.item() + g_loss.backward() + if torch.isnan(g_loss): + from IPython.core.debugger import set_trace + + set_trace() + + optimizer.step("bert_encoder") + optimizer.step("bert") + optimizer.step("predictor") + optimizer.step("predictor_encoder") + optimizer.step("style_encoder") + optimizer.step("decoder") + + optimizer.step("text_encoder") + optimizer.step("text_aligner") + + if epoch >= diff_epoch: + optimizer.step("diffusion") + + if epoch >= joint_epoch: + # randomly pick whether to use in-distribution text + if np.random.rand() < 0.5: + use_ind = True + else: + use_ind = False + + if use_ind: + ref_lengths = input_lengths + ref_texts = texts + + slm_out = slmadv( + i, + y_rec_gt, + y_rec_gt_pred, + waves, + mel_input_length, + ref_texts, + ref_lengths, + use_ind, + s_trg.detach(), + ref if multispeaker else None, + ) + + if slm_out is None: + continue + + d_loss_slm, loss_gen_lm, y_pred = slm_out + + # SLM discriminator loss + if d_loss_slm != 0: + optimizer.zero_grad() + d_loss_slm.backward() + optimizer.step("wd") + + # SLM generator loss + optimizer.zero_grad() + loss_gen_lm.backward() + + # compute the gradient norm + total_norm = {} + for key in model.keys(): + total_norm[key] = 0 + parameters = [ + p + for p in model[key].parameters() + if p.grad is not None and p.requires_grad + ] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + total_norm[key] += param_norm.item() ** 2 + total_norm[key] = total_norm[key] ** 0.5 + + # gradient scaling + if total_norm["predictor"] > slmadv_params.thresh: + for key in model.keys(): + for p in model[key].parameters(): + if p.grad is not None: + p.grad *= 1 / total_norm["predictor"] + + for p in model.predictor.duration_proj.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + for p in model.predictor.lstm.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + for p in model.diffusion.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + optimizer.step("bert_encoder") + optimizer.step("bert") + optimizer.step("predictor") + optimizer.step("diffusion") + + else: + d_loss_slm, loss_gen_lm = 0, 0 + + iters = iters + 1 + + if (i + 1) % log_interval == 0: + logger.info( + "Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f" + % ( + epoch + 1, + epochs, + i + 1, + len(train_list) // batch_size, + running_loss / log_interval, + d_loss, + loss_dur, + loss_ce, + loss_norm_rec, + loss_F0_rec, + loss_lm, + loss_gen_all, + loss_sty, + loss_diff, + d_loss_slm, + loss_gen_lm, + s_loss, + loss_s2s, + loss_mono, + ) + ) + + writer.add_scalar("train/mel_loss", running_loss / log_interval, iters) + writer.add_scalar("train/gen_loss", loss_gen_all, iters) + writer.add_scalar("train/d_loss", d_loss, iters) + writer.add_scalar("train/ce_loss", loss_ce, iters) + writer.add_scalar("train/dur_loss", loss_dur, iters) + writer.add_scalar("train/slm_loss", loss_lm, iters) + writer.add_scalar("train/norm_loss", loss_norm_rec, iters) + writer.add_scalar("train/F0_loss", loss_F0_rec, iters) + writer.add_scalar("train/sty_loss", loss_sty, iters) + writer.add_scalar("train/diff_loss", loss_diff, iters) + writer.add_scalar("train/d_loss_slm", d_loss_slm, iters) + writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters) + + running_loss = 0 + + print("Time elasped:", time.time() - start_time) + + loss_test = 0 + loss_align = 0 + loss_f = 0 + _ = [model[key].eval() for key in model] + + with torch.no_grad(): + iters_test = 0 + for batch_idx, batch in enumerate(val_dataloader): + optimizer.zero_grad() + + try: + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + ( + texts, + input_lengths, + ref_texts, + ref_lengths, + mels, + mel_input_length, + ref_mels, + ) = batch + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to( + "cuda" + ) + text_mask = length_to_mask(input_lengths).to(texts.device) + + _, _, s2s_attn = model.text_aligner(mels, mask, texts) + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + + mask_ST = mask_from_lens( + s2s_attn, input_lengths, mel_input_length // (2**n_down) + ) + s2s_attn_mono = maximum_path(s2s_attn, mask_ST) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + asr = t_en @ s2s_attn_mono + + d_gt = s2s_attn_mono.sum(axis=-1).detach() + + ss = [] + gs = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item()) + mel = mels[bib, :, : mel_input_length[bib]] + s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1)) + ss.append(s) + s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) + gs.append(s) + + s = torch.stack(ss).squeeze() + gs = torch.stack(gs).squeeze() + s_trg = torch.cat([s, gs], dim=-1).detach() + + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + d, p = model.predictor( + d_en, s, input_lengths, s2s_attn_mono, text_mask + ) + # get clips + mel_len = int(mel_input_length.min().item() / 2 - 1) + en = [] + gt = [] + + p_en = [] + wav = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + p_en.append(p[bib, :, random_start : random_start + mel_len]) + + gt.append( + mels[ + bib, + :, + (random_start * 2) : ((random_start + mel_len) * 2), + ] + ) + y = waves[bib][ + (random_start * 2) + * 300 : ((random_start + mel_len) * 2) + * 300 + ] + wav.append(torch.from_numpy(y).to(device)) + + wav = torch.stack(wav).float().detach() + + en = torch.stack(en) + p_en = torch.stack(p_en) + gt = torch.stack(gt).detach() + s = model.predictor_encoder(gt.unsqueeze(1)) + + F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s) + + loss_dur = 0 + for _s2s_pred, _text_input, _text_length in zip( + d, (d_gt), input_lengths + ): + _s2s_pred = _s2s_pred[:_text_length, :] + _text_input = _text_input[:_text_length].long() + _s2s_trg = torch.zeros_like(_s2s_pred) + for bib in range(_s2s_trg.shape[0]): + _s2s_trg[bib, : _text_input[bib]] = 1 + _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1) + loss_dur += F.l1_loss( + _dur_pred[1 : _text_length - 1], + _text_input[1 : _text_length - 1], + ) + + loss_dur /= texts.size(0) + + s = model.style_encoder(gt.unsqueeze(1)) + + y_rec = model.decoder(en, F0_fake, N_fake, s) + loss_mel = stft_loss(y_rec.squeeze(), wav.detach()) + + F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) + + loss_F0 = F.l1_loss(F0_real, F0_fake) / 10 + + loss_test += (loss_mel).mean() + loss_align += (loss_dur).mean() + loss_f += (loss_F0).mean() + + iters_test += 1 + except: + continue + + print("Epochs:", epoch + 1) + logger.info( + "Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f" + % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + + "\n\n\n" + ) + print("\n\n\n") + writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1) + writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1) + writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1) + + if (epoch + 1) % save_freq == 0: + if (loss_test / iters_test) < best_loss: + best_loss = loss_test / iters_test + print("Saving..") + state = { + "net": {key: model[key].state_dict() for key in model}, + "optimizer": optimizer.state_dict(), + "iters": iters, + "val_loss": loss_test / iters_test, + "epoch": epoch, + } + save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch) + torch.save(state, save_path) + + # if estimate sigma, save the estimated simga + if model_params.diffusion.dist.estimate_sigma_data: + config["model_params"]["diffusion"]["dist"]["sigma_data"] = float( + np.mean(running_std) + ) + + with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile: + yaml.dump(config, outfile, default_flow_style=True) + + +if __name__ == "__main__": + main() diff --git a/src/train_first.py b/src/train_first.py new file mode 100644 index 0000000000000000000000000000000000000000..daab4553b56628227e32e7b803e6e3279ce1f75f --- /dev/null +++ b/src/train_first.py @@ -0,0 +1,540 @@ +import os +import os.path as osp +import re +import sys +import yaml +import shutil +import numpy as np +import torch +import click +import warnings + +warnings.simplefilter("ignore") + +# load packages +import random +import yaml +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa + +from models import * +from meldataset import build_dataloader +from utils import * +from losses import * +from optimizers import build_optimizer +import time + +from accelerate import Accelerator +from accelerate.utils import LoggerType +from accelerate import DistributedDataParallelKwargs + +from torch.utils.tensorboard import SummaryWriter + +import logging +from accelerate.logging import get_logger + +logger = get_logger(__name__, log_level="DEBUG") + + +@click.command() +@click.option("-p", "--config_path", default="Configs/config.yml", type=str) +def main(config_path): + config = yaml.safe_load(open(config_path)) + + log_dir = config["log_dir"] + if not osp.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs] + ) + if accelerator.is_main_process: + writer = SummaryWriter(log_dir + "/tensorboard") + + # write logs + file_handler = logging.FileHandler(osp.join(log_dir, "train.log")) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(levelname)s:%(asctime)s: %(message)s") + ) + logger.logger.addHandler(file_handler) + + batch_size = config.get("batch_size", 10) + device = accelerator.device + + epochs = config.get("epochs_1st", 200) + save_freq = config.get("save_freq", 2) + log_interval = config.get("log_interval", 10) + saving_epoch = config.get("save_freq", 2) + + data_params = config.get("data_params", None) + sr = config["preprocess_params"].get("sr", 24000) + train_path = data_params["train_data"] + val_path = data_params["val_data"] + root_path = data_params["root_path"] + min_length = data_params["min_length"] + OOD_data = data_params["OOD_data"] + + max_len = config.get("max_len", 200) + + # load data + train_list, val_list = get_data_path_list(train_path, val_path) + + train_dataloader = build_dataloader( + train_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + num_workers=2, + dataset_config={}, + device=device, + ) + + val_dataloader = build_dataloader( + val_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + validation=True, + num_workers=0, + device=device, + dataset_config={}, + ) + + with accelerator.main_process_first(): + # load pretrained ASR model + ASR_config = config.get("ASR_config", False) + ASR_path = config.get("ASR_path", False) + text_aligner = load_ASR_models(ASR_path, ASR_config) + + # load pretrained F0 model + F0_path = config.get("F0_path", False) + pitch_extractor = load_F0_models(F0_path) + + # load BERT model + from Utils.PLBERT.util import load_plbert + + BERT_path = config.get("PLBERT_dir", False) + plbert = load_plbert(BERT_path) + + scheduler_params = { + "max_lr": float(config["optimizer_params"].get("lr", 1e-4)), + "pct_start": float(config["optimizer_params"].get("pct_start", 0.0)), + "epochs": epochs, + "steps_per_epoch": len(train_dataloader), + } + + model_params = recursive_munch(config["model_params"]) + multispeaker = model_params.multispeaker + model = build_model(model_params, text_aligner, pitch_extractor, plbert) + + best_loss = float("inf") # best test loss + loss_train_record = list([]) + loss_test_record = list([]) + + loss_params = Munch(config["loss_params"]) + TMA_epoch = loss_params.TMA_epoch + + for k in model: + model[k] = accelerator.prepare(model[k]) + + train_dataloader, val_dataloader = accelerator.prepare( + train_dataloader, val_dataloader + ) + + _ = [model[key].to(device) for key in model] + + # initialize optimizers after preparing models for compatibility with FSDP + optimizer = build_optimizer( + {key: model[key].parameters() for key in model}, + scheduler_params_dict={key: scheduler_params.copy() for key in model}, + lr=float(config["optimizer_params"].get("lr", 1e-4)), + ) + + for k, v in optimizer.optimizers.items(): + optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k]) + optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k]) + + with accelerator.main_process_first(): + if config.get("pretrained_model", "") != "": + model, optimizer, start_epoch, iters = load_checkpoint( + model, + optimizer, + config["pretrained_model"], + load_only_params=config.get("load_only_params", True), + ) + else: + start_epoch = 0 + iters = 0 + + # in case not distributed + try: + n_down = model.text_aligner.module.n_down + except: + n_down = model.text_aligner.n_down + + # wrapped losses for compatibility with mixed precision + stft_loss = MultiResolutionSTFTLoss().to(device) + gl = GeneratorLoss(model.mpd, model.msd).to(device) + dl = DiscriminatorLoss(model.mpd, model.msd).to(device) + wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device) + + for epoch in range(start_epoch, epochs): + running_loss = 0 + start_time = time.time() + + _ = [model[key].train() for key in model] + + for i, batch in enumerate(train_dataloader): + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + texts, input_lengths, _, _, mels, mel_input_length, _ = batch + + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda") + text_mask = length_to_mask(input_lengths).to(texts.device) + + ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) + + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + + with torch.no_grad(): + attn_mask = ( + (~mask) + .unsqueeze(-1) + .expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]) + .float() + .transpose(-1, -2) + ) + attn_mask = ( + attn_mask.float() + * (~text_mask) + .unsqueeze(-1) + .expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]) + .float() + ) + attn_mask = attn_mask < 1 + + s2s_attn.masked_fill_(attn_mask, 0.0) + + with torch.no_grad(): + mask_ST = mask_from_lens( + s2s_attn, input_lengths, mel_input_length // (2**n_down) + ) + s2s_attn_mono = maximum_path(s2s_attn, mask_ST) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + + # 50% of chance of using monotonic version + if bool(random.getrandbits(1)): + asr = t_en @ s2s_attn + else: + asr = t_en @ s2s_attn_mono + + # get clips + mel_input_length_all = accelerator.gather( + mel_input_length + ) # for balanced load + mel_len = min( + [int(mel_input_length_all.min().item() / 2 - 1), max_len // 2] + ) + mel_len_st = int(mel_input_length.min().item() / 2 - 1) + + en = [] + gt = [] + wav = [] + st = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + gt.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)] + ) + + y = waves[bib][ + (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 + ] + wav.append(torch.from_numpy(y).to(device)) + + # style reference (better to be different from the GT) + random_start = np.random.randint(0, mel_length - mel_len_st) + st.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)] + ) + + en = torch.stack(en) + gt = torch.stack(gt).detach() + st = torch.stack(st).detach() + + wav = torch.stack(wav).float().detach() + + # clip too short to be used by the style encoder + if gt.shape[-1] < 80: + continue + + with torch.no_grad(): + real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach() + F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1)) + + s = model.style_encoder( + st.unsqueeze(1) if multispeaker else gt.unsqueeze(1) + ) + + y_rec = model.decoder(en, F0_real, real_norm, s) + + # discriminator loss + + if epoch >= TMA_epoch: + optimizer.zero_grad() + d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean() + accelerator.backward(d_loss) + optimizer.step("msd") + optimizer.step("mpd") + else: + d_loss = 0 + + # generator loss + optimizer.zero_grad() + loss_mel = stft_loss(y_rec.squeeze(), wav.detach()) + + if epoch >= TMA_epoch: # start TMA training + loss_s2s = 0 + for _s2s_pred, _text_input, _text_length in zip( + s2s_pred, texts, input_lengths + ): + loss_s2s += F.cross_entropy( + _s2s_pred[:_text_length], _text_input[:_text_length] + ) + loss_s2s /= texts.size(0) + + loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10 + + loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean() + loss_slm = wl(wav.detach(), y_rec).mean() + + g_loss = ( + loss_params.lambda_mel * loss_mel + + loss_params.lambda_mono * loss_mono + + loss_params.lambda_s2s * loss_s2s + + loss_params.lambda_gen * loss_gen_all + + loss_params.lambda_slm * loss_slm + ) + + else: + loss_s2s = 0 + loss_mono = 0 + loss_gen_all = 0 + loss_slm = 0 + g_loss = loss_mel + + running_loss += accelerator.gather(loss_mel).mean().item() + + accelerator.backward(g_loss) + + optimizer.step("text_encoder") + optimizer.step("style_encoder") + optimizer.step("decoder") + + if epoch >= TMA_epoch: + optimizer.step("text_aligner") + optimizer.step("pitch_extractor") + + iters = iters + 1 + + if (i + 1) % log_interval == 0 and accelerator.is_main_process: + log_print( + "Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f" + % ( + epoch + 1, + epochs, + i + 1, + len(train_list) // batch_size, + running_loss / log_interval, + loss_gen_all, + d_loss, + loss_mono, + loss_s2s, + loss_slm, + ), + logger, + ) + + writer.add_scalar("train/mel_loss", running_loss / log_interval, iters) + writer.add_scalar("train/gen_loss", loss_gen_all, iters) + writer.add_scalar("train/d_loss", d_loss, iters) + writer.add_scalar("train/mono_loss", loss_mono, iters) + writer.add_scalar("train/s2s_loss", loss_s2s, iters) + writer.add_scalar("train/slm_loss", loss_slm, iters) + + running_loss = 0 + + print("Time elasped:", time.time() - start_time) + + loss_test = 0 + + _ = [model[key].eval() for key in model] + + with torch.no_grad(): + iters_test = 0 + for batch_idx, batch in enumerate(val_dataloader): + optimizer.zero_grad() + + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + texts, input_lengths, _, _, mels, mel_input_length, _ = batch + + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda") + ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) + + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + + text_mask = length_to_mask(input_lengths).to(texts.device) + attn_mask = ( + (~mask) + .unsqueeze(-1) + .expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]) + .float() + .transpose(-1, -2) + ) + attn_mask = ( + attn_mask.float() + * (~text_mask) + .unsqueeze(-1) + .expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]) + .float() + ) + attn_mask = attn_mask < 1 + s2s_attn.masked_fill_(attn_mask, 0.0) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + + asr = t_en @ s2s_attn + + # get clips + mel_input_length_all = accelerator.gather( + mel_input_length + ) # for balanced load + mel_len = min( + [int(mel_input_length.min().item() / 2 - 1), max_len // 2] + ) + + en = [] + gt = [] + wav = [] + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + gt.append( + mels[ + bib, :, (random_start * 2) : ((random_start + mel_len) * 2) + ] + ) + y = waves[bib][ + (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 + ] + wav.append(torch.from_numpy(y).to("cuda")) + + wav = torch.stack(wav).float().detach() + + en = torch.stack(en) + gt = torch.stack(gt).detach() + + F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) + s = model.style_encoder(gt.unsqueeze(1)) + real_norm = log_norm(gt.unsqueeze(1)).squeeze(1) + y_rec = model.decoder(en, F0_real, real_norm, s) + + loss_mel = stft_loss(y_rec.squeeze(), wav.detach()) + + loss_test += accelerator.gather(loss_mel).mean().item() + iters_test += 1 + + if accelerator.is_main_process: + print("Epochs:", epoch + 1) + log_print( + "Validation loss: %.3f" % (loss_test / iters_test) + "\n\n\n\n", logger + ) + print("\n\n\n") + writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1) + attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze()) + writer.add_figure("eval/attn", attn_image, epoch) + + with torch.no_grad(): + for bib in range(len(asr)): + mel_length = int(mel_input_length[bib].item()) + gt = mels[bib, :, :mel_length].unsqueeze(0) + en = asr[bib, :, : mel_length // 2].unsqueeze(0) + + F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1)) + F0_real = F0_real.unsqueeze(0) + s = model.style_encoder(gt.unsqueeze(1)) + real_norm = log_norm(gt.unsqueeze(1)).squeeze(1) + + y_rec = model.decoder(en, F0_real, real_norm, s) + + writer.add_audio( + "eval/y" + str(bib), + y_rec.cpu().numpy().squeeze(), + epoch, + sample_rate=sr, + ) + if epoch == 0: + writer.add_audio( + "gt/y" + str(bib), + waves[bib].squeeze(), + epoch, + sample_rate=sr, + ) + + if bib >= 6: + break + + if epoch % saving_epoch == 0: + if (loss_test / iters_test) < best_loss: + best_loss = loss_test / iters_test + print("Saving..") + state = { + "net": {key: model[key].state_dict() for key in model}, + "optimizer": optimizer.state_dict(), + "iters": iters, + "val_loss": loss_test / iters_test, + "epoch": epoch, + } + save_path = osp.join(log_dir, "epoch_1st_%05d.pth" % epoch) + torch.save(state, save_path) + + if accelerator.is_main_process: + print("Saving..") + state = { + "net": {key: model[key].state_dict() for key in model}, + "optimizer": optimizer.state_dict(), + "iters": iters, + "val_loss": loss_test / iters_test, + "epoch": epoch, + } + save_path = osp.join(log_dir, config.get("first_stage_path", "first_stage.pth")) + torch.save(state, save_path) + + +if __name__ == "__main__": + main() diff --git a/src/train_second.py b/src/train_second.py new file mode 100644 index 0000000000000000000000000000000000000000..4f36b799c21cf63ceb48d7428d76ad1eed473150 --- /dev/null +++ b/src/train_second.py @@ -0,0 +1,958 @@ +# load packages +import random +import yaml +import time +from munch import Munch +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +import click +import shutil +import warnings + +warnings.simplefilter("ignore") +from torch.utils.tensorboard import SummaryWriter + +from meldataset import build_dataloader + +from Utils.ASR.models import ASRCNN +from Utils.JDC.model import JDCNet +from Utils.PLBERT.util import load_plbert + +from models import * +from losses import * +from utils import * + +from Modules.slmadv import SLMAdversarialLoss +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +from optimizers import build_optimizer + + +# simple fix for dataparallel that allows access to class attributes +class MyDataParallel(torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +import logging +from logging import StreamHandler + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +handler = StreamHandler() +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + + +@click.command() +@click.option("-p", "--config_path", default="Configs/config.yml", type=str) +def main(config_path): + config = yaml.safe_load(open(config_path)) + + log_dir = config["log_dir"] + if not osp.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) + writer = SummaryWriter(log_dir + "/tensorboard") + + # write logs + file_handler = logging.FileHandler(osp.join(log_dir, "train.log")) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter("%(levelname)s:%(asctime)s: %(message)s") + ) + logger.addHandler(file_handler) + + batch_size = config.get("batch_size", 10) + + epochs = config.get("epochs_2nd", 200) + save_freq = config.get("save_freq", 2) + log_interval = config.get("log_interval", 10) + saving_epoch = config.get("save_freq", 2) + + data_params = config.get("data_params", None) + sr = config["preprocess_params"].get("sr", 24000) + train_path = data_params["train_data"] + val_path = data_params["val_data"] + root_path = data_params["root_path"] + min_length = data_params["min_length"] + OOD_data = data_params["OOD_data"] + + max_len = config.get("max_len", 200) + + loss_params = Munch(config["loss_params"]) + diff_epoch = loss_params.diff_epoch + joint_epoch = loss_params.joint_epoch + + optimizer_params = Munch(config["optimizer_params"]) + + train_list, val_list = get_data_path_list(train_path, val_path) + device = "cuda" + + train_dataloader = build_dataloader( + train_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + num_workers=2, + dataset_config={}, + device=device, + ) + + val_dataloader = build_dataloader( + val_list, + root_path, + OOD_data=OOD_data, + min_length=min_length, + batch_size=batch_size, + validation=True, + num_workers=0, + device=device, + dataset_config={}, + ) + + # load pretrained ASR model + ASR_config = config.get("ASR_config", False) + ASR_path = config.get("ASR_path", False) + text_aligner = load_ASR_models(ASR_path, ASR_config) + + # load pretrained F0 model + F0_path = config.get("F0_path", False) + pitch_extractor = load_F0_models(F0_path) + + # load PL-BERT model + BERT_path = config.get("PLBERT_dir", False) + plbert = load_plbert(BERT_path) + + # build model + model_params = recursive_munch(config["model_params"]) + multispeaker = model_params.multispeaker + model = build_model(model_params, text_aligner, pitch_extractor, plbert) + _ = [model[key].to(device) for key in model] + + # DP + for key in model: + if key != "mpd" and key != "msd" and key != "wd": + model[key] = MyDataParallel(model[key]) + + start_epoch = 0 + iters = 0 + + load_pretrained = config.get("pretrained_model", "") != "" and config.get( + "second_stage_load_pretrained", False + ) + + if not load_pretrained: + if config.get("first_stage_path", "") != "": + first_stage_path = osp.join( + log_dir, config.get("first_stage_path", "first_stage.pth") + ) + print("Loading the first stage model at %s ..." % first_stage_path) + model, _, start_epoch, iters = load_checkpoint( + model, + None, + first_stage_path, + load_only_params=True, + ignore_modules=[ + "bert", + "bert_encoder", + "predictor", + "predictor_encoder", + "msd", + "mpd", + "wd", + "diffusion", + ], + ) # keep starting epoch for tensorboard log + + # these epochs should be counted from the start epoch + diff_epoch += start_epoch + joint_epoch += start_epoch + epochs += start_epoch + + model.predictor_encoder = copy.deepcopy(model.style_encoder) + else: + raise ValueError("You need to specify the path to the first stage model.") + + gl = GeneratorLoss(model.mpd, model.msd).to(device) + dl = DiscriminatorLoss(model.mpd, model.msd).to(device) + wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device) + + gl = MyDataParallel(gl) + dl = MyDataParallel(dl) + wl = MyDataParallel(wl) + + sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule( + sigma_min=0.0001, sigma_max=3.0, rho=9.0 + ), # empirical parameters + clamp=False, + ) + + scheduler_params = { + "max_lr": optimizer_params.lr, + "pct_start": float(0), + "epochs": epochs, + "steps_per_epoch": len(train_dataloader), + } + scheduler_params_dict = {key: scheduler_params.copy() for key in model} + scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2 + scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2 + scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2 + + optimizer = build_optimizer( + {key: model[key].parameters() for key in model}, + scheduler_params_dict=scheduler_params_dict, + lr=optimizer_params.lr, + ) + + # adjust BERT learning rate + for g in optimizer.optimizers["bert"].param_groups: + g["betas"] = (0.9, 0.99) + g["lr"] = optimizer_params.bert_lr + g["initial_lr"] = optimizer_params.bert_lr + g["min_lr"] = 0 + g["weight_decay"] = 0.01 + + # adjust acoustic module learning rate + for module in ["decoder", "style_encoder"]: + for g in optimizer.optimizers[module].param_groups: + g["betas"] = (0.0, 0.99) + g["lr"] = optimizer_params.ft_lr + g["initial_lr"] = optimizer_params.ft_lr + g["min_lr"] = 0 + g["weight_decay"] = 1e-4 + + # load models if there is a model + if load_pretrained: + model, optimizer, start_epoch, iters = load_checkpoint( + model, + optimizer, + config["pretrained_model"], + load_only_params=config.get("load_only_params", True), + ) + + n_down = model.text_aligner.n_down + + best_loss = float("inf") # best test loss + loss_train_record = list([]) + loss_test_record = list([]) + iters = 0 + + criterion = nn.L1Loss() # F0 loss (regression) + torch.cuda.empty_cache() + + stft_loss = MultiResolutionSTFTLoss().to(device) + + print("BERT", optimizer.optimizers["bert"]) + print("decoder", optimizer.optimizers["decoder"]) + + start_ds = False + + running_std = [] + + slmadv_params = Munch(config["slmadv_params"]) + slmadv = SLMAdversarialLoss( + model, + wl, + sampler, + slmadv_params.min_len, + slmadv_params.max_len, + batch_percentage=slmadv_params.batch_percentage, + skip_update=slmadv_params.iter, + sig=slmadv_params.sig, + ) + + for epoch in range(start_epoch, epochs): + running_loss = 0 + start_time = time.time() + + _ = [model[key].eval() for key in model] + + model.predictor.train() + model.bert_encoder.train() + model.bert.train() + model.msd.train() + model.mpd.train() + + if epoch >= diff_epoch: + start_ds = True + + for i, batch in enumerate(train_dataloader): + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + ( + texts, + input_lengths, + ref_texts, + ref_lengths, + mels, + mel_input_length, + ref_mels, + ) = batch + + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to(device) + mel_mask = length_to_mask(mel_input_length).to(device) + text_mask = length_to_mask(input_lengths).to(texts.device) + + try: + _, _, s2s_attn = model.text_aligner(mels, mask, texts) + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + except: + continue + + mask_ST = mask_from_lens( + s2s_attn, input_lengths, mel_input_length // (2**n_down) + ) + s2s_attn_mono = maximum_path(s2s_attn, mask_ST) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + asr = t_en @ s2s_attn_mono + + d_gt = s2s_attn_mono.sum(axis=-1).detach() + + # compute reference styles + if multispeaker and epoch >= diff_epoch: + ref_ss = model.style_encoder(ref_mels.unsqueeze(1)) + ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1)) + ref = torch.cat([ref_ss, ref_sp], dim=1) + + # compute the style of the entire utterance + # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool) + ss = [] + gs = [] + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item()) + mel = mels[bib, :, : mel_input_length[bib]] + s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1)) + ss.append(s) + s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) + gs.append(s) + + s_dur = torch.stack(ss).squeeze() # global prosodic styles + gs = torch.stack(gs).squeeze() # global acoustic styles + s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser + + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + # denoiser training + if epoch >= diff_epoch: + num_steps = np.random.randint(3, 5) + + if model_params.diffusion.dist.estimate_sigma_data: + model.diffusion.module.diffusion.sigma_data = ( + s_trg.std(axis=-1).mean().item() + ) # batch-wise std estimation + running_std.append(model.diffusion.module.diffusion.sigma_data) + + if multispeaker: + s_preds = sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=1, + features=ref, # reference from the same speaker as the embedding + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + loss_diff = model.diffusion( + s_trg.unsqueeze(1), embedding=bert_dur, features=ref + ).mean() # EDM loss + loss_sty = F.l1_loss( + s_preds, s_trg.detach() + ) # style reconstruction loss + else: + s_preds = sampler( + noise=torch.randn_like(s_trg).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=1, + embedding_mask_proba=0.1, + num_steps=num_steps, + ).squeeze(1) + loss_diff = model.diffusion.module.diffusion( + s_trg.unsqueeze(1), embedding=bert_dur + ).mean() # EDM loss + loss_sty = F.l1_loss( + s_preds, s_trg.detach() + ) # style reconstruction loss + else: + loss_sty = 0 + loss_diff = 0 + + d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask) + + mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2) + mel_len_st = int(mel_input_length.min().item() / 2 - 1) + en = [] + gt = [] + st = [] + p_en = [] + wav = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + p_en.append(p[bib, :, random_start : random_start + mel_len]) + gt.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)] + ) + + y = waves[bib][ + (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 + ] + wav.append(torch.from_numpy(y).to(device)) + + # style reference (better to be different from the GT) + random_start = np.random.randint(0, mel_length - mel_len_st) + st.append( + mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)] + ) + + wav = torch.stack(wav).float().detach() + + en = torch.stack(en) + p_en = torch.stack(p_en) + gt = torch.stack(gt).detach() + st = torch.stack(st).detach() + + if gt.size(-1) < 80: + continue + + s_dur = model.predictor_encoder( + st.unsqueeze(1) if multispeaker else gt.unsqueeze(1) + ) + s = model.style_encoder( + st.unsqueeze(1) if multispeaker else gt.unsqueeze(1) + ) + + with torch.no_grad(): + F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) + F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze() + + asr_real = model.text_aligner.get_feature(gt) + + N_real = log_norm(gt.unsqueeze(1)).squeeze(1) + + y_rec_gt = wav.unsqueeze(1) + y_rec_gt_pred = model.decoder(en, F0_real, N_real, s) + + if epoch >= joint_epoch: + # ground truth from recording + wav = y_rec_gt # use recording since decoder is tuned + else: + # ground truth from reconstruction + wav = y_rec_gt_pred # use reconstruction since decoder is fixed + + F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur) + + y_rec = model.decoder(en, F0_fake, N_fake, s) + + loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10 + loss_norm_rec = F.smooth_l1_loss(N_real, N_fake) + + if start_ds: + optimizer.zero_grad() + d_loss = dl(wav.detach(), y_rec.detach()).mean() + d_loss.backward() + optimizer.step("msd") + optimizer.step("mpd") + else: + d_loss = 0 + + # generator loss + optimizer.zero_grad() + + loss_mel = stft_loss(y_rec, wav) + if start_ds: + loss_gen_all = gl(wav, y_rec).mean() + else: + loss_gen_all = 0 + loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean() + + loss_ce = 0 + loss_dur = 0 + for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths): + _s2s_pred = _s2s_pred[:_text_length, :] + _text_input = _text_input[:_text_length].long() + _s2s_trg = torch.zeros_like(_s2s_pred) + for p in range(_s2s_trg.shape[0]): + _s2s_trg[p, : _text_input[p]] = 1 + _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1) + + loss_dur += F.l1_loss( + _dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1] + ) + loss_ce += F.binary_cross_entropy_with_logits( + _s2s_pred.flatten(), _s2s_trg.flatten() + ) + + loss_ce /= texts.size(0) + loss_dur /= texts.size(0) + + g_loss = ( + loss_params.lambda_mel * loss_mel + + loss_params.lambda_F0 * loss_F0_rec + + loss_params.lambda_ce * loss_ce + + loss_params.lambda_norm * loss_norm_rec + + loss_params.lambda_dur * loss_dur + + loss_params.lambda_gen * loss_gen_all + + loss_params.lambda_slm * loss_lm + + loss_params.lambda_sty * loss_sty + + loss_params.lambda_diff * loss_diff + ) + + running_loss += loss_mel.item() + g_loss.backward() + if torch.isnan(g_loss): + from IPython.core.debugger import set_trace + + set_trace() + + optimizer.step("bert_encoder") + optimizer.step("bert") + optimizer.step("predictor") + optimizer.step("predictor_encoder") + + if epoch >= diff_epoch: + optimizer.step("diffusion") + + if epoch >= joint_epoch: + optimizer.step("style_encoder") + optimizer.step("decoder") + + # randomly pick whether to use in-distribution text + if np.random.rand() < 0.5: + use_ind = True + else: + use_ind = False + + if use_ind: + ref_lengths = input_lengths + ref_texts = texts + + slm_out = slmadv( + i, + y_rec_gt, + y_rec_gt_pred, + waves, + mel_input_length, + ref_texts, + ref_lengths, + use_ind, + s_trg.detach(), + ref if multispeaker else None, + ) + + if slm_out is None: + continue + + d_loss_slm, loss_gen_lm, y_pred = slm_out + + # SLM generator loss + optimizer.zero_grad() + loss_gen_lm.backward() + + # SLM discriminator loss + if d_loss_slm != 0: + optimizer.zero_grad() + d_loss_slm.backward(retain_graph=True) + optimizer.step("wd") + + # compute the gradient norm + total_norm = {} + for key in model.keys(): + total_norm[key] = 0 + parameters = [ + p + for p in model[key].parameters() + if p.grad is not None and p.requires_grad + ] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + total_norm[key] += param_norm.item() ** 2 + total_norm[key] = total_norm[key] ** 0.5 + + # gradient scaling + if total_norm["predictor"] > slmadv_params.thresh: + for key in model.keys(): + for p in model[key].parameters(): + if p.grad is not None: + p.grad *= 1 / total_norm["predictor"] + + for p in model.predictor.duration_proj.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + for p in model.predictor.lstm.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + for p in model.diffusion.parameters(): + if p.grad is not None: + p.grad *= slmadv_params.scale + + optimizer.step("bert_encoder") + optimizer.step("bert") + optimizer.step("predictor") + optimizer.step("diffusion") + else: + d_loss_slm, loss_gen_lm = 0, 0 + + iters = iters + 1 + + if (i + 1) % log_interval == 0: + logger.info( + "Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f" + % ( + epoch + 1, + epochs, + i + 1, + len(train_list) // batch_size, + running_loss / log_interval, + d_loss, + loss_dur, + loss_ce, + loss_norm_rec, + loss_F0_rec, + loss_lm, + loss_gen_all, + loss_sty, + loss_diff, + d_loss_slm, + loss_gen_lm, + ) + ) + + writer.add_scalar("train/mel_loss", running_loss / log_interval, iters) + writer.add_scalar("train/gen_loss", loss_gen_all, iters) + writer.add_scalar("train/d_loss", d_loss, iters) + writer.add_scalar("train/ce_loss", loss_ce, iters) + writer.add_scalar("train/dur_loss", loss_dur, iters) + writer.add_scalar("train/slm_loss", loss_lm, iters) + writer.add_scalar("train/norm_loss", loss_norm_rec, iters) + writer.add_scalar("train/F0_loss", loss_F0_rec, iters) + writer.add_scalar("train/sty_loss", loss_sty, iters) + writer.add_scalar("train/diff_loss", loss_diff, iters) + writer.add_scalar("train/d_loss_slm", d_loss_slm, iters) + writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters) + + running_loss = 0 + + print("Time elasped:", time.time() - start_time) + + loss_test = 0 + loss_align = 0 + loss_f = 0 + _ = [model[key].eval() for key in model] + + with torch.no_grad(): + iters_test = 0 + for batch_idx, batch in enumerate(val_dataloader): + optimizer.zero_grad() + + try: + waves = batch[0] + batch = [b.to(device) for b in batch[1:]] + ( + texts, + input_lengths, + ref_texts, + ref_lengths, + mels, + mel_input_length, + ref_mels, + ) = batch + with torch.no_grad(): + mask = length_to_mask(mel_input_length // (2**n_down)).to( + "cuda" + ) + text_mask = length_to_mask(input_lengths).to(texts.device) + + _, _, s2s_attn = model.text_aligner(mels, mask, texts) + s2s_attn = s2s_attn.transpose(-1, -2) + s2s_attn = s2s_attn[..., 1:] + s2s_attn = s2s_attn.transpose(-1, -2) + + mask_ST = mask_from_lens( + s2s_attn, input_lengths, mel_input_length // (2**n_down) + ) + s2s_attn_mono = maximum_path(s2s_attn, mask_ST) + + # encode + t_en = model.text_encoder(texts, input_lengths, text_mask) + asr = t_en @ s2s_attn_mono + + d_gt = s2s_attn_mono.sum(axis=-1).detach() + + ss = [] + gs = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item()) + mel = mels[bib, :, : mel_input_length[bib]] + s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1)) + ss.append(s) + s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1)) + gs.append(s) + + s = torch.stack(ss).squeeze() + gs = torch.stack(gs).squeeze() + s_trg = torch.cat([s, gs], dim=-1).detach() + + bert_dur = model.bert(texts, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + d, p = model.predictor( + d_en, s, input_lengths, s2s_attn_mono, text_mask + ) + # get clips + mel_len = int(mel_input_length.min().item() / 2 - 1) + en = [] + gt = [] + p_en = [] + wav = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item() / 2) + + random_start = np.random.randint(0, mel_length - mel_len) + en.append(asr[bib, :, random_start : random_start + mel_len]) + p_en.append(p[bib, :, random_start : random_start + mel_len]) + + gt.append( + mels[ + bib, + :, + (random_start * 2) : ((random_start + mel_len) * 2), + ] + ) + + y = waves[bib][ + (random_start * 2) + * 300 : ((random_start + mel_len) * 2) + * 300 + ] + wav.append(torch.from_numpy(y).to(device)) + + wav = torch.stack(wav).float().detach() + + en = torch.stack(en) + p_en = torch.stack(p_en) + gt = torch.stack(gt).detach() + + s = model.predictor_encoder(gt.unsqueeze(1)) + + F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s) + + loss_dur = 0 + for _s2s_pred, _text_input, _text_length in zip( + d, (d_gt), input_lengths + ): + _s2s_pred = _s2s_pred[:_text_length, :] + _text_input = _text_input[:_text_length].long() + _s2s_trg = torch.zeros_like(_s2s_pred) + for bib in range(_s2s_trg.shape[0]): + _s2s_trg[bib, : _text_input[bib]] = 1 + _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1) + loss_dur += F.l1_loss( + _dur_pred[1 : _text_length - 1], + _text_input[1 : _text_length - 1], + ) + + loss_dur /= texts.size(0) + + s = model.style_encoder(gt.unsqueeze(1)) + + y_rec = model.decoder(en, F0_fake, N_fake, s) + loss_mel = stft_loss(y_rec.squeeze(), wav.detach()) + + F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) + + loss_F0 = F.l1_loss(F0_real, F0_fake) / 10 + + loss_test += (loss_mel).mean() + loss_align += (loss_dur).mean() + loss_f += (loss_F0).mean() + + iters_test += 1 + except: + continue + + print("Epochs:", epoch + 1) + logger.info( + "Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f" + % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + + "\n\n\n" + ) + print("\n\n\n") + writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1) + writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1) + writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1) + + if epoch < joint_epoch: + # generating reconstruction examples with GT duration + + with torch.no_grad(): + for bib in range(len(asr)): + mel_length = int(mel_input_length[bib].item()) + gt = mels[bib, :, :mel_length].unsqueeze(0) + en = asr[bib, :, : mel_length // 2].unsqueeze(0) + + F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1)) + F0_real = F0_real.unsqueeze(0) + s = model.style_encoder(gt.unsqueeze(1)) + real_norm = log_norm(gt.unsqueeze(1)).squeeze(1) + + y_rec = model.decoder(en, F0_real, real_norm, s) + + writer.add_audio( + "eval/y" + str(bib), + y_rec.cpu().numpy().squeeze(), + epoch, + sample_rate=sr, + ) + + s_dur = model.predictor_encoder(gt.unsqueeze(1)) + p_en = p[bib, :, : mel_length // 2].unsqueeze(0) + + F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur) + + y_pred = model.decoder(en, F0_fake, N_fake, s) + + writer.add_audio( + "pred/y" + str(bib), + y_pred.cpu().numpy().squeeze(), + epoch, + sample_rate=sr, + ) + + if epoch == 0: + writer.add_audio( + "gt/y" + str(bib), + waves[bib].squeeze(), + epoch, + sample_rate=sr, + ) + + if bib >= 5: + break + else: + # generating sampled speech from text directly + with torch.no_grad(): + # compute reference styles + if multispeaker and epoch >= diff_epoch: + ref_ss = model.style_encoder(ref_mels.unsqueeze(1)) + ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1)) + ref_s = torch.cat([ref_ss, ref_sp], dim=1) + + for bib in range(len(d_en)): + if multispeaker: + s_pred = sampler( + noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device), + embedding=bert_dur[bib].unsqueeze(0), + embedding_scale=1, + features=ref_s[bib].unsqueeze( + 0 + ), # reference from the same speaker as the embedding + num_steps=5, + ).squeeze(1) + else: + s_pred = sampler( + noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device), + embedding=bert_dur[bib].unsqueeze(0), + embedding_scale=1, + num_steps=5, + ).squeeze(1) + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + d = model.predictor.text_encoder( + d_en[bib, :, : input_lengths[bib]].unsqueeze(0), + s, + input_lengths[bib, ...].unsqueeze(0), + text_mask[bib, : input_lengths[bib]].unsqueeze(0), + ) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + pred_dur[-1] += 5 + + pred_aln_trg = torch.zeros( + input_lengths[bib], int(pred_dur.sum().data) + ) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame : c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to( + texts.device + ) + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + out = model.decoder( + ( + t_en[bib, :, : input_lengths[bib]].unsqueeze(0) + @ pred_aln_trg.unsqueeze(0).to(texts.device) + ), + F0_pred, + N_pred, + ref.squeeze().unsqueeze(0), + ) + + writer.add_audio( + "pred/y" + str(bib), + out.cpu().numpy().squeeze(), + epoch, + sample_rate=sr, + ) + + if bib >= 5: + break + + if epoch % saving_epoch == 0: + if (loss_test / iters_test) < best_loss: + best_loss = loss_test / iters_test + print("Saving..") + state = { + "net": {key: model[key].state_dict() for key in model}, + "optimizer": optimizer.state_dict(), + "iters": iters, + "val_loss": loss_test / iters_test, + "epoch": epoch, + } + save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch) + torch.save(state, save_path) + + # if estimate sigma, save the estimated simga + if model_params.diffusion.dist.estimate_sigma_data: + config["model_params"]["diffusion"]["dist"]["sigma_data"] = float( + np.mean(running_std) + ) + + with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile: + yaml.dump(config, outfile, default_flow_style=True) + + +if __name__ == "__main__": + main() diff --git a/src/transforms.py b/src/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a11f799e023864ff7082c1f49c0cc18351a13b47 --- /dev/null +++ b/src/transforms.py @@ -0,0 +1,209 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b048fbb203d3ac662f670d2fdcd8b310d5593a7 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,277 @@ +from monotonic_align import maximum_path +from monotonic_align import mask_from_lens +from monotonic_align.core import maximum_path_c +import numpy as np +import torch +import copy +from torch import nn +import torch.nn.functional as F +import torchaudio +import librosa +import matplotlib.pyplot as plt +from munch import Munch + +import re +import json +import numpy as np + + +def maximum_path(neg_cent, mask): + """Cython optimized version. + neg_cent: [b, t_t, t_s] + mask: [b, t_t, t_s] + """ + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32)) + path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32)) + + t_t_max = np.ascontiguousarray( + mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) + ) + t_s_max = np.ascontiguousarray( + mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) + ) + maximum_path_c(path, neg_cent, t_t_max, t_s_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +def get_data_path_list(train_path=None, val_path=None): + if train_path is None: + train_path = "Data/train_list.txt" + if val_path is None: + val_path = "Data/val_list.txt" + + with open(train_path, "r", encoding="utf-8", errors="ignore") as f: + train_list = f.readlines() + with open(val_path, "r", encoding="utf-8", errors="ignore") as f: + val_list = f.readlines() + + return train_list, val_list + + +def length_to_mask(lengths): + mask = ( + torch.arange(lengths.max()) + .unsqueeze(0) + .expand(lengths.shape[0], -1) + .type_as(lengths) + ) + mask = torch.gt(mask + 1, lengths.unsqueeze(1)) + return mask + + +# for norm consistency loss +def log_norm(x, mean=-4, std=4, dim=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) + return x + + +def get_image(arrs): + plt.switch_backend("agg") + fig = plt.figure() + ax = plt.gca() + ax.imshow(arrs) + + return fig + + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d + + +def log_print(message, logger): + logger.info(message) + print(message) + + + + +def get_hparams_from_file(config_path): + with open(config_path, "r", encoding="utf-8") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +def string_to_bits(string, pad_len=8): + # Convert each character to its ASCII value + ascii_values = [ord(char) for char in string] + + # Convert ASCII values to binary representation + binary_values = [bin(value)[2:].zfill(8) for value in ascii_values] + + # Convert binary strings to integer arrays + bit_arrays = [[int(bit) for bit in binary] for binary in binary_values] + + # Convert list of arrays to NumPy array + numpy_array = np.array(bit_arrays) + numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype) + numpy_array_full[:, 2] = 1 + max_len = min(pad_len, len(numpy_array)) + numpy_array_full[:max_len] = numpy_array[:max_len] + return numpy_array_full + + +def bits_to_string(bits_array): + # Convert each row of the array to a binary string + binary_values = [''.join(str(bit) for bit in row) for row in bits_array] + + # Convert binary strings to ASCII values + ascii_values = [int(binary, 2) for binary in binary_values] + + # Convert ASCII values to characters + output_string = ''.join(chr(value) for value in ascii_values) + + return output_string + + +def split_sentence(text, min_len=10, language_str='[EN]'): + if language_str in ['EN']: + sentences = split_sentences_latin(text, min_len=min_len) + else: + sentences = split_sentences_zh(text, min_len=min_len) + return sentences + +def split_sentences_latin(text, min_len=10): + """Split Long sentences into list of short ones + + Args: + str: Input sentences. + + Returns: + List[str]: list of output sentences. + """ + # deal with dirty sentences + text = re.sub('[。!?;]', '.', text) + text = re.sub('[,]', ',', text) + text = re.sub('[“”]', '"', text) + text = re.sub('[‘’]', "'", text) + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + text = re.sub('[\n\t ]+', ' ', text) + text = re.sub('([,.!?;])', r'\1 $#!', text) + # split + sentences = [s.strip() for s in text.split('$#!')] + if len(sentences[-1]) == 0: del sentences[-1] + + new_sentences = [] + new_sent = [] + count_len = 0 + for ind, sent in enumerate(sentences): + # print(sent) + new_sent.append(sent) + count_len += len(sent.split(" ")) + if count_len > min_len or ind == len(sentences) - 1: + count_len = 0 + new_sentences.append(' '.join(new_sent)) + new_sent = [] + return merge_short_sentences_latin(new_sentences) + + +def merge_short_sentences_latin(sens): + sens_out = [] + for s in sens: + # If the previous sentense is too short, merge them with + # the current sentence. + if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: + sens_out[-1] = sens_out[-1] + " " + s + else: + sens_out.append(s) + try: + if len(sens_out[-1].split(" ")) <= 2: + sens_out[-2] = sens_out[-2] + " " + sens_out[-1] + sens_out.pop(-1) + except: + pass + return sens_out + +def split_sentences_zh(text, min_len=10): + text = re.sub('[。!?;]', '.', text) + text = re.sub('[,]', ',', text) + # 将文本中的换行符、空格和制表符替换为空格 + text = re.sub('[\n\t ]+', ' ', text) + # 在标点符号后添加一个空格 + text = re.sub('([,.!?;])', r'\1 $#!', text) + # 分隔句子并去除前后空格 + # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)] + sentences = [s.strip() for s in text.split('$#!')] + if len(sentences[-1]) == 0: del sentences[-1] + + new_sentences = [] + new_sent = [] + count_len = 0 + for ind, sent in enumerate(sentences): + new_sent.append(sent) + count_len += len(sent) + if count_len > min_len or ind == len(sentences) - 1: + count_len = 0 + new_sentences.append(' '.join(new_sent)) + new_sent = [] + return merge_short_sentences_zh(new_sentences) + +def merge_short_sentences_zh(sens): + # return sens + """Avoid short sentences by merging them with the following sentence. + + Args: + List[str]: list of input sentences. + + Returns: + List[str]: list of output sentences. + """ + sens_out = [] + for s in sens: + # If the previous sentense is too short, merge them with + # the current sentence. + if len(sens_out) > 0 and len(sens_out[-1]) <= 2: + sens_out[-1] = sens_out[-1] + " " + s + else: + sens_out.append(s) + try: + if len(sens_out[-1]) <= 2: + sens_out[-2] = sens_out[-2] + " " + sens_out[-1] + sens_out.pop(-1) + except: + pass + return sens_out \ No newline at end of file diff --git a/src/vitstest.py b/src/vitstest.py new file mode 100644 index 0000000000000000000000000000000000000000..df0757dd316449e3494221e4b7f14c46c8e6eaae --- /dev/null +++ b/src/vitstest.py @@ -0,0 +1,43 @@ +import os +import torch +import se_extractor +from api import BaseSpeakerTTS, ToneColorConverter +ckpt_base = 'checkpoints/base_speakers/EN' +ckpt_converter = 'checkpoints/converter' +device = 'cuda:0' +output_dir = 'outputs' + +base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device) +base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth') + +tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) +tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') + +os.makedirs(output_dir, exist_ok=True) + +source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device) + +reference_speaker = '/root/src/videly/voices/m-us-3.wav' +target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) +save_path = f'{output_dir}/output_en_default.wav' + +# Run the base speaker tts +text = ''' +Harry Potter stood silently at the edge of the Forbidden Forest, his wand gripped tightly in his hand. The moon cast a silvery glow over the ancient trees, creating a scene both eerie and beautiful. Harry's mind was racing with thoughts of the recent events at Hogwarts. The castle, usually a place of learning and friendship, had transformed into a battleground where the forces of good and evil clashed. + +In the distance, the unmistakable silhouette of Hogwarts Castle stood against the starry sky, its spires reaching up like fingers grasping for hope. Harry felt a deep connection to the place, not just as a student, but as someone who had faced and overcome great trials within its walls. He thought of his friends Ron and Hermione, who had been with him through thick and thin, their loyalty never wavering. + +As an owl hooted in the distance, Harry remembered the letters that had started his journey into the wizarding world. He had been just a boy then, unaware of his lineage and the future that awaited him. Now, he stood as a symbol of courage and resilience, a young wizard who had faced Lord Voldemort and lived to tell the tale. + +The events of the Triwizard Tournament flashed through his mind. The excitement, the danger, the tragic loss of Cedric Diggory - it was a reminder of the thin line between life and death in their world. Harry felt a pang of sadness, knowing that the path ahead was fraught with peril. Yet, he also felt a strong sense of purpose. He was not just fighting for himself, but for the entire wizarding community, for a world free from the tyranny of Voldemort. +''' +src_path = f'{output_dir}/tmp.wav' +base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0) + +# Run the tone color converter +tone_color_converter.convert( + audio_src_path=src_path, + src_se=source_se, + tgt_se=target_se, + output_path=save_path, + message='') \ No newline at end of file diff --git a/src/voices.pkl b/src/voices.pkl new file mode 100644 index 0000000000000000000000000000000000000000..770533f59bcdf28996c0cf7112fbc0f2fdc34746 --- /dev/null +++ b/src/voices.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58e11e1d6726c8992f5325aca8b381ad37facbd7380ebb5f5e04d77a017b4ee3 +size 10739 diff --git a/src/voices/andrew_huberman.wav b/src/voices/andrew_huberman.wav new file mode 100644 index 0000000000000000000000000000000000000000..e7a89fe7d7e006f631cb36c445275240dac8e4b4 --- /dev/null +++ b/src/voices/andrew_huberman.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a7f329506a6c5dff1a3210e3eaa4eec6c2ddf589a1207b4f60f774be30d7fab +size 723534 diff --git a/src/voices/f-us-1.wav b/src/voices/f-us-1.wav new file mode 100644 index 0000000000000000000000000000000000000000..bdc060809b7fcbae90ae0ec1bdd750024037ca1e --- /dev/null +++ b/src/voices/f-us-1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6602240f8c2789447d3a4a37b3ea3202c384164c7151cbe252dc7773144ab59c +size 1780846 diff --git a/src/voices/f-us-2.wav b/src/voices/f-us-2.wav new file mode 100644 index 0000000000000000000000000000000000000000..0afd467e58ef75c2e1baa822c5059c0f1f95b6b0 --- /dev/null +++ b/src/voices/f-us-2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:673cc52f812b3f91389500e1e81cb1e4ca4660348cba3145296c758d38a8f45d +size 903884 diff --git a/src/voices/f-us-3.wav b/src/voices/f-us-3.wav new file mode 100644 index 0000000000000000000000000000000000000000..c4253d89fb2e4af0f26f1d9b24f905ada3e6e7c3 --- /dev/null +++ b/src/voices/f-us-3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9925afe15657cfd6e38f0bbc1c40610c8e389b70c3a301b3214358e185d347b +size 691724 diff --git a/src/voices/f-us-4.wav b/src/voices/f-us-4.wav new file mode 100644 index 0000000000000000000000000000000000000000..0c5a563f7937dc81c25c08c4d56e0731cf380768 --- /dev/null +++ b/src/voices/f-us-4.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80add6d3f14d2dbf3c0046ed0bb65c0857831d71143c2d24d86555918dcde125 +size 1493326 diff --git a/src/voices/huberman_clone.wav b/src/voices/huberman_clone.wav new file mode 100644 index 0000000000000000000000000000000000000000..0b2dd8bd459eb4439c3e3e82b2e611746399dce6 --- /dev/null +++ b/src/voices/huberman_clone.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1fd3618660ac523d64533348b449a3c31e1850f752e6152acbfa4da53a7a20b +size 1857102 diff --git a/src/voices/m-us-1.wav b/src/voices/m-us-1.wav new file mode 100644 index 0000000000000000000000000000000000000000..905d5d967878c503ffee89311715bff26fe61a44 --- /dev/null +++ b/src/voices/m-us-1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62348b20144d9471350dcaa88d76a3f4992070a056da51968b6500e873053fb8 +size 346124 diff --git a/src/voices/m-us-2.wav b/src/voices/m-us-2.wav new file mode 100644 index 0000000000000000000000000000000000000000..e6a02318eb7e52f9bcc27d71547536fb9f2d5120 --- /dev/null +++ b/src/voices/m-us-2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7cc4b86e4b1a1c2f12dc16652b61676bfff949d3ff33199cd9bf1cf42b6db2b +size 1047882 diff --git a/src/voices/m-us-3.wav b/src/voices/m-us-3.wav new file mode 100644 index 0000000000000000000000000000000000000000..61fc70890baab75c23af83b0d78845f6750d10c5 --- /dev/null +++ b/src/voices/m-us-3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8937a2581208baf15e80ce7f648edaf58b77547e1c64acc2a4fe6be294163fa9 +size 1060370 diff --git a/src/voices/m-us-4.wav b/src/voices/m-us-4.wav new file mode 100644 index 0000000000000000000000000000000000000000..c9ee9a9d263f88b2856163e478cf991d2183d2e9 --- /dev/null +++ b/src/voices/m-us-4.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ad56d8f95c00d7a0d8bfd26ae11806638024166015f63d211ce79531ccf16b0 +size 640856 diff --git a/src/voices/obama_clone.wav b/src/voices/obama_clone.wav new file mode 100644 index 0000000000000000000000000000000000000000..44193b795510f79c0b05df0be4639b941fd5e6cc --- /dev/null +++ b/src/voices/obama_clone.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa15ad073b810405ddf5fcbd435388a76ef2f524a394e896b7d8ee4cb229e898 +size 576224 diff --git a/src/voices/obama_clone_ref_after_resemble.wav b/src/voices/obama_clone_ref_after_resemble.wav new file mode 100644 index 0000000000000000000000000000000000000000..1f324801dd35ac1296c8d62542cb61e24c12a05b --- /dev/null +++ b/src/voices/obama_clone_ref_after_resemble.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:729e7cc1938be57fc53a289e4daff0ab99bf8fbd9612d7ade8c49367a1115584 +size 1058480