Spaces:
Sleeping
Sleeping
import streamlit as st | |
from huggingface_hub import InferenceClient | |
import base64 | |
from pydub import AudioSegment | |
from io import BytesIO | |
from gtts import gTTS | |
from streamlit_webrtc import webrtc_streamer, WebRtcMode | |
import speech_recognition as sr | |
import sounddevice as sd | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
pre_prompt = "" | |
pre_prompt_sent = False | |
webrtc_ctx = None | |
def take_user_input(): | |
r = sr.Recognizer() | |
def audio_callback(in_data, frame_count, time_info, status): | |
global webrtc_ctx | |
audio = sr.AudioData( | |
in_data.tobytes(), | |
sample_rate=webrtc_ctx.audio_sample_rate, | |
sample_width=sd.default.dtype.itemsize | |
) | |
st.info('Reconociendo...') | |
query = transcribe_speech(audio) | |
if 'salir' in query or 'detener' in query: | |
speak("Hasta luego.") | |
exit() | |
return query | |
global webrtc_ctx | |
webrtc_ctx = webrtc_streamer( | |
key="microphone", | |
mode=WebRtcMode.SENDRECV, | |
audio_receiver=audio_callback, | |
async_processing=True, | |
) | |
if not webrtc_ctx: | |
st.warning("Por favor, habilita el micrófono.") | |
return 'None' | |
st.info('Escuchando...') | |
try: | |
with sd.InputStream(callback=lambda indata, frames, time, status: None): | |
while True: | |
audio_data = webrtc_ctx.audio_receiver_stream.get() | |
if audio_data: | |
audio = sr.AudioData( | |
audio_data.tobytes(), | |
sample_rate=webrtc_ctx.audio_sample_rate, | |
sample_width=audio_data.itemsize | |
) | |
st.info('Reconociendo...') | |
query = transcribe_speech(audio) | |
if 'salir' in query or 'detener' in query: | |
speak("Hasta luego.") | |
exit() | |
return query | |
except sr.UnknownValueError: | |
speak('No se ha reconocido nada. Intenta de nuevo...') | |
except sr.RequestError as e: | |
st.error(f"Error en la solicitud al reconocimiento de voz: {e}") | |
return 'None' | |
def audio_callback(in_data, frame_count, time_info, status): | |
return in_data, webrtc_ctx.audio_sample_rate | |
def format_prompt(message, history): | |
global pre_prompt_sent | |
prompt = "<s>" | |
if not pre_prompt_sent and all(f"[INST] {pre_prompt} [/INST]" not in user_prompt for user_prompt, _ in history): | |
prompt += f"[INST] {pre_prompt} [/INST]" | |
pre_prompt_sent = True | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate(user_input, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0): | |
global pre_prompt_sent | |
temperature = float(temperature) if temperature is not None else 0.9 | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(user_input, history) | |
try: | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True) | |
response = "" | |
for response_token in stream: | |
response += response_token.token.text | |
response = ' '.join(response.split()).replace('</s>', '') | |
audio_bytes = text_to_speech(response) | |
return response, audio_bytes | |
except Exception as e: | |
return str(e), None | |
def text_to_speech(text): | |
tts = gTTS(text=text, lang='es') | |
audio_stream = BytesIO() | |
tts.save(audio_stream) | |
audio_stream.seek(0) | |
return audio_stream.read() | |
def speak(text): | |
audio_bytes = text_to_speech(text) | |
st.audio(audio_bytes, format="audio/mp3", start_time=0, key="audio_player") | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
user_input = take_user_input() | |
output, audio_bytes = generate(user_input, history=st.session_state.history) | |
with st.container(width=900, height=400): | |
user_input_container = st.text_input("Tu entrada de usuario", value=user_input) | |
st.text_area("Respuesta", value=output, key="output_text", disabled=True) | |
if audio_bytes is not None: | |
st.markdown( | |
f""" | |
<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_bytes).decode()}" type="audio/mp3" id="audio_player"></audio> | |
""", | |
unsafe_allow_html=True | |
) |