Spaces:
Running
Running
import gradio as gr | |
from sqlalchemy.exc import SQLAlchemyError | |
from utils import InnovativeIdea, init_db, get_db, SessionLocal, get_llm_response | |
from data_models import IdeaForm | |
from chatbot import InnovativeIdeaChatbot | |
from config import MODELS, DEFAULT_SYSTEM_PROMPT, STAGES | |
import logging | |
import os | |
from typing import Dict, Any | |
import re | |
import time | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def create_gradio_interface(): | |
innovative_chatbot = InnovativeIdeaChatbot() | |
default_stage = STAGES[0]["name"] | |
# Initialize the database | |
try: | |
init_db() | |
db = next(get_db()) | |
initial_idea = db.query(InnovativeIdea).first() | |
if initial_idea is None: | |
logging.info("No initial idea found in the database. Creating a new one.") | |
initial_idea = InnovativeIdea() | |
db.add(initial_idea) | |
db.commit() | |
db.refresh(initial_idea) | |
# Create form_fields while the session is still open | |
form_fields = { | |
stage["name"]: gr.Textbox( | |
label=stage["question"], | |
placeholder=stage["example"], | |
value=getattr(initial_idea, stage["field"], ""), | |
visible=(stage["name"] == default_stage), | |
interactive=False | |
) for stage in STAGES | |
} | |
# Now we can safely close the session | |
db.close() | |
except SQLAlchemyError as e: | |
logging.error(f"Database initialization failed: {str(e)}") | |
raise RuntimeError(f"Failed to initialize database: {str(e)}") | |
def chatbot_function(message, history, model, system_prompt, thinking_budget, current_stage): | |
try: | |
# If this is the first message, get the initial greeting | |
if not history: | |
initial_greeting = innovative_chatbot.get_initial_greeting() | |
history.append((None, initial_greeting)) | |
yield history, "", "" | |
return | |
for partial_response in innovative_chatbot.process_stage_input_stream(current_stage, message, model, system_prompt, thinking_budget): | |
chat_history, form_data = partial_response | |
history.append((message, chat_history[-1][1])) | |
yield history, form_data.get(current_stage, ""), "" | |
# Update the database with the new form data | |
db = SessionLocal() | |
idea = db.query(InnovativeIdea).first() | |
for key, value in form_data.items(): | |
if key == 'team_roles' and isinstance(value, list): | |
value = ','.join(value) # Convert list to string for database storage | |
setattr(idea, key, value) | |
db.commit() | |
db.close() | |
except Exception as e: | |
logging.error(f"An error occurred in chatbot_function: {str(e)}", exc_info=True) | |
yield history + [(None, f"An error occurred: {str(e)}")], "", "" | |
def fill_form(stage, model, thinking_budget): | |
form_data = innovative_chatbot.fill_out_form(stage, model, thinking_budget) | |
return [form_data.get(stage["field"], "") for stage in STAGES] | |
def clear_chat(): | |
# Reset the database to an empty form | |
db = SessionLocal() | |
idea = db.query(InnovativeIdea).first() | |
empty_form = IdeaForm() | |
for key, value in empty_form.dict().items(): | |
setattr(idea, key, value) | |
db.commit() | |
db.close() | |
chat_history, form_data = innovative_chatbot.reset() | |
return chat_history, *[form_data.get(stage["field"], "") for stage in STAGES] | |
def start_over(): | |
chat_history, form_data, initial_stage = innovative_chatbot.start_over() | |
return ( | |
chat_history, # Update the chatbot with the new chat history | |
"", # Clear the message input | |
*[form_data.get(stage["field"], "") for stage in STAGES], # Reset all form fields | |
gr.update(value=initial_stage) # Reset the stage selection | |
) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Innovative Idea Generator") | |
mode = gr.Radio(["Chatbot", "Direct Input"], label="Mode", value="Chatbot") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(label="Conversation", height=500) | |
msg = gr.Textbox(label="Your input", placeholder="Type your brilliant idea here...") | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
clear = gr.Button("Clear Chat") | |
start_over_btn = gr.Button("Start Over") | |
with gr.Column(scale=1): | |
stages = gr.Radio( | |
choices=[stage["name"] for stage in STAGES], | |
label="Ideation Stages", | |
value=default_stage | |
) | |
form_fields = { | |
stage["name"]: gr.Textbox( | |
label=stage["question"], | |
placeholder=stage["example"], | |
value=getattr(initial_idea, stage["field"], ""), | |
visible=(stage["name"] == default_stage), | |
interactive=False | |
) for stage in STAGES | |
} | |
fill_form_btn = gr.Button("Fill out Form") | |
submit_form_btn = gr.Button("Submit Form", visible=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]) | |
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5) | |
thinking_budget = gr.Slider(minimum=1, maximum=4098, value=2048, step=1, label="Max New Tokens") | |
api_key = gr.Textbox(label="Hugging Face API Key", type="password") | |
# Event handlers | |
msg.submit(chatbot_function, | |
inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages], | |
outputs=[chatbot, form_fields[default_stage], msg]) | |
submit.click(chatbot_function, | |
inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages], | |
outputs=[chatbot, form_fields[default_stage], msg]) | |
fill_form_btn.click(fill_form, | |
inputs=[stages, model, thinking_budget], | |
outputs=list(form_fields.values())) | |
clear.click(clear_chat, | |
outputs=[chatbot] + list(form_fields.values())) | |
# Update form field visibility based on selected stage | |
stages.change( | |
lambda s: [gr.update(visible=(stage["name"] == s)) for stage in STAGES], | |
inputs=[stages], | |
outputs=list(form_fields.values()) | |
) | |
# Update API key when changed | |
api_key.change(innovative_chatbot.set_api_key, inputs=[api_key]) | |
# Toggle between chatbot and direct input mode | |
def toggle_mode(new_mode): | |
if new_mode == "Direct Input": | |
return [gr.update(visible=False)] * 3 + [gr.update(interactive=True)] * len(STAGES) + [gr.update(visible=True)] | |
else: | |
return [gr.update(visible=True)] * 3 + [gr.update(interactive=False)] * len(STAGES) + [gr.update(visible=False)] | |
mode.change( | |
toggle_mode, | |
inputs=[mode], | |
outputs=[chatbot, msg, submit] + list(form_fields.values()) + [submit_form_btn] | |
) | |
# Handle direct form submission | |
submit_form_btn.click( | |
lambda *values: values, | |
inputs=[form_fields[stage["name"]] for stage in STAGES], | |
outputs=[form_fields[stage["name"]] for stage in STAGES] | |
) | |
# Add this new event handler for the Start Over button | |
start_over_btn.click( | |
start_over, | |
outputs=[chatbot, msg] + [form_fields[stage["name"]] for stage in STAGES] + [stages] | |
) | |
# Add this new event handler to display the initial greeting when the interface loads | |
demo.load(lambda: ([[None, innovative_chatbot.get_initial_greeting()]], ""), | |
outputs=[chatbot, msg]) | |
# Add this new event handler to update form fields when they change | |
for stage in STAGES: | |
form_fields[stage["name"]].change( | |
lambda value, s=stage["name"]: innovative_chatbot.update_form_field(s, value), | |
inputs=[form_fields[stage["name"]]], | |
outputs=[form_fields[stage["name"]]] | |
) | |
return demo | |
def main(): | |
try: | |
demo = create_gradio_interface() | |
return demo | |
except ImportError as e: | |
logging.error(f"Import error: {str(e)}", exc_info=True) | |
print(f"An import error occurred: {str(e)}") | |
print("Please check your import statements and ensure there are no circular dependencies.") | |
return None | |
except Exception as e: | |
logging.error(f"Failed to initialize application: {str(e)}", exc_info=True) | |
print(f"An unexpected error occurred: {str(e)}") | |
print("Please check the log file for more details.") | |
return None | |
if __name__ == "__main__": | |
try: | |
demo = main() | |
if demo: | |
demo.launch() | |
except Exception as e: | |
logging.error(f"Failed to start the application: {str(e)}", exc_info=True) | |
print(f"An error occurred while starting the application: {str(e)}") | |
print("Please check the log file for more details.") | |
# You might want to add a more user-friendly error message or UI here | |