Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import base64 | |
import io | |
from PIL import Image | |
import requests | |
# Inicializaci贸n del cliente de inferencia con el modelo especificado | |
client = InferenceClient("mistralai/Pixtral-Large-Instruct-2411") | |
def image_to_base64(image_path): | |
"""Convert an image file to a base64 string.""" | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return encoded_string | |
def base64_to_image(base64_string): | |
"""Convert a base64 string to an image.""" | |
image_data = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
def describe_image(image, system_message, max_tokens, temperature, top_p): | |
"""Describe an image using the model.""" | |
if image is None: | |
return "No image uploaded.", [] | |
# Convert image to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": "Describe the following image:"}, | |
{"role": "user", "content": image_base64} | |
] | |
response = "" | |
for chunk in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = chunk.choices[0].delta.content | |
response += token | |
return response, [(f"User: Describe the following image:", response)] | |
def respond( | |
user_message: str, | |
chat_history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
) -> str: | |
""" | |
Funci贸n para generar respuestas basadas en el historial de chat y par谩metros de configuraci贸n. | |
Args: | |
user_message (str): Mensaje del usuario. | |
chat_history (list[tuple[str, str]]): Historial de chat. | |
system_message (str): Mensaje del sistema que define el comportamiento del chatbot. | |
max_tokens (int): M谩ximo n煤mero de tokens a generar. | |
temperature (float): Temperatura para el muestreo de texto. | |
top_p (float): Par谩metro top-p para el muestreo de texto. | |
Yields: | |
str: Respuesta generada por el modelo. | |
""" | |
# Construcci贸n de la lista de mensajes | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, assistant_msg in chat_history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": user_message}) | |
response = "" | |
try: | |
# Obtenci贸n de la respuesta del modelo | |
for chunk in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = chunk.choices[0].delta.content | |
response += token | |
yield response | |
except Exception as e: | |
yield f"Error al obtener respuesta: {str(e)}" | |
def main(): | |
""" | |
Funci贸n principal para iniciar la interfaz de chat. | |
""" | |
def update_chat(user_message, image, chat_history, system_message, max_tokens, temperature, top_p): | |
if image is not None: | |
description, new_history = describe_image(image, system_message, max_tokens, temperature, top_p) | |
chat_history.extend(new_history) | |
user_message = description | |
if user_message: | |
response_generator = respond( | |
user_message, | |
chat_history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
) | |
for response in response_generator: | |
chat_history.append((user_message, response)) | |
yield "", chat_history, chat_history | |
else: | |
yield "", chat_history, chat_history | |
with gr.Blocks(title="Chatbot con MistralAI", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Chatbot con MistralAI") | |
gr.Markdown("Un chatbot amigable basado en el modelo MistralAI Pixtral-Large-Instruct-2411 que puede describir im谩genes y mantener un historial de chat.") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Conversaci贸n") | |
user_message = gr.Textbox(label="Mensaje del Usuario", placeholder="Escribe tu mensaje aqu铆...") | |
with gr.Row(): | |
submit_button = gr.Button("Enviar") | |
clear_button = gr.Button("Limpiar") | |
with gr.Column(scale=2): | |
image_input = gr.Image(label="Cargar Imagen", type="pil") | |
image_description = gr.Textbox(label="Descripci贸n de la Imagen", interactive=False) | |
with gr.Row(): | |
system_message = gr.Textbox( | |
value="You are a friendly Chatbot.", | |
label="Mensaje del Sistema", | |
placeholder="Define el comportamiento del chatbot." | |
) | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=512, | |
step=1, | |
label="Max New Tokens", | |
info="M谩ximo n煤mero de tokens generados." | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controla la creatividad de la respuesta." | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (Nucleus Sampling)", | |
info="Par谩metro para el muestreo del texto." | |
) | |
chat_history = gr.State([]) | |
submit_button.click( | |
fn=update_chat, | |
inputs=[user_message, image_input, chat_history, system_message, max_tokens, temperature, top_p], | |
outputs=[user_message, chatbot, chat_history] | |
) | |
clear_button.click( | |
fn=lambda: ([], [], []), | |
inputs=[], | |
outputs=[user_message, chatbot, chat_history] | |
) | |
image_input.upload( | |
fn=describe_image, | |
inputs=[image_input, system_message, max_tokens, temperature, top_p], | |
outputs=[image_description, chat_history] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |