import os
import dotenv
import gradio as gr # type: ignore
from mistralai.client import MistralClient # type: ignore
from mistralai.models.chat_completion import ChatMessage # type: ignore
dotenv.load_dotenv()
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
TITLE = """
MistralAI Playground 💬
"""
DUPLICATE = """
"""
AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png")
chatbot_component = gr.Chatbot(
label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400
)
text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8)
run_button_component = gr.Button(value="Run", variant="primary", scale=1)
clear_button_component = gr.ClearButton(value="Clear", variant="secondary", scale=1)
mistral_key_component = gr.Textbox(
label="MISTRAL API KEY",
value="",
type="password",
placeholder="...",
info="You have to provide your own MISTRAL_API_KEY for this app to function properly",
visible=MISTRAL_API_KEY is None,
)
model_component = gr.Dropdown(
choices=["mistral-tiny", "mistral-small", "mistral-medium"],
label="Model",
value="mistral-small",
scale=1,
type="value",
)
temperature_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.7,
step=0.05,
label="Temperature",
info=(
"What sampling temperature to use, between 0.0 and 1.0. "
"Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic."
"We generally recommend altering this or top_p but not both."
),
)
user_inputs = [
text_prompt_component,
chatbot_component,
]
bot_inputs = [
mistral_key_component,
model_component,
temperature_component,
chatbot_component,
]
client: MistralClient = None
def preprocess_chat_history(history):
chat_history = []
for human, assistant in history:
if human:
chat_history.append(ChatMessage(role="user", content=human))
if assistant:
chat_history.append(ChatMessage(role="assistant", content=assistant))
return chat_history
def bot(
mistral_key: str | None,
model: str,
temperature: float,
history,
):
if not history:
return history
mistral_key = mistral_key or MISTRAL_API_KEY
if not mistral_key:
raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.")
global client
if client is None:
client = MistralClient(api_key=mistral_key) # TDOO: how to handle this if no GIL
chat_history = preprocess_chat_history(history)
history[-1][1] = ""
for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature):
print("chunk", chunk)
if chunk.choices and chunk.choices[0].delta.content:
history[-1][1] += chunk.choices[0].delta.content
yield history
def user(text_prompt: str, history):
if text_prompt:
history.append((text_prompt, None))
return "", history
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.HTML(DUPLICATE)
with gr.Column():
mistral_key_component.render()
chatbot_component.render()
with gr.Row():
text_prompt_component.render()
run_button_component.render()
clear_button_component.render()
with gr.Accordion("Parameters", open=False):
model_component.render()
temperature_component.render()
run_button_component.click(
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
).then(
fn=bot,
inputs=bot_inputs,
outputs=[chatbot_component],
)
clear_button_component.click(lambda: (None, None), outputs=[text_prompt_component, chatbot_component], queue=False)
text_prompt_component.submit(
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
).then(
fn=bot,
inputs=bot_inputs,
outputs=[chatbot_component],
)
demo.queue(max_size=99).launch(debug=False, show_error=True)