print("WARNING: You are running this unofficial E2/F5 TTS demo locally, it may not be as up-to-date as the hosted version (https://huggingface.co./spaces/mrfakename/E2-F5-TTS)") import os import re import torch import torchaudio import gradio as gr import numpy as np import tempfile from einops import rearrange from ema_pytorch import EMA from vocos import Vocos from pydub import AudioSegment from model import CFM, UNetT, DiT, MMDiT from cached_path import cached_path from model.utils import ( get_tokenizer, convert_char_to_pinyin, save_spectrogram, ) from transformers import pipeline import librosa import re import gc import matplotlib.pyplot as plt import devicetorch device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" gc.collect() devicetorch.empty_cache(torch) print(f"Using {device} device") # --------------------- Settings -------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 target_rms = 0.1 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = 'euler' sway_sampling_coef = -1.0 speed = 1.0 # fix_duration = 27 # None or float (duration in seconds) fix_duration = None def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step): checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device) checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device) vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model = CFM( transformer=model_cls( **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels ), mel_spec_kwargs=dict( target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) ema_model = EMA(model, include_online_model=False).to(device) ema_model.load_state_dict(checkpoint['ema_model_state_dict']) ema_model.copy_params_from_ema_to_model() return ema_model, model # load models F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) def chunk_text(text, max_chars=200): chunks = [] current_chunk = "" sentences = re.split(r'(?<=[.!?])\s+', text) for sentence in sentences: if len(current_chunk) + len(sentence) <= max_chars: current_chunk += sentence + " " else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if current_chunk: chunks.append(current_chunk.strip()) return chunks def save_spectrogram(y, sr, path): plt.figure(figsize=(10, 4)) D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz') plt.colorbar(format='%+2.0f dB') plt.title('Spectrogram') plt.tight_layout() plt.savefig(path) plt.close() def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence): print(gen_text) chunks = chunk_text(gen_text) if not chunks: raise gr.Error("Please enter some text to generate.") # Convert reference audio gr.Info("Converting reference audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) aseg = aseg.set_channels(1) audio_duration = len(aseg) if audio_duration > 15000: gr.Warning("Audio is over 15s, clipping to only first 15s.") aseg = aseg[:15000] aseg.export(f.name, format="wav") ref_audio = f.name # Select model if exp_name == "F5-TTS": ema_model = F5TTS_ema_model base_model = F5TTS_base_model elif exp_name == "E2-TTS": ema_model = E2TTS_ema_model base_model = E2TTS_base_model # Transcribe reference audio if needed if not ref_text.strip(): gr.Info("No reference text provided, transcribing reference audio...") # Initialize Whisper model pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-Turbo", # You can set this to large-V3 if you want better quality, but VRAM then goes to 10 GB torch_dtype=torch.float16, device=device, ) ref_text = pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe"}, return_timestamps=False, )['text'].strip() print("\nTranscribed text: ", ref_text) # Degug transcribing quality gr.Info("\nFinished transcription") # Release Whisper model del pipe devicetorch.empty_cache(torch) gc.collect() else: gr.Info("Using custom reference text...") # Load and preprocess reference audio audio, sr = torchaudio.load(ref_audio) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) # convert to mono rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) audio = audio.to(device) # Process each chunk results = [] spectrograms = [] for i, chunk in enumerate(chunks): gr.Info(f"Processing chunk {i+1}/{len(chunks)}: {chunk[:30]}...") # Prepare the text text_list = [ref_text + chunk] final_text_list = convert_char_to_pinyin(text_list) # Calculate duration ref_audio_len = audio.shape[-1] // hop_length zh_pause_punc = r"。,、;:?!" ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text)) gen_text_len = len(chunk) + len(re.findall(zh_pause_punc, chunk)) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # Inference gr.Info(f"Generating audio using {exp_name}") with torch.inference_mode(): generated, _ = base_model.sample( cond=audio, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') # Clear unnecessary tensors del generated devicetorch.empty_cache(torch) gr.Info("Running vocoder") vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms # Convert to numpy and clear GPU tensors generated_wave = generated_wave.squeeze().cpu().numpy() del generated_mel_spec devicetorch.empty_cache(torch) results.append(generated_wave) # Generate spectrogram with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name save_spectrogram(generated_wave, target_sample_rate, spectrogram_path) spectrograms.append(spectrogram_path) # Clear cache after processing each chunk gc.collect() devicetorch.empty_cache(torch) # Combine all audio chunks combined_audio = np.concatenate(results) if remove_silence: gr.Info("Removing audio silences... This may take a moment") non_silent_intervals = librosa.effects.split(combined_audio, top_db=30) non_silent_wave = np.array([]) for interval in non_silent_intervals: start, end = interval non_silent_wave = np.concatenate([non_silent_wave, combined_audio[start:end]]) combined_audio = non_silent_wave # Generate final spectrogram with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: final_spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_audio, target_sample_rate, final_spectrogram_path) # Final cleanup gc.collect() devicetorch.empty_cache(torch) # Return combined audio and the final spectrogram return (target_sample_rate, combined_audio), final_spectrogram_path with gr.Blocks() as app: ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") gen_text_input = gr.Textbox(label="Text to Generate (for longer than 200 chars the app uses chunking)", lines=4) model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS") generate_btn = gr.Button("Synthesize", variant="primary") with gr.Accordion("Advanced Settings", open=False): ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2) remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True) audio_output = gr.Audio(label="Synthesized Audio") spectrogram_output = gr.Image(label="Spectrogram") generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, spectrogram_output]) gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)") app.queue().launch()