SB-Test-v2 / chatbot.py
Severian's picture
Update chatbot.py
bcb8853 verified
raw
history blame contribute delete
No virus
16.7 kB
from typing import List, Tuple, Dict, Any, Optional
import logging
import re
from data_models import IdeaForm, IDEA_STAGES
from config import DEFAULT_SYSTEM_PROMPT, STAGES
from utils import (
get_llm_response, extract_form_data, save_idea_to_database,
load_idea_from_database, update_idea_in_database,
get_db, clear_database, init_db, create_tables,
perform_web_search, optimize_search_query,
SessionLocal, InnovativeIdea
)
class InnovativeIdeaChatbot:
def __init__(self):
create_tables()
init_db()
self.idea_form = IdeaForm()
self.chat_history = []
self.idea_id = None
self.current_stage = None
self.api_key = None
self.add_system_message(self.get_initial_greeting())
def get_initial_greeting(self) -> str:
greeting = """
Welcome to the Innovative Idea Generator! I'm Myamoto, your AI assistant designed to help you refine and develop your innovative ideas.
Here's how we'll work together:
1. We'll go through 10 stages to explore different aspects of your idea.
2. At each stage, I'll ask you questions and provide feedback to help you think deeper about your concept.
3. You can ask me questions at any time or request more information on a topic.
4. If you want to perform a web search for additional information, just start your message with '@' followed by your search query.
5. When you're ready to move to the next stage, simply type 'next'.
Let's start by exploring your innovative idea! What's the name of your idea, or would you like help coming up with one?
"""
self.greeted = True
return greeting
def add_system_message(self, message: str):
self.chat_history.append(("System", message))
def set_api_key(self, api_key: str):
self.api_key = api_key
def activate_stage(self, stage_name: str) -> Optional[str]:
self.current_stage = stage_name
for stage in STAGES:
if stage["name"] == stage_name:
return f"Let's work on the '{stage_name}' stage. {stage['question']}"
return None
def process_stage_input(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]:
if self.current_stage != stage_name:
activation_message = self.activate_stage(stage_name)
if activation_message is None:
error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct."
self.chat_history.append(("System", error_message))
return self.chat_history, self.idea_form.dict()
self.chat_history.append(("System", activation_message))
# Check for web search request
if message.startswith('@'):
search_query = message[1:].strip()
optimized_query = optimize_search_query(search_query, model)
search_results = perform_web_search(optimized_query)
self.chat_history.append(("Human", message))
self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}"))
return self.chat_history, self.idea_form.dict()
# Generate the prompt for the current stage
stage_prompt = self.generate_prompt_for_stage(stage_name)
# Use the DEFAULT_SYSTEM_PROMPT from config.py
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format(
current_stage=stage_name,
stage_prompt=stage_prompt
)
# Combine the formatted system prompt and user's input
combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}"
# Get LLM response
llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key)
# Parse the LLM response to extract only the user-facing content
parsed_response = self.parse_llm_response(llm_response)
# Add the interaction to chat history
self.chat_history.append(("Human", message))
self.chat_history.append(("AI", parsed_response))
# Extract form data from the LLM response
form_data = extract_form_data(llm_response)
# Update the idea form
if stage_name in form_data:
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name])
return self.chat_history, self.idea_form.dict()
def parse_llm_response(self, response: str) -> str:
# Remove content within <form_data> tags
response = re.sub(r'<form_data>.*?</form_data>', '', response, flags=re.DOTALL)
# Remove content within <reflection> tags
response = re.sub(r'<reflection>.*?</reflection>', '', response, flags=re.DOTALL)
# Remove content within <analysis> tags
response = re.sub(r'<analysis>.*?</analysis>', '', response, flags=re.DOTALL)
# Remove content within <summary> tags
response = re.sub(r'<summary>.*?</summary>', '', response, flags=re.DOTALL)
# Remove content within <step> tags
response = re.sub(r'<step>.*?</step>', '', response, flags=re.DOTALL)
# Remove any remaining HTML-like tags
response = re.sub(r'<[^>]+>', '', response)
# Remove extra whitespace and newlines
response = re.sub(r'\s+', ' ', response).strip()
return response
def fill_out_form(self, current_stage: str, model: str, thinking_budget: int) -> Dict[str, str]:
form_data = {}
for stage in STAGES:
stage_name = stage["name"]
if stage_name == current_stage:
# Generate new data for the current stage
form_data[stage["field"]] = self.generate_form_data(stage_name, model, thinking_budget)
else:
# Use existing data for other stages
form_data[stage["field"]] = getattr(self.idea_form, stage["field"], "")
# Update the idea form
for stage in STAGES:
setattr(self.idea_form, stage["field"], form_data[stage["field"]])
# Save to database
try:
new_session = SessionLocal()
if self.idea_id:
update_idea_in_database(self.idea_id, self.idea_form, new_session)
else:
self.idea_id = save_idea_to_database(self.idea_form, new_session)
new_session.commit()
except Exception as e:
logging.error(f"Error saving idea to database: {str(e)}")
new_session.rollback()
finally:
new_session.close()
return form_data
def generate_prompt_for_stage(self, stage: str) -> str:
for s in IDEA_STAGES:
if s.name == stage:
return f"We are currently working on the '{stage}' stage. {s.question}"
return f"We are currently working on the '{stage}' stage. Please provide relevant information."
def reset(self):
self.chat_history = []
self.idea_form = IdeaForm()
self.idea_id = None
self.current_stage = None
self.add_system_message(self.get_initial_greeting())
try:
new_session = SessionLocal()
clear_database(new_session)
new_session.commit()
except Exception as e:
logging.error(f"Error clearing database: {str(e)}")
new_session.rollback()
finally:
new_session.close()
return self.chat_history, self.idea_form.dict()
def start_over(self):
self.chat_history = []
self.idea_form = IdeaForm()
self.current_stage = None
self.add_system_message(self.get_initial_greeting())
try:
new_session = SessionLocal()
# Clear the existing database
clear_database(new_session)
# Create a new empty idea
new_idea = InnovativeIdea()
new_session.add(new_idea)
new_session.commit()
new_session.refresh(new_idea)
# Update the idea_id
self.idea_id = new_idea.id
new_session.close()
except Exception as e:
logging.error(f"Error in start_over: {str(e)}")
if 'new_session' in locals():
new_session.rollback()
new_session.close()
return self.chat_history, self.idea_form.dict(), STAGES[0]["name"]
def update_idea_form(self, stage_name: str, form_data: str):
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data)
try:
new_session = SessionLocal()
if self.idea_id:
update_idea_in_database(self.idea_id, self.idea_form, new_session)
else:
self.idea_id = save_idea_to_database(self.idea_form, new_session)
new_session.commit()
except Exception as e:
logging.error(f"Error updating idea form: {str(e)}")
new_session.rollback()
finally:
new_session.close()
def generate_form_data(self, stage: str, model: str, thinking_budget: int) -> str:
# Prepare the conversation history for the LLM
conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history])
stage_prompt = self.generate_prompt_for_stage(stage)
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format(
current_stage=stage,
stage_prompt=stage_prompt
)
prompt = f"""
{formatted_system_prompt}
Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea:
{conversation}
Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage.
Your response should be structured as follows:
1. A brief analysis of the conversation related to this stage.
2. A concise summary of the key points relevant to this stage.
3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags.
The form entry should be in the format: "{stage}: Content"
Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry.
"""
# Get LLM response
llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key)
# Extract form data from the LLM response
form_data = extract_form_data(llm_response)
return form_data.get(stage, "")
def process_stage_input_stream(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int):
if self.current_stage != stage_name:
activation_message = self.activate_stage(stage_name)
if activation_message is None:
error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct."
self.chat_history.append(("System", error_message))
yield self.chat_history, self.idea_form.dict()
return
self.chat_history.append(("System", activation_message))
# Check for web search request
if message.startswith('@'):
search_query = message[1:].strip()
optimized_query = optimize_search_query(search_query, model)
search_results = perform_web_search(optimized_query)
self.chat_history.append(("Human", message))
self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}"))
yield self.chat_history, self.idea_form.dict()
return
# Generate the prompt for the current stage
stage_prompt = self.generate_prompt_for_stage(stage_name)
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format(
current_stage=stage_name,
stage_prompt=stage_prompt
)
combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}"
# Get LLM response
llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key)
parsed_response = self.parse_llm_response(llm_response)
self.chat_history.append(("Human", message))
self.chat_history.append(("AI", parsed_response))
form_data = extract_form_data(llm_response)
if stage_name in form_data:
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name])
yield self.chat_history, self.idea_form.dict()
def fill_out_form_stream(self, current_stage: str, model: str, thinking_budget: int):
form_data = {}
for stage in IDEA_STAGES:
stage_name = stage.name
if stage_name == current_stage:
form_data[stage_name] = self.generate_form_data(stage_name, model, thinking_budget)
else:
form_data[stage_name] = getattr(self.idea_form, stage.field, "")
yield form_data
# Update the idea form
for stage in IDEA_STAGES:
setattr(self.idea_form, stage.field, form_data[stage.name])
# Save to database
try:
new_session = SessionLocal()
if self.idea_id:
update_idea_in_database(self.idea_id, self.idea_form, new_session)
else:
self.idea_id = save_idea_to_database(self.idea_form, new_session)
new_session.commit()
except Exception as e:
logging.error(f"Error saving idea to database: {str(e)}")
new_session.rollback()
finally:
new_session.close()
def generate_form_data_stream(self, stage: str, model: str, thinking_budget: int):
conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history])
stage_prompt = self.generate_prompt_for_stage(stage)
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format(
current_stage=stage,
stage_prompt=stage_prompt
)
prompt = f"""
{formatted_system_prompt}
Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea:
{conversation}
Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage.
Your response should be structured as follows:
1. A brief analysis of the conversation related to this stage.
2. A concise summary of the key points relevant to this stage.
3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags.
The form entry should be in the format: "{stage}: Content"
Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry.
"""
llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key)
form_data = extract_form_data(llm_response)
return form_data.get(stage, "")
def update_form_field(self, stage_name: str, value: str):
field_name = stage_name.lower().replace(" ", "_")
if field_name == 'team_roles':
value = value.split(',') # Convert string to list for team_roles
setattr(self.idea_form, field_name, value)
try:
# Create a new session for this operation
new_session = SessionLocal()
if self.idea_id:
update_idea_in_database(self.idea_id, self.idea_form, new_session)
else:
self.idea_id = save_idea_to_database(self.idea_form, new_session)
new_session.commit()
except Exception as e:
logging.error(f"Error updating form field: {str(e)}")
# If an error occurs, rollback the new session
new_session.rollback()
finally:
# Always close the new session
new_session.close()
return value
# Add this new method to handle the "Fill out form" button click
def fill_out_form_button(chatbot: InnovativeIdeaChatbot, current_stage: str, model: str, thinking_budget: int):
form_data = chatbot.fill_out_form(current_stage, model, thinking_budget)
return {stage["field"]: form_data[stage["field"]] for stage in STAGES}