RobbiePasquale commited on
Commit
a1e5744
·
verified ·
1 Parent(s): 3d255ef

Update train_agent.py

Browse files
Files changed (1) hide show
  1. train_agent.py +125 -116
train_agent.py CHANGED
@@ -1,116 +1,125 @@
1
- # train_agent.py
2
-
3
- from twisted.internet import reactor, defer, task
4
- from agent import AutonomousWebAgent
5
- import random
6
- import logging
7
- import sys
8
- import time
9
- import codecs
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO,
13
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
- handlers=[
15
- logging.FileHandler("agent_training.log", encoding='utf-8'),
16
- logging.StreamHandler(codecs.getwriter('utf-8')(sys.stdout.buffer))
17
- ])
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
- # List of diverse queries
22
- QUERIES = [
23
- "machine learning", "climate change", "renewable energy", "artificial intelligence",
24
- "quantum computing", "blockchain technology", "gene editing", "virtual reality",
25
- "space exploration", "cybersecurity", "autonomous vehicles", "Internet of Things",
26
- "3D printing", "nanotechnology", "bioinformatics", "augmented reality", "robotics",
27
- "data science", "neural networks", "cloud computing", "edge computing", "5G technology",
28
- "cryptocurrency", "natural language processing", "computer vision"
29
- ]
30
-
31
- @defer.inlineCallbacks
32
- def train_agent():
33
- # Updated state_size to 7 to match the feature extraction in AutonomousWebAgent
34
- state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
35
- action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
36
- num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
37
-
38
- # Initialize the AutonomousWebAgent with the required arguments
39
- agent = AutonomousWebAgent(
40
- state_size=state_size,
41
- action_size=action_size,
42
- num_options=num_options, # Added parameter for HRL
43
- hidden_size=64,
44
- learning_rate=0.001,
45
- gamma=0.99,
46
- epsilon=1.0,
47
- epsilon_decay=0.995,
48
- epsilon_min=0.01,
49
- knowledge_base_path='knowledge_base.json'
50
- )
51
- logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
52
-
53
- num_episodes = 10 # Adjust as needed
54
- total_training_reward = 0
55
- start_time = time.time()
56
-
57
- for episode in range(num_episodes):
58
- query = random.choice(QUERIES)
59
- logger.info(f"Starting episode {episode + 1}/{num_episodes} with query: {query}")
60
- episode_start_time = time.time()
61
-
62
- try:
63
- # Initiate the search process
64
- search_deferred = agent.search(query)
65
- search_deferred.addTimeout(300, reactor) # 5-minute timeout
66
- total_reward = yield search_deferred
67
- total_training_reward += total_reward
68
- episode_duration = time.time() - episode_start_time
69
- logger.info(f"Episode {episode + 1}/{num_episodes}, Query: {query}, Total Reward: {total_reward}, Duration: {episode_duration:.2f} seconds")
70
- except defer.TimeoutError:
71
- logger.error(f"Episode {episode + 1} timed out")
72
- total_reward = -1 # Assign a negative reward for timeout
73
- total_training_reward += total_reward
74
- except Exception as e:
75
- logger.error(f"Error in episode {episode + 1}: {str(e)}", exc_info=True)
76
- total_reward = -1 # Assign a negative reward for errors
77
- total_training_reward += total_reward
78
-
79
- # Update target models periodically
80
- if (episode + 1) % 10 == 0:
81
- logger.info(f"Updating target models at episode {episode + 1}")
82
- agent.update_worker_target_model()
83
- agent.update_manager_target_model()
84
- agent.manager.update_target_model()
85
-
86
- # Log overall progress
87
- progress = (episode + 1) / num_episodes
88
- elapsed_time = time.time() - start_time
89
- estimated_total_time = elapsed_time / progress if progress > 0 else 0
90
- remaining_time = estimated_total_time - elapsed_time
91
- logger.info(f"Overall progress: {progress:.2%}, Elapsed time: {elapsed_time:.2f}s, Estimated remaining time: {remaining_time:.2f}s")
92
-
93
- total_training_time = time.time() - start_time
94
- average_reward = total_training_reward / num_episodes
95
- logger.info(f"Training completed. Total reward: {total_training_reward}, Average reward per episode: {average_reward:.2f}")
96
- logger.info(f"Total training time: {total_training_time:.2f} seconds")
97
- logger.info("Saving models.")
98
-
99
- # Save both Worker and Manager models
100
- agent.save_worker_model("worker_model.pth")
101
- agent.save_manager_model("manager_model.pth")
102
- agent.save("web_agent_model.pth") # Assuming this saves additional components if needed
103
-
104
- if reactor.running:
105
- logger.info("Stopping reactor")
106
- reactor.stop()
107
-
108
- def main():
109
- logger.info("Starting agent training")
110
- d = task.deferLater(reactor, 0, train_agent)
111
- d.addErrback(lambda failure: logger.error(f"An error occurred: {failure}", exc_info=True))
112
- d.addBoth(lambda _: reactor.stop())
113
- reactor.run()
114
-
115
- if __name__ == "__main__":
116
- main()
 
 
 
 
 
 
 
 
 
 
1
+ # train_agent.py
2
+
3
+ from twisted.internet import reactor, defer, task
4
+ from agent import AutonomousWebAgent
5
+ import random
6
+ import logging
7
+ import sys
8
+ import time
9
+ import codecs
10
+
11
+ IS_COLAB = 'google.colab' in sys.modules
12
+
13
+
14
+ # Configure logging
15
+ if IS_COLAB:
16
+ logging.basicConfig(level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
+ else:
19
+ logging.basicConfig(level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.FileHandler("agent_training.log", encoding='utf-8'),
23
+ logging.StreamHandler(codecs.getwriter('utf-8')(sys.stdout.buffer))
24
+ ])
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # List of diverse queries
29
+ QUERIES = [
30
+ "machine learning", "climate change", "renewable energy", "artificial intelligence",
31
+ "quantum computing", "blockchain technology", "gene editing", "virtual reality",
32
+ "space exploration", "cybersecurity", "autonomous vehicles", "Internet of Things",
33
+ "3D printing", "nanotechnology", "bioinformatics", "augmented reality", "robotics",
34
+ "data science", "neural networks", "cloud computing", "edge computing", "5G technology",
35
+ "cryptocurrency", "natural language processing", "computer vision"
36
+ ]
37
+
38
+ @defer.inlineCallbacks
39
+ def train_agent():
40
+ # Updated state_size to 7 to match the feature extraction in AutonomousWebAgent
41
+ state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
42
+ action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
43
+ num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
44
+
45
+ # Initialize the AutonomousWebAgent with the required arguments
46
+ agent = AutonomousWebAgent(
47
+ state_size=state_size,
48
+ action_size=action_size,
49
+ num_options=num_options, # Added parameter for HRL
50
+ hidden_size=64,
51
+ learning_rate=0.001,
52
+ gamma=0.99,
53
+ epsilon=1.0,
54
+ epsilon_decay=0.995,
55
+ epsilon_min=0.01,
56
+ knowledge_base_path='knowledge_base.json'
57
+ )
58
+ logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
59
+
60
+ num_episodes = 10 # Adjust as needed
61
+ total_training_reward = 0
62
+ start_time = time.time()
63
+
64
+ for episode in range(num_episodes):
65
+ query = random.choice(QUERIES)
66
+ logger.info(f"Starting episode {episode + 1}/{num_episodes} with query: {query}")
67
+ episode_start_time = time.time()
68
+
69
+ try:
70
+ # Initiate the search process
71
+ search_deferred = agent.search(query)
72
+ search_deferred.addTimeout(300, reactor) # 5-minute timeout
73
+ total_reward = yield search_deferred
74
+ total_training_reward += total_reward
75
+ episode_duration = time.time() - episode_start_time
76
+ logger.info(f"Episode {episode + 1}/{num_episodes}, Query: {query}, Total Reward: {total_reward}, Duration: {episode_duration:.2f} seconds")
77
+ except defer.TimeoutError:
78
+ logger.error(f"Episode {episode + 1} timed out")
79
+ total_reward = -1 # Assign a negative reward for timeout
80
+ total_training_reward += total_reward
81
+ except Exception as e:
82
+ logger.error(f"Error in episode {episode + 1}: {str(e)}", exc_info=True)
83
+ total_reward = -1 # Assign a negative reward for errors
84
+ total_training_reward += total_reward
85
+
86
+ # Update target models periodically
87
+ if (episode + 1) % 10 == 0:
88
+ logger.info(f"Updating target models at episode {episode + 1}")
89
+ agent.update_worker_target_model()
90
+ agent.update_manager_target_model()
91
+ agent.manager.update_target_model()
92
+
93
+ # Log overall progress
94
+ progress = (episode + 1) / num_episodes
95
+ elapsed_time = time.time() - start_time
96
+ estimated_total_time = elapsed_time / progress if progress > 0 else 0
97
+ remaining_time = estimated_total_time - elapsed_time
98
+ logger.info(f"Overall progress: {progress:.2%}, Elapsed time: {elapsed_time:.2f}s, Estimated remaining time: {remaining_time:.2f}s")
99
+
100
+ total_training_time = time.time() - start_time
101
+ average_reward = total_training_reward / num_episodes
102
+ logger.info(f"Training completed. Total reward: {total_training_reward}, Average reward per episode: {average_reward:.2f}")
103
+ logger.info(f"Total training time: {total_training_time:.2f} seconds")
104
+ logger.info("Saving models.")
105
+
106
+ # Save both Worker and Manager models
107
+ agent.save_worker_model("worker_model.pth")
108
+ agent.save_manager_model("manager_model.pth")
109
+ agent.save("web_agent_model.pth") # Assuming this saves additional components if needed
110
+
111
+ if reactor.running:
112
+ logger.info("Stopping reactor")
113
+ reactor.stop()
114
+
115
+ def main(is_colab=False):
116
+ global IS_COLAB
117
+ IS_COLAB = is_colab
118
+ logger.info("Starting agent training")
119
+ d = task.deferLater(reactor, 0, train_agent)
120
+ d.addErrback(lambda failure: logger.error(f"An error occurred: {failure}", exc_info=True))
121
+ d.addBoth(lambda _: reactor.stop())
122
+ reactor.run()
123
+
124
+ if __name__ == "__main__":
125
+ main(IS_COLAB)