File size: 5,215 Bytes
85033dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d255ef
85033dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# main_menu.py

import argparse
import sys
import os
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from lightbulb import main as world_model_main
from lightbulb_inf import main as inference_main  # Import the inference main function
from twisted.internet import reactor, task

def parse_main_args():
    parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
    parser.add_argument('--task', type=str, choices=[
                        'train_llm_world', 
                        'train_agent', 
                        'test_agent', 
                        'inference_llm', 
                        'inference_world_model', 
                        'advanced_inference'
                    ], 
                        required=True, 
                        help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference')
    # Optional arguments for more granular control
    parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
    parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
    parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
    parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
    parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
    # Additional arguments specific to inference can be added here if needed
    return parser.parse_args()

def main():
    # Parse arguments for the main function
    args = parse_main_args()

    # Execute tasks based on user input
    if args.task == 'train_llm_world':
        print("Starting LLM and World Model Training...")
        # Directly call the world model main function with appropriate arguments
        sys.argv = [
            'lightbulb.py', 
            '--mode', args.mode, 
            '--model_name', args.model_name, 
            '--dataset_name', args.dataset_name, 
            '--dataset_config', args.dataset_config,
            '--batch_size', str(args.batch_size), 
            '--num_epochs', str(args.num_epochs),
            '--max_length', str(args.max_length)
        ]
        world_model_main()

    elif args.task == 'train_agent':
        print("Starting Agent Training...")
        # Call the train_agent function from train_agent.py using Twisted reactor
        d = task.deferLater(reactor, 0, train_agent)
        d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
        d.addBoth(lambda _: reactor.stop())
        reactor.run()

    elif args.task == 'test_agent':
        print("Starting Test Agent...")
        test_agent = TestAgent()
        if args.query:
            # Directly process a single query
            result = test_agent.process_query(args.query)
            print("\nAgent's response:")
            print(result)
        else:
            # Run the interactive session
            reactor.callWhenRunning(run_test_session)
            reactor.run()

    elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
        print("Starting Inference Task...")
        # Prepare the arguments for lightbulb_inf.py based on the selected inference task

        # Map the main_menu task to lightbulb_inf.py's inference_mode
        inference_mode_map = {
            'inference_llm': 'without_world_model',
            'inference_world_model': 'world_model',
            'advanced_inference': 'world_model_tree_of_thought'
        }

        selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')

        # Construct sys.argv for lightbulb_inf.py
        lightbulb_inf_args = [
            'lightbulb_custom.py',
            '--mode', 'inference',
            '--model_name', args.model_name,
            '--query', args.query,
            '--max_length', str(args.max_length),
            '--inference_mode', selected_inference_mode,
            '--beam_size', str(getattr(args, 'beam_size', 5)),
            '--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)),
            '--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)),
            '--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414))
        ]

        # Include additional arguments if they exist
        if hasattr(args, 'load_model') and args.load_model:
            lightbulb_inf_args += ['--load_model', args.load_model]

        # Update sys.argv and call the inference main function
        sys.argv = lightbulb_inf_args
        inference_main()

    else:
        print(f"Unknown task: {args.task}")
        sys.exit(1)

if __name__ == "__main__":
    main()