lightbulb / main_menu.py
RobbiePasquale's picture
allow for model to train on the search agent experience (rag cache, knowledge base, search data)
3d255ef verified
raw
history blame
5.22 kB
# main_menu.py
import argparse
import sys
import os
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from lightbulb import main as world_model_main
from lightbulb_inf import main as inference_main # Import the inference main function
from twisted.internet import reactor, task
def parse_main_args():
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
parser.add_argument('--task', type=str, choices=[
'train_llm_world',
'train_agent',
'test_agent',
'inference_llm',
'inference_world_model',
'advanced_inference'
],
required=True,
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference')
# Optional arguments for more granular control
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
# Additional arguments specific to inference can be added here if needed
return parser.parse_args()
def main():
# Parse arguments for the main function
args = parse_main_args()
# Execute tasks based on user input
if args.task == 'train_llm_world':
print("Starting LLM and World Model Training...")
# Directly call the world model main function with appropriate arguments
sys.argv = [
'lightbulb.py',
'--mode', args.mode,
'--model_name', args.model_name,
'--dataset_name', args.dataset_name,
'--dataset_config', args.dataset_config,
'--batch_size', str(args.batch_size),
'--num_epochs', str(args.num_epochs),
'--max_length', str(args.max_length)
]
world_model_main()
elif args.task == 'train_agent':
print("Starting Agent Training...")
# Call the train_agent function from train_agent.py using Twisted reactor
d = task.deferLater(reactor, 0, train_agent)
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
d.addBoth(lambda _: reactor.stop())
reactor.run()
elif args.task == 'test_agent':
print("Starting Test Agent...")
test_agent = TestAgent()
if args.query:
# Directly process a single query
result = test_agent.process_query(args.query)
print("\nAgent's response:")
print(result)
else:
# Run the interactive session
reactor.callWhenRunning(run_test_session)
reactor.run()
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
print("Starting Inference Task...")
# Prepare the arguments for lightbulb_inf.py based on the selected inference task
# Map the main_menu task to lightbulb_inf.py's inference_mode
inference_mode_map = {
'inference_llm': 'without_world_model',
'inference_world_model': 'world_model',
'advanced_inference': 'world_model_tree_of_thought'
}
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
# Construct sys.argv for lightbulb_inf.py
lightbulb_inf_args = [
'lightbulb_custom.py',
'--mode', 'inference',
'--model_name', args.model_name,
'--query', args.query,
'--max_length', str(args.max_length),
'--inference_mode', selected_inference_mode,
'--beam_size', str(getattr(args, 'beam_size', 5)),
'--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)),
'--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)),
'--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414))
]
# Include additional arguments if they exist
if hasattr(args, 'load_model') and args.load_model:
lightbulb_inf_args += ['--load_model', args.load_model]
# Update sys.argv and call the inference main function
sys.argv = lightbulb_inf_args
inference_main()
else:
print(f"Unknown task: {args.task}")
sys.exit(1)
if __name__ == "__main__":
main()