from __future__ import annotations import os # By using XTTS you agree to CPML license https://coqui.ai/cpml os.environ["COQUI_TOS_AGREED"] = "1" import gradio as gr import numpy as np import torch import nltk # we'll use this to split into sentences nltk.download('punkt') import uuid import librosa import torchaudio from TTS.api import TTS from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.utils.generic_utils import get_user_data_dir # This will trigger downloading model print("Downloading if not downloaded Coqui XTTS V1") tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1") del tts print("XTTS downloaded") print("Loading XTTS") #Below will use model directly for inference model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) model = Xtts.init_from_config(config) model.load_checkpoint( config, checkpoint_path=os.path.join(model_path, "model.pth"), vocab_path=os.path.join(model_path, "vocab.json"), eval=True, use_deepspeed=True ) model.cuda() print("Done loading TTS") title = "Voice chat with Mistral 7B Instruct" DESCRIPTION = """# Voice chat with Mistral 7B Instruct""" css = """.toast-wrap { display: none !important } """ from huggingface_hub import HfApi HF_TOKEN = os.environ.get("HF_TOKEN") # will use api to restart space on a unrecoverable error api = HfApi(token=HF_TOKEN) repo_id = "ylacombe/voice-chat-with-lama" system_message = "\nYou are a helpful, respectful and honest assistant. Your answers are short, ideally a few words long, if it is possible. Always answer as helpfully as possible, while being safe.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." temperature = 0.9 top_p = 0.6 repetition_penalty = 1.2 import gradio as gr import os import time import gradio as gr from transformers import pipeline import numpy as np from gradio_client import Client from huggingface_hub import InferenceClient # This client is down #whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/") # Replacement whisper client, it may be time limited whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space") text_client = InferenceClient( "mistralai/Mistral-7B-Instruct-v0.1" ) ###### COQUI TTS FUNCTIONS ###### def get_latents(speaker_wav): # create as function as we can populate here with voice cleanup/filtering gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) return gpt_cond_latent, diffusion_conditioning, speaker_embedding def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt def generate( prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) formatted_prompt = format_prompt(prompt, history) try: stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text yield output except Exception as e: if "Too Many Requests" in str(e): print("ERROR: Too many requests on mistral client") gr.Warning("Unfortunately Mistral is unable to process") output = "Unfortuanately I am not able to process your request now !" else: print("Unhandled Exception: ", str(e)) gr.Warning("Unfortunately Mistral is unable to process") output = "I do not know what happened but I could not understand you ." return output def transcribe(wav_path): # get first element from whisper_jax and strip it to delete begin and end space return whisper_client.predict( wav_path, # str (filepath or URL to file) in 'inputs' Audio component "transcribe", # str in 'Task' Radio component False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py api_name="/predict" )[0].strip() # Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text. def add_text(history, text): history = [] if history is None else history history = history + [(text, None)] return history, gr.update(value="", interactive=False) def add_file(history, file): history = [] if history is None else history try: text = transcribe( file ) print("Transcribed text:",text) except Exception as e: print(str(e)) gr.Warning("There was an issue with transcription, please try writing for now") # Apply a null text on error text = "Transcription seems failed, please tell me a joke about chickens" history = history + [(text, None)] return history def bot(history, system_prompt=""): history = [] if history is None else history if system_prompt == "": system_prompt = system_message history[-1][1] = "" for character in generate(history[-1][0], history[:-1]): history[-1][1] = character yield history def get_latents(speaker_wav): # Generate speaker embedding and latents for TTS gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) return gpt_cond_latent, diffusion_conditioning, speaker_embedding latent_map={} latent_map["Female_Voice"] = get_latents("examples/female.wav") def get_voice(prompt,language, latent_tuple,suffix="0"): gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple # Direct version t0 = time.time() out = model.inference( prompt, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning ) inference_time = time.time() - t0 print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds") real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000 print(f"Real-time factor (RTF): {real_time_factor}") wav_filename=f"output_{suffix}.wav" torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000) return wav_filename def generate_speech(history): text_to_generate = history[-1][1] text_to_generate = text_to_generate.replace("\n", " ").strip() text_to_generate = nltk.sent_tokenize(text_to_generate) language = "en" wav_list = [] for i,sentence in enumerate(text_to_generate): # Sometimes prompt coming on output remove it sentence= sentence.replace("","") # A fast fix for last chacter, may produce weird sounds if it is with text if sentence[-1] in ["!","?",".",","]: #just add a space sentence = sentence[:-1] + " " + sentence[-1] print("Sentence:", sentence) try: # generate speech using precomputed latents # This is not streaming but it will be fast # giving sentence suffix so we can merge all to single audio at end # On mobile there is no autoplay support due to mobile security! wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=i) wav_list.append(wav) yield wav wait_time= librosa.get_duration(path=wav) print("Sleeping till audio end") time.sleep(wait_time) except RuntimeError as e : if "device-side assert" in str(e): # cannot do anything on cuda device side error, need tor estart print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True) gr.Warning("Unhandled Exception encounter, please retry in a minute") print("Cuda device-assert Runtime encountered need restart") # HF Space specific.. This error is unrecoverable need to restart space api.restart_space(repo_id=repo_id) else: print("RuntimeError: non device-side assert error:", str(e)) raise e #Spoken on autoplay everysencen now produce a concataned one at the one #requires pip install ffmpeg-python files_to_concat= [ffmpeg.input(w) for w in wav_list] combined_file_name="combined.wav" ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True) return gr.Audio.update(value=combined_file_name, autoplay=False) with gr.Blocks(title=title) as demo: gr.Markdown(DESCRIPTION) chatbot = gr.Chatbot( [], elem_id="chatbot", avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'), bubble_full_width=False, ) with gr.Row(): txt = gr.Textbox( scale=3, show_label=False, placeholder="Enter text and press enter, or speak to your microphone", container=False, ) txt_btn = gr.Button(value="Submit text",scale=1) btn = gr.Audio(source="microphone", type="filepath", scale=4) with gr.Row(): audio = gr.Audio(type="numpy", streaming=False, autoplay=True, label="Generated audio response", show_label=True) clear_btn = gr.ClearButton([chatbot, audio]) txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( bot, chatbot, chatbot ).then(generate_speech, chatbot, audio) txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( bot, chatbot, chatbot ).then(generate_speech, chatbot, audio) txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then( bot, chatbot, chatbot ).then(generate_speech, chatbot, audio) gr.Markdown(""" This Space demonstrates how to speak to a chatbot, based solely on open-source models. It relies on 3 models: 1. [Whisper-large-v2](https://huggingface.co./spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client). 2. [Mistral-7b-instruct](https://huggingface.co./spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co./docs/huggingface_hub/guides/inference). 3. [Coqui's XTTS](https://huggingface.co./spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally. Note: - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""") demo.queue() demo.launch(debug=True)