from typing import List, Tuple, Dict, Any, Optional import logging import re from data_models import IdeaForm, IDEA_STAGES from config import DEFAULT_SYSTEM_PROMPT, STAGES from utils import ( get_llm_response, extract_form_data, save_idea_to_database, load_idea_from_database, update_idea_in_database, get_db, clear_database, init_db, create_tables, perform_web_search, optimize_search_query, SessionLocal, InnovativeIdea ) class InnovativeIdeaChatbot: def __init__(self): create_tables() init_db() self.idea_form = IdeaForm() self.chat_history = [] self.idea_id = None self.current_stage = None self.api_key = None self.add_system_message(self.get_initial_greeting()) def get_initial_greeting(self) -> str: greeting = """ Welcome to the Innovative Idea Generator! I'm Myamoto, your AI assistant designed to help you refine and develop your innovative ideas. Here's how we'll work together: 1. We'll go through 10 stages to explore different aspects of your idea. 2. At each stage, I'll ask you questions and provide feedback to help you think deeper about your concept. 3. You can ask me questions at any time or request more information on a topic. 4. If you want to perform a web search for additional information, just start your message with '@' followed by your search query. 5. When you're ready to move to the next stage, simply type 'next'. Let's start by exploring your innovative idea! What's the name of your idea, or would you like help coming up with one? """ self.greeted = True return greeting def add_system_message(self, message: str): self.chat_history.append(("System", message)) def set_api_key(self, api_key: str): self.api_key = api_key def activate_stage(self, stage_name: str) -> Optional[str]: self.current_stage = stage_name for stage in STAGES: if stage["name"] == stage_name: return f"Let's work on the '{stage_name}' stage. {stage['question']}" return None def process_stage_input(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]: if self.current_stage != stage_name: activation_message = self.activate_stage(stage_name) if activation_message is None: error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." self.chat_history.append(("System", error_message)) return self.chat_history, self.idea_form.dict() self.chat_history.append(("System", activation_message)) # Check for web search request if message.startswith('@'): search_query = message[1:].strip() optimized_query = optimize_search_query(search_query, model) search_results = perform_web_search(optimized_query) self.chat_history.append(("Human", message)) self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) return self.chat_history, self.idea_form.dict() # Generate the prompt for the current stage stage_prompt = self.generate_prompt_for_stage(stage_name) # Use the DEFAULT_SYSTEM_PROMPT from config.py formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( current_stage=stage_name, stage_prompt=stage_prompt ) # Combine the formatted system prompt and user's input combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" # Get LLM response llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) # Parse the LLM response to extract only the user-facing content parsed_response = self.parse_llm_response(llm_response) # Add the interaction to chat history self.chat_history.append(("Human", message)) self.chat_history.append(("AI", parsed_response)) # Extract form data from the LLM response form_data = extract_form_data(llm_response) # Update the idea form if stage_name in form_data: setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) return self.chat_history, self.idea_form.dict() def parse_llm_response(self, response: str) -> str: # Remove content within tags response = re.sub(r'.*?', '', response, flags=re.DOTALL) # Remove content within tags response = re.sub(r'.*?', '', response, flags=re.DOTALL) # Remove content within tags response = re.sub(r'.*?', '', response, flags=re.DOTALL) # Remove content within tags response = re.sub(r'.*?', '', response, flags=re.DOTALL) # Remove content within tags response = re.sub(r'.*?', '', response, flags=re.DOTALL) # Remove any remaining HTML-like tags response = re.sub(r'<[^>]+>', '', response) # Remove extra whitespace and newlines response = re.sub(r'\s+', ' ', response).strip() return response def fill_out_form(self, current_stage: str, model: str, thinking_budget: int) -> Dict[str, str]: form_data = {} for stage in STAGES: stage_name = stage["name"] if stage_name == current_stage: # Generate new data for the current stage form_data[stage["field"]] = self.generate_form_data(stage_name, model, thinking_budget) else: # Use existing data for other stages form_data[stage["field"]] = getattr(self.idea_form, stage["field"], "") # Update the idea form for stage in STAGES: setattr(self.idea_form, stage["field"], form_data[stage["field"]]) # Save to database try: new_session = SessionLocal() if self.idea_id: update_idea_in_database(self.idea_id, self.idea_form, new_session) else: self.idea_id = save_idea_to_database(self.idea_form, new_session) new_session.commit() except Exception as e: logging.error(f"Error saving idea to database: {str(e)}") new_session.rollback() finally: new_session.close() return form_data def generate_prompt_for_stage(self, stage: str) -> str: for s in IDEA_STAGES: if s.name == stage: return f"We are currently working on the '{stage}' stage. {s.question}" return f"We are currently working on the '{stage}' stage. Please provide relevant information." def reset(self): self.chat_history = [] self.idea_form = IdeaForm() self.idea_id = None self.current_stage = None self.add_system_message(self.get_initial_greeting()) try: new_session = SessionLocal() clear_database(new_session) new_session.commit() except Exception as e: logging.error(f"Error clearing database: {str(e)}") new_session.rollback() finally: new_session.close() return self.chat_history, self.idea_form.dict() def start_over(self): self.chat_history = [] self.idea_form = IdeaForm() self.current_stage = None self.add_system_message(self.get_initial_greeting()) try: new_session = SessionLocal() # Clear the existing database clear_database(new_session) # Create a new empty idea new_idea = InnovativeIdea() new_session.add(new_idea) new_session.commit() new_session.refresh(new_idea) # Update the idea_id self.idea_id = new_idea.id new_session.close() except Exception as e: logging.error(f"Error in start_over: {str(e)}") if 'new_session' in locals(): new_session.rollback() new_session.close() return self.chat_history, self.idea_form.dict(), STAGES[0]["name"] def update_idea_form(self, stage_name: str, form_data: str): setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data) try: new_session = SessionLocal() if self.idea_id: update_idea_in_database(self.idea_id, self.idea_form, new_session) else: self.idea_id = save_idea_to_database(self.idea_form, new_session) new_session.commit() except Exception as e: logging.error(f"Error updating idea form: {str(e)}") new_session.rollback() finally: new_session.close() def generate_form_data(self, stage: str, model: str, thinking_budget: int) -> str: # Prepare the conversation history for the LLM conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) stage_prompt = self.generate_prompt_for_stage(stage) formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( current_stage=stage, stage_prompt=stage_prompt ) prompt = f""" {formatted_system_prompt} Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: {conversation} Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. Your response should be structured as follows: 1. A brief analysis of the conversation related to this stage. 2. A concise summary of the key points relevant to this stage. 3. A suggested form entry for this stage, enclosed in tags. The form entry should be in the format: "{stage}: Content" Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. """ # Get LLM response llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) # Extract form data from the LLM response form_data = extract_form_data(llm_response) return form_data.get(stage, "") def process_stage_input_stream(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int): if self.current_stage != stage_name: activation_message = self.activate_stage(stage_name) if activation_message is None: error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." self.chat_history.append(("System", error_message)) yield self.chat_history, self.idea_form.dict() return self.chat_history.append(("System", activation_message)) # Check for web search request if message.startswith('@'): search_query = message[1:].strip() optimized_query = optimize_search_query(search_query, model) search_results = perform_web_search(optimized_query) self.chat_history.append(("Human", message)) self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) yield self.chat_history, self.idea_form.dict() return # Generate the prompt for the current stage stage_prompt = self.generate_prompt_for_stage(stage_name) formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( current_stage=stage_name, stage_prompt=stage_prompt ) combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" # Get LLM response llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) parsed_response = self.parse_llm_response(llm_response) self.chat_history.append(("Human", message)) self.chat_history.append(("AI", parsed_response)) form_data = extract_form_data(llm_response) if stage_name in form_data: setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) yield self.chat_history, self.idea_form.dict() def fill_out_form_stream(self, current_stage: str, model: str, thinking_budget: int): form_data = {} for stage in IDEA_STAGES: stage_name = stage.name if stage_name == current_stage: form_data[stage_name] = self.generate_form_data(stage_name, model, thinking_budget) else: form_data[stage_name] = getattr(self.idea_form, stage.field, "") yield form_data # Update the idea form for stage in IDEA_STAGES: setattr(self.idea_form, stage.field, form_data[stage.name]) # Save to database try: new_session = SessionLocal() if self.idea_id: update_idea_in_database(self.idea_id, self.idea_form, new_session) else: self.idea_id = save_idea_to_database(self.idea_form, new_session) new_session.commit() except Exception as e: logging.error(f"Error saving idea to database: {str(e)}") new_session.rollback() finally: new_session.close() def generate_form_data_stream(self, stage: str, model: str, thinking_budget: int): conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) stage_prompt = self.generate_prompt_for_stage(stage) formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( current_stage=stage, stage_prompt=stage_prompt ) prompt = f""" {formatted_system_prompt} Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: {conversation} Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. Your response should be structured as follows: 1. A brief analysis of the conversation related to this stage. 2. A concise summary of the key points relevant to this stage. 3. A suggested form entry for this stage, enclosed in tags. The form entry should be in the format: "{stage}: Content" Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. """ llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) form_data = extract_form_data(llm_response) return form_data.get(stage, "") def update_form_field(self, stage_name: str, value: str): field_name = stage_name.lower().replace(" ", "_") if field_name == 'team_roles': value = value.split(',') # Convert string to list for team_roles setattr(self.idea_form, field_name, value) try: # Create a new session for this operation new_session = SessionLocal() if self.idea_id: update_idea_in_database(self.idea_id, self.idea_form, new_session) else: self.idea_id = save_idea_to_database(self.idea_form, new_session) new_session.commit() except Exception as e: logging.error(f"Error updating form field: {str(e)}") # If an error occurs, rollback the new session new_session.rollback() finally: # Always close the new session new_session.close() return value # Add this new method to handle the "Fill out form" button click def fill_out_form_button(chatbot: InnovativeIdeaChatbot, current_stage: str, model: str, thinking_budget: int): form_data = chatbot.fill_out_form(current_stage, model, thinking_budget) return {stage["field"]: form_data[stage["field"]] for stage in STAGES}