SB-Test-v2 / main.py
Severian's picture
Update main.py
3782906 verified
raw
history blame contribute delete
No virus
9.79 kB
import gradio as gr
from sqlalchemy.exc import SQLAlchemyError
from utils import InnovativeIdea, init_db, get_db, SessionLocal, get_llm_response
from data_models import IdeaForm
from chatbot import InnovativeIdeaChatbot
from config import MODELS, DEFAULT_SYSTEM_PROMPT, STAGES
import logging
import os
from typing import Dict, Any
import re
import time
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def create_gradio_interface():
innovative_chatbot = InnovativeIdeaChatbot()
default_stage = STAGES[0]["name"]
# Initialize the database
try:
init_db()
db = next(get_db())
initial_idea = db.query(InnovativeIdea).first()
if initial_idea is None:
logging.info("No initial idea found in the database. Creating a new one.")
initial_idea = InnovativeIdea()
db.add(initial_idea)
db.commit()
db.refresh(initial_idea)
# Create form_fields while the session is still open
form_fields = {
stage["name"]: gr.Textbox(
label=stage["question"],
placeholder=stage["example"],
value=getattr(initial_idea, stage["field"], ""),
visible=(stage["name"] == default_stage),
interactive=False
) for stage in STAGES
}
# Now we can safely close the session
db.close()
except SQLAlchemyError as e:
logging.error(f"Database initialization failed: {str(e)}")
raise RuntimeError(f"Failed to initialize database: {str(e)}")
def chatbot_function(message, history, model, system_prompt, thinking_budget, current_stage):
try:
# If this is the first message, get the initial greeting
if not history:
initial_greeting = innovative_chatbot.get_initial_greeting()
history.append((None, initial_greeting))
yield history, "", ""
return
for partial_response in innovative_chatbot.process_stage_input_stream(current_stage, message, model, system_prompt, thinking_budget):
chat_history, form_data = partial_response
history.append((message, chat_history[-1][1]))
yield history, form_data.get(current_stage, ""), ""
# Update the database with the new form data
db = SessionLocal()
idea = db.query(InnovativeIdea).first()
for key, value in form_data.items():
if key == 'team_roles' and isinstance(value, list):
value = ','.join(value) # Convert list to string for database storage
setattr(idea, key, value)
db.commit()
db.close()
except Exception as e:
logging.error(f"An error occurred in chatbot_function: {str(e)}", exc_info=True)
yield history + [(None, f"An error occurred: {str(e)}")], "", ""
def fill_form(stage, model, thinking_budget):
form_data = innovative_chatbot.fill_out_form(stage, model, thinking_budget)
return [form_data.get(stage["field"], "") for stage in STAGES]
def clear_chat():
# Reset the database to an empty form
db = SessionLocal()
idea = db.query(InnovativeIdea).first()
empty_form = IdeaForm()
for key, value in empty_form.dict().items():
setattr(idea, key, value)
db.commit()
db.close()
chat_history, form_data = innovative_chatbot.reset()
return chat_history, *[form_data.get(stage["field"], "") for stage in STAGES]
def start_over():
chat_history, form_data, initial_stage = innovative_chatbot.start_over()
return (
chat_history, # Update the chatbot with the new chat history
"", # Clear the message input
*[form_data.get(stage["field"], "") for stage in STAGES], # Reset all form fields
gr.update(value=initial_stage) # Reset the stage selection
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Innovative Idea Generator")
mode = gr.Radio(["Chatbot", "Direct Input"], label="Mode", value="Chatbot")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation", height=500)
msg = gr.Textbox(label="Your input", placeholder="Type your brilliant idea here...")
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear Chat")
start_over_btn = gr.Button("Start Over")
with gr.Column(scale=1):
stages = gr.Radio(
choices=[stage["name"] for stage in STAGES],
label="Ideation Stages",
value=default_stage
)
form_fields = {
stage["name"]: gr.Textbox(
label=stage["question"],
placeholder=stage["example"],
value=getattr(initial_idea, stage["field"], ""),
visible=(stage["name"] == default_stage),
interactive=False
) for stage in STAGES
}
fill_form_btn = gr.Button("Fill out Form")
submit_form_btn = gr.Button("Submit Form", visible=False)
with gr.Accordion("Advanced Settings", open=False):
model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0])
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5)
thinking_budget = gr.Slider(minimum=1, maximum=4098, value=2048, step=1, label="Max New Tokens")
api_key = gr.Textbox(label="Hugging Face API Key", type="password")
# Event handlers
msg.submit(chatbot_function,
inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages],
outputs=[chatbot, form_fields[default_stage], msg])
submit.click(chatbot_function,
inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages],
outputs=[chatbot, form_fields[default_stage], msg])
fill_form_btn.click(fill_form,
inputs=[stages, model, thinking_budget],
outputs=list(form_fields.values()))
clear.click(clear_chat,
outputs=[chatbot] + list(form_fields.values()))
# Update form field visibility based on selected stage
stages.change(
lambda s: [gr.update(visible=(stage["name"] == s)) for stage in STAGES],
inputs=[stages],
outputs=list(form_fields.values())
)
# Update API key when changed
api_key.change(innovative_chatbot.set_api_key, inputs=[api_key])
# Toggle between chatbot and direct input mode
def toggle_mode(new_mode):
if new_mode == "Direct Input":
return [gr.update(visible=False)] * 3 + [gr.update(interactive=True)] * len(STAGES) + [gr.update(visible=True)]
else:
return [gr.update(visible=True)] * 3 + [gr.update(interactive=False)] * len(STAGES) + [gr.update(visible=False)]
mode.change(
toggle_mode,
inputs=[mode],
outputs=[chatbot, msg, submit] + list(form_fields.values()) + [submit_form_btn]
)
# Handle direct form submission
submit_form_btn.click(
lambda *values: values,
inputs=[form_fields[stage["name"]] for stage in STAGES],
outputs=[form_fields[stage["name"]] for stage in STAGES]
)
# Add this new event handler for the Start Over button
start_over_btn.click(
start_over,
outputs=[chatbot, msg] + [form_fields[stage["name"]] for stage in STAGES] + [stages]
)
# Add this new event handler to display the initial greeting when the interface loads
demo.load(lambda: ([[None, innovative_chatbot.get_initial_greeting()]], ""),
outputs=[chatbot, msg])
# Add this new event handler to update form fields when they change
for stage in STAGES:
form_fields[stage["name"]].change(
lambda value, s=stage["name"]: innovative_chatbot.update_form_field(s, value),
inputs=[form_fields[stage["name"]]],
outputs=[form_fields[stage["name"]]]
)
return demo
def main():
try:
demo = create_gradio_interface()
return demo
except ImportError as e:
logging.error(f"Import error: {str(e)}", exc_info=True)
print(f"An import error occurred: {str(e)}")
print("Please check your import statements and ensure there are no circular dependencies.")
return None
except Exception as e:
logging.error(f"Failed to initialize application: {str(e)}", exc_info=True)
print(f"An unexpected error occurred: {str(e)}")
print("Please check the log file for more details.")
return None
if __name__ == "__main__":
try:
demo = main()
if demo:
demo.launch()
except Exception as e:
logging.error(f"Failed to start the application: {str(e)}", exc_info=True)
print(f"An error occurred while starting the application: {str(e)}")
print("Please check the log file for more details.")
# You might want to add a more user-friendly error message or UI here