allow for model to train on the search agent experience (rag cache, knowledge base, search data)
3d255ef
verified
# main_menu.py | |
import argparse | |
import sys | |
import os | |
from train_agent import train_agent | |
from test_agent import TestAgent, run_test_session | |
from lightbulb import main as world_model_main | |
from lightbulb_inf import main as inference_main # Import the inference main function | |
from twisted.internet import reactor, task | |
def parse_main_args(): | |
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") | |
parser.add_argument('--task', type=str, choices=[ | |
'train_llm_world', | |
'train_agent', | |
'test_agent', | |
'inference_llm', | |
'inference_world_model', | |
'advanced_inference' | |
], | |
required=True, | |
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference') | |
# Optional arguments for more granular control | |
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') | |
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') | |
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') | |
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') | |
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') | |
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') | |
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM') | |
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') | |
# Additional arguments specific to inference can be added here if needed | |
return parser.parse_args() | |
def main(): | |
# Parse arguments for the main function | |
args = parse_main_args() | |
# Execute tasks based on user input | |
if args.task == 'train_llm_world': | |
print("Starting LLM and World Model Training...") | |
# Directly call the world model main function with appropriate arguments | |
sys.argv = [ | |
'lightbulb.py', | |
'--mode', args.mode, | |
'--model_name', args.model_name, | |
'--dataset_name', args.dataset_name, | |
'--dataset_config', args.dataset_config, | |
'--batch_size', str(args.batch_size), | |
'--num_epochs', str(args.num_epochs), | |
'--max_length', str(args.max_length) | |
] | |
world_model_main() | |
elif args.task == 'train_agent': | |
print("Starting Agent Training...") | |
# Call the train_agent function from train_agent.py using Twisted reactor | |
d = task.deferLater(reactor, 0, train_agent) | |
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) | |
d.addBoth(lambda _: reactor.stop()) | |
reactor.run() | |
elif args.task == 'test_agent': | |
print("Starting Test Agent...") | |
test_agent = TestAgent() | |
if args.query: | |
# Directly process a single query | |
result = test_agent.process_query(args.query) | |
print("\nAgent's response:") | |
print(result) | |
else: | |
# Run the interactive session | |
reactor.callWhenRunning(run_test_session) | |
reactor.run() | |
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: | |
print("Starting Inference Task...") | |
# Prepare the arguments for lightbulb_inf.py based on the selected inference task | |
# Map the main_menu task to lightbulb_inf.py's inference_mode | |
inference_mode_map = { | |
'inference_llm': 'without_world_model', | |
'inference_world_model': 'world_model', | |
'advanced_inference': 'world_model_tree_of_thought' | |
} | |
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') | |
# Construct sys.argv for lightbulb_inf.py | |
lightbulb_inf_args = [ | |
'lightbulb_custom.py', | |
'--mode', 'inference', | |
'--model_name', args.model_name, | |
'--query', args.query, | |
'--max_length', str(args.max_length), | |
'--inference_mode', selected_inference_mode, | |
'--beam_size', str(getattr(args, 'beam_size', 5)), | |
'--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)), | |
'--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)), | |
'--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414)) | |
] | |
# Include additional arguments if they exist | |
if hasattr(args, 'load_model') and args.load_model: | |
lightbulb_inf_args += ['--load_model', args.load_model] | |
# Update sys.argv and call the inference main function | |
sys.argv = lightbulb_inf_args | |
inference_main() | |
else: | |
print(f"Unknown task: {args.task}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() | |