Spaces:
Running
Running
from typing import List, Tuple, Dict, Any, Optional | |
import logging | |
import re | |
from data_models import IdeaForm, IDEA_STAGES | |
from config import DEFAULT_SYSTEM_PROMPT, STAGES | |
from utils import ( | |
get_llm_response, extract_form_data, save_idea_to_database, | |
load_idea_from_database, update_idea_in_database, | |
get_db, clear_database, init_db, create_tables, | |
perform_web_search, optimize_search_query, | |
SessionLocal, InnovativeIdea | |
) | |
class InnovativeIdeaChatbot: | |
def __init__(self): | |
create_tables() | |
init_db() | |
self.idea_form = IdeaForm() | |
self.chat_history = [] | |
self.idea_id = None | |
self.current_stage = None | |
self.api_key = None | |
self.add_system_message(self.get_initial_greeting()) | |
def get_initial_greeting(self) -> str: | |
greeting = """ | |
Welcome to the Innovative Idea Generator! I'm Myamoto, your AI assistant designed to help you refine and develop your innovative ideas. | |
Here's how we'll work together: | |
1. We'll go through 10 stages to explore different aspects of your idea. | |
2. At each stage, I'll ask you questions and provide feedback to help you think deeper about your concept. | |
3. You can ask me questions at any time or request more information on a topic. | |
4. If you want to perform a web search for additional information, just start your message with '@' followed by your search query. | |
5. When you're ready to move to the next stage, simply type 'next'. | |
Let's start by exploring your innovative idea! What's the name of your idea, or would you like help coming up with one? | |
""" | |
self.greeted = True | |
return greeting | |
def add_system_message(self, message: str): | |
self.chat_history.append(("System", message)) | |
def set_api_key(self, api_key: str): | |
self.api_key = api_key | |
def activate_stage(self, stage_name: str) -> Optional[str]: | |
self.current_stage = stage_name | |
for stage in STAGES: | |
if stage["name"] == stage_name: | |
return f"Let's work on the '{stage_name}' stage. {stage['question']}" | |
return None | |
def process_stage_input(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]: | |
if self.current_stage != stage_name: | |
activation_message = self.activate_stage(stage_name) | |
if activation_message is None: | |
error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." | |
self.chat_history.append(("System", error_message)) | |
return self.chat_history, self.idea_form.dict() | |
self.chat_history.append(("System", activation_message)) | |
# Check for web search request | |
if message.startswith('@'): | |
search_query = message[1:].strip() | |
optimized_query = optimize_search_query(search_query, model) | |
search_results = perform_web_search(optimized_query) | |
self.chat_history.append(("Human", message)) | |
self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) | |
return self.chat_history, self.idea_form.dict() | |
# Generate the prompt for the current stage | |
stage_prompt = self.generate_prompt_for_stage(stage_name) | |
# Use the DEFAULT_SYSTEM_PROMPT from config.py | |
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
current_stage=stage_name, | |
stage_prompt=stage_prompt | |
) | |
# Combine the formatted system prompt and user's input | |
combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" | |
# Get LLM response | |
llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) | |
# Parse the LLM response to extract only the user-facing content | |
parsed_response = self.parse_llm_response(llm_response) | |
# Add the interaction to chat history | |
self.chat_history.append(("Human", message)) | |
self.chat_history.append(("AI", parsed_response)) | |
# Extract form data from the LLM response | |
form_data = extract_form_data(llm_response) | |
# Update the idea form | |
if stage_name in form_data: | |
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) | |
return self.chat_history, self.idea_form.dict() | |
def parse_llm_response(self, response: str) -> str: | |
# Remove content within <form_data> tags | |
response = re.sub(r'<form_data>.*?</form_data>', '', response, flags=re.DOTALL) | |
# Remove content within <reflection> tags | |
response = re.sub(r'<reflection>.*?</reflection>', '', response, flags=re.DOTALL) | |
# Remove content within <analysis> tags | |
response = re.sub(r'<analysis>.*?</analysis>', '', response, flags=re.DOTALL) | |
# Remove content within <summary> tags | |
response = re.sub(r'<summary>.*?</summary>', '', response, flags=re.DOTALL) | |
# Remove content within <step> tags | |
response = re.sub(r'<step>.*?</step>', '', response, flags=re.DOTALL) | |
# Remove any remaining HTML-like tags | |
response = re.sub(r'<[^>]+>', '', response) | |
# Remove extra whitespace and newlines | |
response = re.sub(r'\s+', ' ', response).strip() | |
return response | |
def fill_out_form(self, current_stage: str, model: str, thinking_budget: int) -> Dict[str, str]: | |
form_data = {} | |
for stage in STAGES: | |
stage_name = stage["name"] | |
if stage_name == current_stage: | |
# Generate new data for the current stage | |
form_data[stage["field"]] = self.generate_form_data(stage_name, model, thinking_budget) | |
else: | |
# Use existing data for other stages | |
form_data[stage["field"]] = getattr(self.idea_form, stage["field"], "") | |
# Update the idea form | |
for stage in STAGES: | |
setattr(self.idea_form, stage["field"], form_data[stage["field"]]) | |
# Save to database | |
try: | |
new_session = SessionLocal() | |
if self.idea_id: | |
update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
else: | |
self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
new_session.commit() | |
except Exception as e: | |
logging.error(f"Error saving idea to database: {str(e)}") | |
new_session.rollback() | |
finally: | |
new_session.close() | |
return form_data | |
def generate_prompt_for_stage(self, stage: str) -> str: | |
for s in IDEA_STAGES: | |
if s.name == stage: | |
return f"We are currently working on the '{stage}' stage. {s.question}" | |
return f"We are currently working on the '{stage}' stage. Please provide relevant information." | |
def reset(self): | |
self.chat_history = [] | |
self.idea_form = IdeaForm() | |
self.idea_id = None | |
self.current_stage = None | |
self.add_system_message(self.get_initial_greeting()) | |
try: | |
new_session = SessionLocal() | |
clear_database(new_session) | |
new_session.commit() | |
except Exception as e: | |
logging.error(f"Error clearing database: {str(e)}") | |
new_session.rollback() | |
finally: | |
new_session.close() | |
return self.chat_history, self.idea_form.dict() | |
def start_over(self): | |
self.chat_history = [] | |
self.idea_form = IdeaForm() | |
self.current_stage = None | |
self.add_system_message(self.get_initial_greeting()) | |
try: | |
new_session = SessionLocal() | |
# Clear the existing database | |
clear_database(new_session) | |
# Create a new empty idea | |
new_idea = InnovativeIdea() | |
new_session.add(new_idea) | |
new_session.commit() | |
new_session.refresh(new_idea) | |
# Update the idea_id | |
self.idea_id = new_idea.id | |
new_session.close() | |
except Exception as e: | |
logging.error(f"Error in start_over: {str(e)}") | |
if 'new_session' in locals(): | |
new_session.rollback() | |
new_session.close() | |
return self.chat_history, self.idea_form.dict(), STAGES[0]["name"] | |
def update_idea_form(self, stage_name: str, form_data: str): | |
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data) | |
try: | |
new_session = SessionLocal() | |
if self.idea_id: | |
update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
else: | |
self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
new_session.commit() | |
except Exception as e: | |
logging.error(f"Error updating idea form: {str(e)}") | |
new_session.rollback() | |
finally: | |
new_session.close() | |
def generate_form_data(self, stage: str, model: str, thinking_budget: int) -> str: | |
# Prepare the conversation history for the LLM | |
conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) | |
stage_prompt = self.generate_prompt_for_stage(stage) | |
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
current_stage=stage, | |
stage_prompt=stage_prompt | |
) | |
prompt = f""" | |
{formatted_system_prompt} | |
Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: | |
{conversation} | |
Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. | |
Your response should be structured as follows: | |
1. A brief analysis of the conversation related to this stage. | |
2. A concise summary of the key points relevant to this stage. | |
3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags. | |
The form entry should be in the format: "{stage}: Content" | |
Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. | |
""" | |
# Get LLM response | |
llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) | |
# Extract form data from the LLM response | |
form_data = extract_form_data(llm_response) | |
return form_data.get(stage, "") | |
def process_stage_input_stream(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int): | |
if self.current_stage != stage_name: | |
activation_message = self.activate_stage(stage_name) | |
if activation_message is None: | |
error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." | |
self.chat_history.append(("System", error_message)) | |
yield self.chat_history, self.idea_form.dict() | |
return | |
self.chat_history.append(("System", activation_message)) | |
# Check for web search request | |
if message.startswith('@'): | |
search_query = message[1:].strip() | |
optimized_query = optimize_search_query(search_query, model) | |
search_results = perform_web_search(optimized_query) | |
self.chat_history.append(("Human", message)) | |
self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) | |
yield self.chat_history, self.idea_form.dict() | |
return | |
# Generate the prompt for the current stage | |
stage_prompt = self.generate_prompt_for_stage(stage_name) | |
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
current_stage=stage_name, | |
stage_prompt=stage_prompt | |
) | |
combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" | |
# Get LLM response | |
llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) | |
parsed_response = self.parse_llm_response(llm_response) | |
self.chat_history.append(("Human", message)) | |
self.chat_history.append(("AI", parsed_response)) | |
form_data = extract_form_data(llm_response) | |
if stage_name in form_data: | |
setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) | |
yield self.chat_history, self.idea_form.dict() | |
def fill_out_form_stream(self, current_stage: str, model: str, thinking_budget: int): | |
form_data = {} | |
for stage in IDEA_STAGES: | |
stage_name = stage.name | |
if stage_name == current_stage: | |
form_data[stage_name] = self.generate_form_data(stage_name, model, thinking_budget) | |
else: | |
form_data[stage_name] = getattr(self.idea_form, stage.field, "") | |
yield form_data | |
# Update the idea form | |
for stage in IDEA_STAGES: | |
setattr(self.idea_form, stage.field, form_data[stage.name]) | |
# Save to database | |
try: | |
new_session = SessionLocal() | |
if self.idea_id: | |
update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
else: | |
self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
new_session.commit() | |
except Exception as e: | |
logging.error(f"Error saving idea to database: {str(e)}") | |
new_session.rollback() | |
finally: | |
new_session.close() | |
def generate_form_data_stream(self, stage: str, model: str, thinking_budget: int): | |
conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) | |
stage_prompt = self.generate_prompt_for_stage(stage) | |
formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
current_stage=stage, | |
stage_prompt=stage_prompt | |
) | |
prompt = f""" | |
{formatted_system_prompt} | |
Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: | |
{conversation} | |
Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. | |
Your response should be structured as follows: | |
1. A brief analysis of the conversation related to this stage. | |
2. A concise summary of the key points relevant to this stage. | |
3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags. | |
The form entry should be in the format: "{stage}: Content" | |
Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. | |
""" | |
llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) | |
form_data = extract_form_data(llm_response) | |
return form_data.get(stage, "") | |
def update_form_field(self, stage_name: str, value: str): | |
field_name = stage_name.lower().replace(" ", "_") | |
if field_name == 'team_roles': | |
value = value.split(',') # Convert string to list for team_roles | |
setattr(self.idea_form, field_name, value) | |
try: | |
# Create a new session for this operation | |
new_session = SessionLocal() | |
if self.idea_id: | |
update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
else: | |
self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
new_session.commit() | |
except Exception as e: | |
logging.error(f"Error updating form field: {str(e)}") | |
# If an error occurs, rollback the new session | |
new_session.rollback() | |
finally: | |
# Always close the new session | |
new_session.close() | |
return value | |
# Add this new method to handle the "Fill out form" button click | |
def fill_out_form_button(chatbot: InnovativeIdeaChatbot, current_stage: str, model: str, thinking_budget: int): | |
form_data = chatbot.fill_out_form(current_stage, model, thinking_budget) | |
return {stage["field"]: form_data[stage["field"]] for stage in STAGES} |