import gradio as gr from conversation import Conversation def get_tab_playground(download_bot_config, get_bot_profile, model_mapping): gr.Markdown(""" # 🎢 Playground 🎢 ## Rules * Chat with any model you would like with any bot from the Chai app. * Click “Clear” to start a new conversation. """) default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca" bot_config = download_bot_config(default_bot_id) user_state = gr.State( bot_config ) with gr.Row(): bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True) reload_bot_button = gr.Button("Reload bot") bot_profile = gr.HTML(get_bot_profile(bot_config)) with gr.Accordion("Bot config:", open=False): bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}") first_message = (None, bot_config["firstMessage"]) chatbot = gr.Chatbot([first_message]) msg = gr.Textbox(label="Message", value="Hi there!") with gr.Row(): send = gr.Button("Send") regenerate = gr.Button("Regenerate") clear = gr.Button("Clear") values = list(model_mapping.keys()) model_tag = gr.Dropdown(values, value=values[0], label="Model version") model = model_mapping[model_tag.value] with gr.Accordion("Generation parameters", open=False): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"], interactive=True, label="Temperature") repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=model.generation_params["repetition_penalty"], interactive=True, label="Repetition penalty") max_new_tokens = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"], interactive=True, label="Max new tokens") top_k = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"], interactive=True, label="Top-K") top_p = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"], interactive=True, label="Top-P") def respond(message, chat_history, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, top_p): custom_generation_params = { 'temperature': temperature, 'repetition_penalty': repetition_penalty, 'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, } conv = Conversation(user_state) conv.set_chat_history(chat_history) conv.add_user_message(message) model = model_mapping[model_tag] bot_message = model.generate_response(conv, custom_generation_params) chat_history.append( (message, bot_message) ) return "", chat_history def clear_chat(chat_history, user_state): chat_history = [(None, user_state["firstMessage"])] return chat_history def regenerate_response(chat_history, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, top_p): custom_generation_params = { 'temperature': temperature, 'repetition_penalty': repetition_penalty, 'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, } last_row = chat_history.pop(-1) chat_history.append((last_row[0], None)) model = model_mapping[model_tag] conv = Conversation(user_state) conv.set_chat_history(chat_history) bot_message = model.generate_response(conv, custom_generation_params) chat_history[-1] = (last_row[0], bot_message) return chat_history def reload_bot(bot_id, bot_profile, chat_history): bot_config = download_bot_config(bot_id) bot_profile = get_bot_profile(bot_config) return bot_profile, [(None, bot_config[ "firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}" def get_generation_args(model_tag): model = model_mapping[model_tag] return ( model.generation_params["temperature"], model.generation_params["repetition_penalty"], model.generation_params["max_new_tokens"], model.generation_params["top_k"], model.generation_params["top_p"], ) model_tag.change(get_generation_args, [model_tag], [temperature, repetition_penalty, max_new_tokens, top_k, top_p], queue=False) send.click(respond, [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, top_p], [msg, chatbot], queue=False) msg.submit(respond, [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, top_p], [msg, chatbot], queue=False) clear.click(clear_chat, [chatbot, user_state], [chatbot], queue=False) regenerate.click(regenerate_response, [chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, top_p], [chatbot], queue=False) reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], [bot_profile, chatbot, user_state, bot_config_text], queue=False)