import os import subprocess import locale locale.getpreferredencoding = lambda: "UTF-8" # def download(lang, tgt_dir="./"): # lang_fn, lang_dir = os.path.join(tgt_dir, lang+'.tar.gz'), os.path.join(tgt_dir, lang) # cmd = ";".join([ # f"wget https://dl.fbaipublicfiles.com/mms/tts/{lang}.tar.gz -O {lang_fn}", # f"tar zxvf {lang_fn}" # ]) # print(f"Download model for language: {lang}") # subprocess.check_output(cmd, shell=True) # print(f"Model checkpoints in {lang_dir}: {os.listdir(lang_dir)}") # return lang_dir LANG = "spa" ckpt_dir = "/workspaces/text_to_speach/spa"#download(LANG) import os import sys # Add the path of the 'another_folder' to sys.path import sys # caution: path[0] is reserved for script path (or '' in REPL) sys.path.insert(1, '/workspaces/text_to_speach/vits') import re import glob import json import tempfile import math import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader import numpy as np import commons import utils import argparse import subprocess from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate from models import SynthesizerTrn from scipy.io.wavfile import write def preprocess_char(text, lang=None): """ Special treatement of characters in certain languages """ print(lang) if lang == 'ron': text = text.replace("ț", "ţ") return text class TextMapper(object): def __init__(self, vocab_file): self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()] self.SPACE_ID = self.symbols.index(" ") self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} def text_to_sequence(self, text, cleaner_names): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: text: string to convert to a sequence cleaner_names: names of the cleaner functions to run the text through Returns: List of integers corresponding to the symbols in the text ''' sequence = [] clean_text = text.strip() for symbol in clean_text: symbol_id = self._symbol_to_id[symbol] sequence += [symbol_id] return sequence def uromanize(self, text, uroman_pl): iso = "xxx" with tempfile.NamedTemporaryFile() as tf, \ tempfile.NamedTemporaryFile() as tf2: with open(tf.name, "w") as f: f.write("\n".join([text])) cmd = f"perl " + uroman_pl cmd += f" -l {iso} " cmd += f" < {tf.name} > {tf2.name}" os.system(cmd) outtexts = [] with open(tf2.name) as f: for line in f: line = re.sub(r"\s+", " ", line).strip() outtexts.append(line) outtext = outtexts[0] return outtext def get_text(self, text, hps): text_norm = self.text_to_sequence(text, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def filter_oov(self, text): val_chars = self._symbol_to_id txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) print(f"text after filtering OOV: {txt_filt}") return txt_filt def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None): txt = preprocess_char(txt, lang=lang) is_uroman = hps.data.training_files.split('.')[-1] == 'uroman' if is_uroman: with tempfile.TemporaryDirectory() as tmp_dir: if uroman_dir is None: cmd = f"git clone git@github.com:isi-nlp/uroman.git {tmp_dir}" print(cmd) subprocess.check_output(cmd, shell=True) uroman_dir = tmp_dir uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl") print(f"uromanize") txt = text_mapper.uromanize(txt, uroman_pl) print(f"uroman text: {txt}") txt = txt.lower() txt = text_mapper.filter_oov(txt) return txt if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") print(f"Run inference with {device}") vocab_file = f"{ckpt_dir}/vocab.txt" config_file = f"{ckpt_dir}/config.json" assert os.path.isfile(config_file), f"{config_file} doesn't exist" hps = utils.get_hparams_from_file(config_file) text_mapper = TextMapper(vocab_file) net_g = SynthesizerTrn( len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) net_g.to(device) _ = net_g.eval() g_pth = f"{ckpt_dir}/G_100000.pth" print(f"load {g_pth}") _ = utils.load_checkpoint(g_pth, net_g, None) def generate_audio_mms(text): txt = preprocess_text(text, text_mapper, hps, lang=LANG) stn_tst = text_mapper.get_text(txt, hps) with torch.no_grad(): x_tst = stn_tst.unsqueeze(0).to(device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) hyp = net_g.infer( x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1.0 )[0][0,0].cpu().float().numpy() return hps.data.sampling_rate, hyp