RobbiePasquale commited on
Commit
904b97c
·
verified ·
1 Parent(s): 7ba0a0a

Upload lightbulb_wm.py

Browse files
Files changed (1) hide show
  1. lightbulb_wm.py +1292 -0
lightbulb_wm.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import math
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ from torch.utils.data import DataLoader
10
+ import copy
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR
12
+ from torch.amp import autocast, GradScaler
13
+ from datasets import load_dataset
14
+ from transformers import AutoTokenizer
15
+
16
+
17
+ # Set the device
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description='Train World Model with Transformer outputs.')
23
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
24
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
25
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
26
+ parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
27
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
28
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
29
+ parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
30
+ parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
31
+ parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
32
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
33
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
34
+ parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
35
+ parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
36
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
37
+ parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
38
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
39
+ parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def load_data(args, tokenizer):
45
+ # Load the dataset
46
+ dataset = load_dataset(args.dataset_name, args.dataset_config)
47
+
48
+ # Ensure the tokenizer has a padding token
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ def tokenize_function(examples):
53
+ return tokenizer(examples['text'], truncation=True, max_length=args.max_length)
54
+
55
+ tokenized_datasets = dataset.map(
56
+ tokenize_function,
57
+ batched=True,
58
+ num_proc=4,
59
+ remove_columns=dataset['train'].column_names,
60
+ )
61
+
62
+ # Build inputs and labels for language modeling
63
+ block_size = args.max_length
64
+
65
+ def group_texts(examples):
66
+ # Concatenate all texts
67
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
68
+ total_length = len(concatenated_examples['input_ids'])
69
+ # We drop the small remainder
70
+ total_length = (total_length // block_size) * block_size
71
+ # Split by chunks of block_size
72
+ result = {
73
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
74
+ for k, t in concatenated_examples.items()
75
+ }
76
+ result['labels'] = result['input_ids'].copy()
77
+ return result
78
+
79
+ lm_datasets = tokenized_datasets.map(
80
+ group_texts,
81
+ batched=True,
82
+ num_proc=4,
83
+ )
84
+
85
+ # Create DataLoader
86
+ train_dataset = lm_datasets['train']
87
+ eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test']
88
+
89
+ data_collator = lambda data: {
90
+ 'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long),
91
+ 'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long)
92
+ }
93
+
94
+ train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=data_collator)
95
+ eval_loader = DataLoader(eval_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=data_collator)
96
+
97
+ return train_loader, eval_loader
98
+
99
+
100
+ def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
101
+ """
102
+ Save all models to the specified directory.
103
+
104
+ Args:
105
+ transformer_model (nn.Module): Transformer model.
106
+ representation_network (nn.Module): Representation network.
107
+ dynamics_network (nn.Module): Dynamics network.
108
+ prediction_network (nn.Module): Prediction network.
109
+ action_encoder (nn.Module): Action encoder.
110
+ save_dir (str): Directory to save the models.
111
+ epoch (int): Current epoch number.
112
+ """
113
+ os.makedirs(save_dir, exist_ok=True)
114
+
115
+ torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
116
+ torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
117
+ torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
118
+ torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
119
+ torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
120
+
121
+ print(f"All models saved for epoch {epoch}.")
122
+
123
+
124
+ class RotaryPositionalEncoding(nn.Module):
125
+ def __init__(self, d_model):
126
+ super(RotaryPositionalEncoding, self).__init__()
127
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
128
+ self.register_buffer('inv_freq', inv_freq)
129
+
130
+ def forward(self, x):
131
+ seq_len, batch_size, _ = x.size()
132
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
133
+ sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
134
+ sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2)
135
+ cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2)
136
+
137
+ x1 = x[..., 0::2]
138
+ x2 = x[..., 1::2]
139
+
140
+ # Apply rotation
141
+ x_rotated = torch.zeros_like(x)
142
+ x_rotated[..., 0::2] = x1 * cos - x2 * sin
143
+ x_rotated[..., 1::2] = x1 * sin + x2 * cos
144
+
145
+ return x_rotated
146
+
147
+
148
+ class MultiHeadAttention(nn.Module):
149
+ def __init__(self, d_model, num_heads):
150
+ super(MultiHeadAttention, self).__init__()
151
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
152
+ self.d_k = d_model // num_heads
153
+ self.num_heads = num_heads
154
+ self.linear_q = nn.Linear(d_model, d_model)
155
+ self.linear_k = nn.Linear(d_model, d_model)
156
+ self.linear_v = nn.Linear(d_model, d_model)
157
+ self.linear_out = nn.Linear(d_model, d_model)
158
+
159
+ def forward(self, query, key, value, mask=None):
160
+ batch_size = query.size(0)
161
+ query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
162
+ key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
163
+ value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
164
+
165
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
166
+ if mask is not None:
167
+ scores = scores.masked_fill(mask == 0, -1e4)
168
+ attn = F.softmax(scores, dim=-1)
169
+ output = torch.matmul(attn, value)
170
+
171
+ output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
172
+ return self.linear_out(output)
173
+
174
+
175
+ class MoE(nn.Module):
176
+ def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1):
177
+ super(MoE, self).__init__()
178
+ self.num_experts = num_experts
179
+ self.top_k = top_k
180
+ self.experts = nn.ModuleList([
181
+ nn.Sequential(
182
+ nn.Linear(d_model, d_ff),
183
+ nn.GELU() if i % 2 == 0 else nn.SiLU(),
184
+ nn.Linear(d_ff, d_model)
185
+ )
186
+ for i in range(num_experts)
187
+ ])
188
+ self.gate = nn.Linear(d_model, num_experts)
189
+ self.dropout = nn.Dropout(dropout)
190
+
191
+ def forward(self, x):
192
+ batch_size, seq_len, d_model = x.size()
193
+ # Compute gating scores
194
+ gate_scores = self.gate(x) # (batch_size, seq_len, num_experts)
195
+ top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k)
196
+ top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k)
197
+
198
+ # Initialize output
199
+ output = torch.zeros_like(x)
200
+
201
+ # Flatten batch and sequence dimensions
202
+ x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model)
203
+ output_flat = output.view(-1, d_model)
204
+ top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k)
205
+ top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k)
206
+
207
+ for k in range(self.top_k):
208
+ expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len)
209
+ expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len)
210
+ for e in range(self.num_experts):
211
+ mask = (expert_idx_flat == e) # Boolean mask
212
+ if mask.any():
213
+ x_masked = x_flat[mask] # Select tokens for expert e
214
+ expert_output = self.experts[e](x_masked) # Apply expert e
215
+ output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output
216
+
217
+ output = output_flat.view(batch_size, seq_len, d_model)
218
+ return self.dropout(output)
219
+
220
+
221
+ class TransformerBlock(nn.Module):
222
+ def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2):
223
+ super(TransformerBlock, self).__init__()
224
+ self.self_attention = MultiHeadAttention(d_model, num_heads)
225
+ self.norm1 = nn.LayerNorm(d_model)
226
+ self.cross_attention = MultiHeadAttention(d_model, num_heads)
227
+ self.norm2 = nn.LayerNorm(d_model)
228
+ self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout)
229
+ self.norm3 = nn.LayerNorm(d_model)
230
+
231
+ def forward(self, x, mask=None, enc_output=None, enc_mask=None):
232
+ # Self-attention
233
+ attn_output = self.self_attention(x, x, x, mask)
234
+ x = self.norm1(x + attn_output)
235
+ # Cross-attention (only in decoder)
236
+ if enc_output is not None:
237
+ cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask)
238
+ x = self.norm2(x + cross_attn_output)
239
+ # Feedforward/MoE
240
+ moe_output = self.moe(x)
241
+ return self.norm3(x + moe_output)
242
+
243
+
244
+ class Transformer(nn.Module):
245
+ def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2):
246
+ super(Transformer, self).__init__()
247
+ self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1)
248
+ self.rotary_positional_encoding = RotaryPositionalEncoding(d_model)
249
+ self.encoder_layers = nn.ModuleList(
250
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
251
+ )
252
+ self.decoder_layers = nn.ModuleList(
253
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
254
+ )
255
+ self.output_layer = nn.Linear(d_model, output_dim)
256
+ self.d_model = d_model
257
+
258
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None):
259
+ # Encoder
260
+ src = self.embedding(src) * math.sqrt(self.d_model)
261
+ src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
262
+ src = self.rotary_positional_encoding(src)
263
+ src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
264
+ for layer in self.encoder_layers:
265
+ src = layer(src, src_mask)
266
+
267
+ # Decoder
268
+ tgt = self.embedding(tgt) * math.sqrt(self.d_model)
269
+ tgt = tgt.transpose(0, 1)
270
+ tgt = self.rotary_positional_encoding(tgt)
271
+ tgt = tgt.transpose(0, 1)
272
+ for layer in self.decoder_layers:
273
+ tgt = layer(tgt, tgt_mask, src, src_mask)
274
+ output = self.output_layer(tgt)
275
+ return output
276
+
277
+ def generate(self, src, tokenizer, max_length=20, temperature=1.0):
278
+ """
279
+ Generate sequences using differentiable sampling (Gumbel-Softmax).
280
+
281
+ Args:
282
+ src (torch.Tensor): Source input tensor of shape (batch_size, seq_len)
283
+ tokenizer (transformers.PreTrainedTokenizer): Tokenizer to access special tokens
284
+ max_length (int): Maximum length of the generated sequence
285
+ temperature (float): Temperature parameter for Gumbel-Softmax
286
+
287
+ Returns:
288
+ torch.Tensor: Generated sequences of shape (batch_size, max_length)
289
+ torch.Tensor: Entropy values for each time step
290
+ torch.Tensor: Variance values for each time step
291
+ """
292
+ batch_size = src.size(0)
293
+
294
+ # Encode the source
295
+ src_enc = self.embedding(src) * math.sqrt(self.d_model)
296
+ src_enc = src_enc.transpose(0, 1)
297
+ src_enc = self.rotary_positional_encoding(src_enc)
298
+ src_enc = src_enc.transpose(0, 1)
299
+ for layer in self.encoder_layers:
300
+ src_enc = layer(src_enc)
301
+
302
+ # Initialize decoder input with <sos> tokens
303
+ tgt_seq = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)
304
+ entropies = []
305
+ variances = []
306
+
307
+ for _ in range(max_length):
308
+ tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model)
309
+ tgt_emb = tgt_emb.transpose(0, 1)
310
+ tgt_emb = self.rotary_positional_encoding(tgt_emb)
311
+ tgt_emb = tgt_emb.transpose(0, 1)
312
+ tgt_dec = tgt_emb
313
+ for layer in self.decoder_layers:
314
+ tgt_dec = layer(tgt_dec, None, src_enc, None)
315
+ output = self.output_layer(tgt_dec) # (batch_size, seq_len, vocab_size)
316
+ logits = output[:, -1, :] # Get logits for the last time step
317
+
318
+ # Compute token probabilities
319
+ probs = F.softmax(logits / temperature, dim=-1) # (batch_size, vocab_size)
320
+
321
+ # Compute entropy
322
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size)
323
+ entropies.append(entropy)
324
+
325
+ # Sample token using Gumbel-Softmax
326
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + 1e-9) + 1e-9)
327
+ y = (logits + gumbel_noise) / temperature
328
+ y = F.softmax(y, dim=-1) # (batch_size, vocab_size)
329
+
330
+ # Compute variance
331
+ variance = torch.var(y, dim=-1) # (batch_size)
332
+ variances.append(variance)
333
+
334
+ # Get token indices (argmax for hard selection)
335
+ next_tokens = torch.argmax(y, dim=-1, keepdim=True) # (batch_size, 1)
336
+ tgt_seq = torch.cat([tgt_seq, next_tokens], dim=1)
337
+
338
+ # Stack entropies and variances
339
+ entropies = torch.stack(entropies, dim=1) # (batch_size, max_length)
340
+ variances = torch.stack(variances, dim=1) # (batch_size, max_length)
341
+
342
+ return tgt_seq[:, 1:], entropies, variances # Exclude the initial <sos> token
343
+
344
+
345
+ # Objective Functions
346
+
347
+ class InfoNCE_Loss(nn.Module):
348
+ def __init__(self, temperature=0.07):
349
+ super(InfoNCE_Loss, self).__init__()
350
+ self.temperature = temperature
351
+ self.cross_entropy = nn.CrossEntropyLoss()
352
+
353
+ def forward(self, z_i, z_j):
354
+ """
355
+ Args:
356
+ z_i (torch.Tensor): Flattened representations from view i, shape (2n, embed_dim)
357
+ z_j (torch.Tensor): Flattened representations from view j, shape (2n, embed_dim)
358
+
359
+ Returns:
360
+ torch.Tensor: InfoNCE loss
361
+ """
362
+ n = z_i.size(0)
363
+ z = torch.cat([z_i, z_j], dim=0) # Shape: (2n, embed_dim)
364
+
365
+ z = F.normalize(z, dim=1)
366
+ similarity_matrix = torch.matmul(z, z.T) # Shape: (2n, 2n)
367
+
368
+ # Create a mask to exclude self-similarity
369
+ mask = torch.eye(2 * n, device=z.device, dtype=torch.bool)
370
+ similarity_matrix = similarity_matrix.masked_fill(mask, -1e4) # Use a manageable negative value
371
+
372
+ # Create labels for contrastive learning
373
+ labels = torch.arange(n, device=z.device)
374
+ labels = torch.cat([labels + n, labels], dim=0) # Shape: (2n,)
375
+
376
+ # Apply temperature scaling
377
+ similarity_matrix /= self.temperature
378
+
379
+ # Compute cross-entropy loss
380
+ loss = self.cross_entropy(similarity_matrix, labels)
381
+ return loss
382
+
383
+
384
+
385
+ class CovarianceRegularization(nn.Module):
386
+ def __init__(self, lambda_reg=1e-3):
387
+ super(CovarianceRegularization, self).__init__()
388
+ self.lambda_reg = lambda_reg
389
+
390
+ def forward(self, embeddings):
391
+ """
392
+ Args:
393
+ embeddings (torch.Tensor): Embedding tensor, shape (batch_size, embed_dim)
394
+
395
+ Returns:
396
+ torch.Tensor: Covariance regularization loss
397
+ """
398
+ batch_size, embed_dim = embeddings.size()
399
+ mean = embeddings.mean(dim=0)
400
+ embeddings_centered = embeddings - mean
401
+ cov = (embeddings_centered.T @ embeddings_centered) / (batch_size - 1)
402
+ cov_loss = torch.sum(cov ** 2) - torch.sum(torch.diag(cov) ** 2)
403
+ return self.lambda_reg * cov_loss
404
+
405
+
406
+ class DynamicsPerformanceLoss(nn.Module):
407
+ def __init__(self, lambda_var=1e-3):
408
+ super(DynamicsPerformanceLoss, self).__init__()
409
+ self.lambda_var = lambda_var
410
+
411
+ def forward(self, true_next_state, predicted_next_state):
412
+ """
413
+ Args:
414
+ true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
415
+ predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim)
416
+
417
+ Returns:
418
+ torch.Tensor: Dynamics performance loss
419
+ """
420
+ mse_loss = F.mse_loss(predicted_next_state, true_next_state)
421
+ variance_loss = torch.var(predicted_next_state, dim=0).mean()
422
+ return mse_loss + self.lambda_var * variance_loss
423
+
424
+
425
+ class ThoughtConsistencyLoss(nn.Module):
426
+ def __init__(self):
427
+ super(ThoughtConsistencyLoss, self).__init__()
428
+
429
+ def forward(self, true_next_state, perturbed_next_state):
430
+ """
431
+ Args:
432
+ true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
433
+ perturbed_next_state (torch.Tensor): Perturbed next state, shape (batch_size, state_dim)
434
+
435
+ Returns:
436
+ torch.Tensor: Thought-consistency loss
437
+ """
438
+ return F.mse_loss(true_next_state, perturbed_next_state)
439
+
440
+
441
+ class PolicyValueJointLoss(nn.Module):
442
+ def __init__(self, lambda_value=0.5):
443
+ super(PolicyValueJointLoss, self).__init__()
444
+ self.lambda_value = lambda_value
445
+ self.cross_entropy = nn.CrossEntropyLoss()
446
+ self.mse_loss = nn.MSELoss()
447
+
448
+ def forward(self, policy_logits, true_policy, value_pred, true_value):
449
+ """
450
+ Args:
451
+ policy_logits (torch.Tensor): Logits from the policy network, shape (batch_size * seq_len, num_actions)
452
+ true_policy (torch.Tensor): Ground truth policy, shape (batch_size * seq_len, num_actions)
453
+ value_pred (torch.Tensor): Predicted values, shape (batch_size * seq_len)
454
+ true_value (torch.Tensor): Ground truth values, shape (batch_size * seq_len)
455
+
456
+ Returns:
457
+ torch.Tensor: Combined policy and value loss
458
+ """
459
+ policy_logits = policy_logits.view(-1, policy_logits.size(-1))
460
+ true_policy = true_policy.view(-1, true_policy.size(-1))
461
+ value_pred = value_pred.view(-1)
462
+ true_value = true_value.view(-1)
463
+
464
+ policy_loss = self.cross_entropy(policy_logits, true_policy.argmax(dim=1))
465
+ value_loss = self.mse_loss(value_pred, true_value)
466
+ return policy_loss + self.lambda_value * value_loss
467
+
468
+
469
+
470
+ class ActionDiversityReward(nn.Module):
471
+ def __init__(self, lambda_div=1e-3):
472
+ super(ActionDiversityReward, self).__init__()
473
+ self.lambda_div = lambda_div
474
+
475
+ def forward(self, action_embeddings):
476
+ """
477
+ Args:
478
+ action_embeddings (torch.Tensor): Embeddings of actions, shape (batch_size, embed_dim)
479
+
480
+ Returns:
481
+ torch.Tensor: Action diversity loss
482
+ """
483
+ similarity_matrix = F.cosine_similarity(action_embeddings.unsqueeze(1), action_embeddings.unsqueeze(0), dim=2)
484
+ # Zero out self-similarity
485
+ similarity_matrix = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(action_embeddings.device)
486
+ diversity_loss = torch.sum(similarity_matrix ** 2)
487
+ return self.lambda_div * diversity_loss
488
+
489
+
490
+ class ExpectedThoughtValueLoss(nn.Module):
491
+ def __init__(self):
492
+ super(ExpectedThoughtValueLoss, self).__init__()
493
+
494
+ def forward(self, mcts_best_values):
495
+ """
496
+ Args:
497
+ mcts_best_values (torch.Tensor): Best values from MCTS, shape (batch_size)
498
+
499
+ Returns:
500
+ torch.Tensor: ETV loss
501
+ """
502
+ return -mcts_best_values.mean()
503
+
504
+
505
+ class ExplorationRegularization(nn.Module):
506
+ def __init__(self, lambda_expl=1e-3):
507
+ super(ExplorationRegularization, self).__init__()
508
+ self.lambda_expl = lambda_expl
509
+
510
+ def forward(self, visit_counts):
511
+ """
512
+ Args:
513
+ visit_counts (torch.Tensor): Visit counts for actions, shape (batch_size, num_actions)
514
+
515
+ Returns:
516
+ torch.Tensor: Exploration regularization loss
517
+ """
518
+ reward = torch.sum(1.0 / (visit_counts + 1), dim=-1)
519
+ return self.lambda_expl * reward.mean()
520
+
521
+
522
+ class KL_DivergenceLoss(nn.Module):
523
+ def __init__(self):
524
+ super(KL_DivergenceLoss, self).__init__()
525
+
526
+ def forward(self, old_policy, new_policy):
527
+ """
528
+ Args:
529
+ old_policy (torch.Tensor): Old policy probabilities, shape (batch_size, num_actions)
530
+ new_policy (torch.Tensor): New policy probabilities, shape (batch_size, num_actions)
531
+
532
+ Returns:
533
+ torch.Tensor: KL divergence loss
534
+ """
535
+ kl_div = F.kl_div(new_policy.log(), old_policy, reduction='batchmean')
536
+ return kl_div
537
+
538
+ # MuZero
539
+
540
+ class ActionEncoder(nn.Module):
541
+ def __init__(self, vocab_size, embed_dim):
542
+ super(ActionEncoder, self).__init__()
543
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
544
+
545
+ def forward(self, action_sequences):
546
+ """
547
+ Args:
548
+ action_sequences (torch.Tensor): Tensor of shape (batch_size, seq_len)
549
+
550
+ Returns:
551
+ torch.Tensor: Encoded actions of shape (batch_size, seq_len, embed_dim)
552
+ """
553
+ return self.embedding(action_sequences) #.half() # Convert to half-precision
554
+
555
+ class RepresentationNetwork(nn.Module):
556
+ def __init__(self, vocab_dim, d_model, state_dim):
557
+ super(RepresentationNetwork, self).__init__()
558
+ self.proj = nn.Linear(vocab_dim, d_model) # Project from vocab_dim to d_model
559
+ self.linear = nn.Linear(d_model, state_dim) # Project from d_model to state_dim
560
+ self.norm = nn.LayerNorm(state_dim)
561
+
562
+ def forward(self, transformer_output):
563
+ """
564
+ Args:
565
+ transformer_output (torch.Tensor): Shape (batch_size, seq_len, vocab_dim)
566
+
567
+ Returns:
568
+ torch.Tensor: Encoded state of shape (batch_size, seq_len, state_dim)
569
+ """
570
+ # First project down from vocab_dim to d_model
571
+ projected_output = self.proj(transformer_output)
572
+ # Then project down from d_model to state_dim
573
+ state = self.linear(projected_output)
574
+ state = self.norm(state)
575
+ return state
576
+
577
+
578
+ class DynamicsNetwork(nn.Module):
579
+ def __init__(self, state_dim, action_dim, hidden_dim):
580
+ super(DynamicsNetwork, self).__init__()
581
+ self.rms_norm = nn.LayerNorm(state_dim)
582
+ self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
583
+ self.activation = nn.GELU()
584
+ self.fc2 = nn.Linear(hidden_dim, state_dim)
585
+
586
+ def forward(self, state, action):
587
+ """
588
+ Args:
589
+ state (torch.Tensor): Current state, shape (batch_size, seq_len, state_dim)
590
+ action (torch.Tensor): Action embedding, shape (batch_size, seq_len, action_dim)
591
+
592
+ Returns:
593
+ torch.Tensor: Predicted next state, shape (batch_size, seq_len, state_dim)
594
+ """
595
+ norm_state = self.rms_norm(state)
596
+ combined = torch.cat([norm_state, action], dim=-1)
597
+ hidden = self.activation(self.fc1(combined))
598
+ next_state = self.fc2(hidden)
599
+ return next_state
600
+
601
+ class PredictionNetwork(nn.Module):
602
+ def __init__(self, state_dim, policy_dim, value_dim):
603
+ super(PredictionNetwork, self).__init__()
604
+ self.state_dim = state_dim
605
+ self.rms_norm = nn.LayerNorm(state_dim)
606
+ self.policy_head = nn.Linear(state_dim, policy_dim)
607
+ self.value_head = nn.Linear(state_dim, value_dim)
608
+
609
+ def forward(self, state):
610
+ """
611
+ Args:
612
+ state (torch.Tensor): Predicted state, shape (batch_size, seq_len, state_dim)
613
+
614
+ Returns:
615
+ Tuple[torch.Tensor, torch.Tensor]: Policy logits and value estimates
616
+ """
617
+ norm_state = self.rms_norm(state)
618
+ policy_logits = self.policy_head(norm_state)
619
+ value_estimates = self.value_head(norm_state)
620
+ return policy_logits, value_estimates
621
+
622
+
623
+ class MCTSNode:
624
+ def __init__(self, state, parent=None, action=None):
625
+ """
626
+ Initialize an MCTS node.
627
+
628
+ Args:
629
+ state (State): The current state representation.
630
+ parent (MCTSNode, optional): The parent node. Defaults to None.
631
+ action (int, optional): The action taken to reach this node. Defaults to None.
632
+ """
633
+ self.state = state # Instance of State class
634
+ self.parent = parent # Parent MCTSNode
635
+ self.action = action # Action taken to reach this node
636
+ self.children = {} # Dict mapping actions to MCTSNode
637
+ self.visit_count = 0
638
+ self.value_sum = 0.0
639
+ self.prior = 0.0 # Prior probability from policy network
640
+
641
+ def expand(self, actions, priors):
642
+ """
643
+ Expand the node with possible actions and their priors.
644
+
645
+ Args:
646
+ actions (list): List of possible actions (action indices).
647
+ priors (list): List of prior probabilities corresponding to actions.
648
+ """
649
+ for action, prior in zip(actions, priors):
650
+ if action not in self.children:
651
+ child_state = self.state.apply_action(action) # Apply action to get new state
652
+ child_node = MCTSNode(state=child_state, parent=self, action=action)
653
+ child_node.prior = float(prior) # Ensure that prior is a float value
654
+ self.children[action] = child_node
655
+
656
+ def is_leaf(self):
657
+ """
658
+ Check if the node is a leaf node (i.e., has no children).
659
+
660
+ Returns:
661
+ bool: True if leaf, False otherwise.
662
+ """
663
+ return len(self.children) == 0
664
+
665
+ def ucb_score(self, total_visits, exploration_constant=math.sqrt(2)):
666
+ """
667
+ Calculate the UCB (Upper Confidence Bound) score for the node.
668
+
669
+ Args:
670
+ total_visits (int): Total number of visits to the parent node.
671
+ exploration_constant (float, optional): Exploration parameter. Defaults to math.sqrt(2).
672
+
673
+ Returns:
674
+ float: The UCB score.
675
+ """
676
+ if self.visit_count == 0:
677
+ return float('inf')
678
+ average_value = self.value_sum / self.visit_count
679
+ exploration_term = exploration_constant * self.prior * math.sqrt(total_visits) / (1 + self.visit_count)
680
+ return average_value + exploration_term
681
+
682
+ class MCTS:
683
+ def __init__(self, prediction_network, dynamics_network, action_encoder, num_iterations=10, exploration_constant=math.sqrt(2)):
684
+ """
685
+ Initialize the MCTS.
686
+
687
+ Args:
688
+ prediction_network (nn.Module): The Prediction Network.
689
+ dynamics_network (nn.Module): The Dynamics Network.
690
+ num_iterations (int): Number of MCTS iterations per search.
691
+ exploration_constant (float): Exploration parameter for UCB.
692
+ """
693
+ self.action_encoder = action_encoder
694
+ self.prediction_network = prediction_network
695
+ self.dynamics_network = dynamics_network
696
+ self.num_iterations = num_iterations
697
+ self.exploration_constant = exploration_constant
698
+
699
+ def search(self, root_state):
700
+ """
701
+ Perform MCTS starting from the root_state.
702
+
703
+ Args:
704
+ root_state: The initial state from which to start MCTS.
705
+
706
+ Returns:
707
+ The best action determined by MCTS.
708
+ """
709
+ self.root = MCTSNode(state=root_state)
710
+
711
+ for _ in range(self.num_iterations):
712
+ node = self.select(self.root)
713
+ value = self.evaluate(node)
714
+ self.backpropagate(node, value)
715
+
716
+ return self.best_action()
717
+
718
+ def select(self, node):
719
+ """
720
+ Traverse the tree to select a node for evaluation.
721
+
722
+ Args:
723
+ node: The starting node for selection.
724
+
725
+ Returns:
726
+ The node selected for evaluation.
727
+ """
728
+ while not node.is_leaf():
729
+ best_action, best_node = max(node.children.items(),
730
+ key=lambda item: item[1].ucb_score(node.visit_count, self.exploration_constant))
731
+ node = best_node
732
+ return node
733
+
734
+ def evaluate(self, node):
735
+ """
736
+ Evaluate the node by expanding it and predicting its value.
737
+
738
+ Args:
739
+ node: The node to evaluate.
740
+
741
+ Returns:
742
+ The value estimate of the node.
743
+ """
744
+ # Use the prediction network to get policy logits and value estimate
745
+ policy_logits, value_estimate = self.prediction_network(node.state.representation)
746
+
747
+ # Convert logits to probabilities
748
+ policy = F.softmax(policy_logits, dim=-1).detach().cpu().numpy()
749
+
750
+ # Expand the node with possible actions and their priors
751
+ actions = list(range(policy.shape[-1])) # Assuming actions are indexed from 0 to num_actions-1
752
+ priors = policy[0].flatten().tolist() # Convert to a 1D list of floats
753
+
754
+ node.expand(actions, priors)
755
+
756
+ return value_estimate.mean().item()
757
+
758
+
759
+ def backpropagate(self, node, value):
760
+ """
761
+ Backpropagate the value up the tree.
762
+
763
+ Args:
764
+ node: The node to start backpropagation from.
765
+ value (float): The value to backpropagate.
766
+ """
767
+ while node is not None:
768
+ node.visit_count += 1
769
+ node.value_sum += value
770
+ node = node.parent
771
+
772
+ def best_action(self):
773
+ """
774
+ Choose the action with the highest visit count.
775
+
776
+ Returns:
777
+ The best action.
778
+ """
779
+ best_child = max(self.root.children.values(), key=lambda n: n.visit_count)
780
+ return best_child.action
781
+
782
+ class State:
783
+ def __init__(self, representation, dynamics_network, action_encoder):
784
+ """
785
+ Initialize the State.
786
+
787
+ Args:
788
+ representation (torch.Tensor): Encoded state representation, shape (batch_size, seq_len, state_dim)
789
+ dynamics_network (nn.Module): The Dynamics Network to predict next states
790
+ action_encoder (nn.Module): The Action Encoder to encode actions
791
+ """
792
+ self.representation = representation # Shape: (batch_size, seq_len, state_dim)
793
+ self.dynamics_network = dynamics_network # Reference to Dynamics Network
794
+ self.action_encoder = action_encoder
795
+
796
+ def apply_action(self, action):
797
+ """
798
+ Apply an action to the current state to get a new state.
799
+
800
+ Args:
801
+ action (int): The action to apply (e.g., token index)
802
+
803
+ Returns:
804
+ State: The new state after applying the action
805
+ """
806
+ # Create action sequence filled with action index
807
+ batch_size, seq_len, _ = self.representation.size()
808
+ action_sequence = torch.full((batch_size, seq_len), action, dtype=torch.long, device=self.representation.device)
809
+ # Encode action
810
+ action_embedding = self.action_encoder(action_sequence)
811
+ # Predict the next state using the Dynamics Network
812
+ with torch.no_grad():
813
+ next_state_representation = self.dynamics_network(self.representation, action_embedding)
814
+ return State(next_state_representation, self.dynamics_network, self.action_encoder)
815
+
816
+
817
+
818
+
819
+ class PPOAgent:
820
+ def __init__(self, policy_network, optimizer, clip_epsilon=0.2, entropy_coef=0.01, value_coef=0.5):
821
+ self.policy_network = policy_network
822
+ self.optimizer = optimizer
823
+ self.clip_epsilon = clip_epsilon
824
+ self.entropy_coef = entropy_coef
825
+ self.value_coef = value_coef
826
+
827
+ def compute_loss(self, states, old_log_probs, actions, returns, advantages):
828
+ # Get policy logits and value estimates
829
+ policy_logits, value_estimates = self.policy_network(states)
830
+ batch_size, seq_len, num_actions = policy_logits.size()
831
+
832
+ # Flatten tensors
833
+ policy_logits = policy_logits.view(-1, num_actions) # Shape: (batch_size * seq_len, num_actions)
834
+ value_estimates = value_estimates.view(-1) # Shape: (batch_size * seq_len)
835
+ actions = actions.view(-1) # Shape: (batch_size * seq_len)
836
+ old_log_probs = old_log_probs.view(-1) # Shape: (batch_size * seq_len)
837
+ returns = returns.view(-1) # Shape: (batch_size * seq_len)
838
+ advantages = advantages.view(-1) # Shape: (batch_size * seq_len)
839
+
840
+ # Compute new log probabilities
841
+ new_log_probs_all = F.log_softmax(policy_logits, dim=-1) # Shape: (batch_size * seq_len, num_actions)
842
+ new_log_probs = new_log_probs_all.gather(1, actions.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size * seq_len)
843
+
844
+ # Compute ratios
845
+ ratios = torch.exp(new_log_probs - old_log_probs)
846
+
847
+ # PPO surrogate loss
848
+ surr1 = ratios * advantages
849
+ surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
850
+ policy_loss = -torch.min(surr1, surr2).mean()
851
+
852
+ # Value loss
853
+ value_loss = F.mse_loss(value_estimates, returns)
854
+
855
+ # Entropy loss
856
+ entropy = -(new_log_probs * torch.exp(new_log_probs)).mean()
857
+
858
+ # Total loss
859
+ total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
860
+ return total_loss
861
+
862
+
863
+
864
+
865
+ def compute_loss_world_model(predicted_next_state, true_next_state, policy_logits, true_policy, value_estimates, true_value,
866
+ alpha, beta, temperature, lambda_reg, lambda_var, lambda_div, lambda_expl):
867
+ """
868
+ Compute the combined loss for the World Model.
869
+
870
+ Args:
871
+ predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim)
872
+ true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
873
+ policy_logits (torch.Tensor): Policy logits, shape (batch_size, num_actions)
874
+ true_policy (torch.Tensor): Ground truth policy, shape (batch_size, num_actions)
875
+ value_estimates (torch.Tensor): Value estimates, shape (batch_size)
876
+ true_value (torch.Tensor): Ground truth value, shape (batch_size)
877
+ alpha (float): Entropy regularization weight
878
+ beta (float): Variance regularization weight
879
+ temperature (float): Temperature parameter
880
+ lambda_reg (float): Covariance regularization weight
881
+ lambda_var (float): Dynamics variance loss weight
882
+ lambda_div (float): Action diversity reward weight
883
+ lambda_expl (float): Exploration regularization weight
884
+
885
+ Returns:
886
+ torch.Tensor: Combined loss
887
+ """
888
+ # Cross-entropy loss
889
+ ce_loss = F.cross_entropy(policy_logits, true_policy.argmax(dim=1))
890
+
891
+ # Entropy loss
892
+ probs = F.softmax(policy_logits / temperature, dim=-1) # (batch_size, num_actions)
893
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size)
894
+ entropy_loss = -alpha * torch.mean(entropy)
895
+
896
+ # Variance loss
897
+ variance = torch.var(probs, dim=-1) # (batch_size)
898
+ variance_loss = -beta * torch.mean(variance)
899
+
900
+ # Covariance Regularization
901
+ cov_reg = CovarianceRegularization(lambda_reg)(predicted_next_state)
902
+
903
+ # Dynamics Performance Loss
904
+ dynamics_loss = DynamicsPerformanceLoss(lambda_var)(true_next_state, predicted_next_state)
905
+
906
+ # Thought-Consistency Loss
907
+ perturbed_next_state = predicted_next_state + torch.randn_like(predicted_next_state) * 0.01
908
+ thought_loss = ThoughtConsistencyLoss()(true_next_state, perturbed_next_state)
909
+
910
+ # Policy-Value Joint Loss
911
+ pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates, true_value)
912
+
913
+ # Action Diversity Reward
914
+ action_embeddings = predicted_next_state # Assuming actions are derived from state
915
+ action_diversity = ActionDiversityReward(lambda_div)(action_embeddings)
916
+
917
+ # Expected Thought Value (ETV) Loss
918
+ # Placeholder: Replace with actual MCTS best values
919
+ mcts_best_values = torch.zeros(value_estimates.size(0)).to(device)
920
+ etv = ExpectedThoughtValueLoss()(mcts_best_values)
921
+
922
+ # Exploration Regularization
923
+ # Placeholder: Replace with actual visit counts
924
+ visit_counts = torch.ones(predicted_next_state.size(0), input_dim).to(device)
925
+ exploration = ExplorationRegularization(lambda_expl)(visit_counts)
926
+
927
+ # KL Divergence Regularization
928
+ # Placeholder: Replace with actual old and new policies
929
+ old_policy = F.softmax(policy_logits.detach(), dim=-1)
930
+ new_policy = F.softmax(policy_logits, dim=-1)
931
+ kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
932
+
933
+ # Total Loss
934
+ total_loss = (
935
+ ce_loss +
936
+ entropy_loss +
937
+ variance_loss +
938
+ cov_reg +
939
+ dynamics_loss +
940
+ thought_loss +
941
+ pv_loss +
942
+ action_diversity +
943
+ etv +
944
+ exploration +
945
+ kl_loss
946
+ )
947
+
948
+ return total_loss
949
+
950
+
951
+ def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
952
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent = world_model_components
953
+ representation_network.train()
954
+ dynamics_network.train()
955
+ prediction_network.train()
956
+ action_encoder.train()
957
+ ppo_agent.policy_network.train()
958
+
959
+ mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=args.mcts_iterations, exploration_constant=args.mcts_exploration_constant)
960
+
961
+ total_loss = 0.0
962
+ optimizer.zero_grad()
963
+ print(f"Starting World Model training epoch with {len(train_loader)} batches...")
964
+
965
+ for i, batch in enumerate(train_loader):
966
+ print(f"Processing batch {i+1}/{len(train_loader)}...")
967
+
968
+ # Ensure batches are on the appropriate device for the Transformer
969
+ src_batch = batch['input_ids'].to('cpu') # Move to CPU for Transformer model
970
+ tgt_batch = batch['labels'].to('cpu') # Move to CPU for Transformer model
971
+
972
+ with autocast(device_type='cuda'):
973
+ print("Forward pass through Transformer (frozen)...")
974
+ with torch.no_grad():
975
+ transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
976
+
977
+ # Move transformer output to the GPU for further processing
978
+ transformer_output = transformer_output.to(device)
979
+
980
+ # Encode actions directly on the GPU
981
+ encoded_actions = action_encoder(tgt_batch[:, :-1].to(device)) # Move labels to GPU for encoding
982
+
983
+ # World Model - Representation
984
+ state_representation = representation_network(transformer_output) # On GPU
985
+
986
+ batch_size, seq_len, _ = state_representation.size()
987
+
988
+ # Initialize list to collect predicted next states for the batch
989
+ predicted_next_states = []
990
+
991
+ # Iterate over each sample in the batch for MCTS
992
+ for b in range(batch_size):
993
+ # Create a State instance for the current sample
994
+ current_state = State(state_representation[b].unsqueeze(0), dynamics_network, action_encoder)
995
+
996
+ # Perform MCTS to find the best action
997
+ best_action = mcts.search(current_state)
998
+
999
+ # Create action sequence filled with best_action
1000
+ action_sequence = torch.full((1, seq_len), best_action, dtype=torch.long, device=device)
1001
+
1002
+ # Get action embedding
1003
+ action_embedding = action_encoder(action_sequence)
1004
+
1005
+ # Apply dynamics network
1006
+ predicted_next_state = dynamics_network(current_state.representation, action_embedding)
1007
+
1008
+ predicted_next_states.append(predicted_next_state)
1009
+
1010
+ # Concatenate predicted next states to form a batch
1011
+ predicted_next_state_batch = torch.cat(predicted_next_states, dim=0)
1012
+
1013
+ # Prediction Network - Policy logits and value
1014
+ policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
1015
+
1016
+
1017
+ # Define true_policy and true_value as placeholders on the GPU
1018
+ true_policy = torch.zeros_like(policy_logits).to(device)
1019
+ true_value = torch.zeros_like(value_estimates).to(device)
1020
+
1021
+ # Compute PPO loss
1022
+ actions = torch.argmax(policy_logits, dim=-1)
1023
+ old_log_probs = torch.zeros_like(actions, dtype=torch.float32).to(device)
1024
+ returns = torch.zeros_like(actions, dtype=torch.float32).to(device)
1025
+ advantages = torch.zeros_like(actions, dtype=torch.float32).to(device)
1026
+
1027
+ # Compute PPO loss using states
1028
+ ppo_loss = ppo_agent.compute_loss(state_representation, old_log_probs, actions, returns, advantages)
1029
+
1030
+ # Compute InfoNCE Loss
1031
+ z_i = state_representation.view(batch_size * seq_len, state_dim) # Shape: (batch_size * seq_len, state_dim)
1032
+ z_j = F.dropout(z_i, p=0.1, training=True)
1033
+ info_nce = InfoNCE_Loss()(z_i, z_j)
1034
+
1035
+ # Compute other losses
1036
+ covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
1037
+ dynamics_loss = DynamicsPerformanceLoss()(torch.zeros_like(predicted_next_state_batch).to(device), predicted_next_state_batch)
1038
+ perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
1039
+ thought_loss = ThoughtConsistencyLoss()(torch.zeros_like(predicted_next_state_batch).to(device), perturbed_next_state)
1040
+ pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
1041
+ action_diversity = ActionDiversityReward()(encoded_actions.view(-1, embed_dim))
1042
+ mcts_best_values = torch.zeros(actions.size(0)).to(device)
1043
+ etv = ExpectedThoughtValueLoss()(mcts_best_values)
1044
+ visit_counts = torch.ones(actions.size(0), policy_logits.size(-1)).to(device)
1045
+ exploration = ExplorationRegularization()(visit_counts)
1046
+ old_policy = F.softmax(policy_logits.detach(), dim=-1)
1047
+ new_policy = F.softmax(policy_logits, dim=-1)
1048
+ kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
1049
+
1050
+ # Total Loss
1051
+ loss = (
1052
+ ppo_loss +
1053
+ info_nce +
1054
+ covariance +
1055
+ dynamics_loss +
1056
+ thought_loss +
1057
+ pv_loss +
1058
+ action_diversity +
1059
+ etv +
1060
+ exploration +
1061
+ kl_loss
1062
+ )
1063
+ loss = loss / args.accumulation_steps
1064
+
1065
+ print("Backward pass...")
1066
+ scaler.scale(loss).backward()
1067
+
1068
+ if (i + 1) % args.accumulation_steps == 0:
1069
+ print("Gradient clipping...")
1070
+ scaler.unscale_(optimizer)
1071
+ torch.nn.utils.clip_grad_norm_(
1072
+ [param for group in optimizer.param_groups for param in group['params']],
1073
+ args.max_grad_norm
1074
+ )
1075
+
1076
+ print("Optimizer step...")
1077
+ scaler.step(optimizer)
1078
+ scaler.update()
1079
+
1080
+ print("Zeroing gradients...")
1081
+ optimizer.zero_grad()
1082
+
1083
+ print("Updating learning rate...")
1084
+ scheduler.step()
1085
+
1086
+ total_loss += loss.item() * args.accumulation_steps
1087
+ print(f"Batch {i+1} completed. Current loss: {loss.item():.4f}")
1088
+
1089
+ avg_loss = total_loss / len(train_loader)
1090
+ print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
1091
+ return avg_loss
1092
+
1093
+
1094
+
1095
+ def evaluate_world_model(world_model_components, model_transformer, eval_loader, args):
1096
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent = world_model_components
1097
+ representation_network.eval()
1098
+ dynamics_network.eval()
1099
+ prediction_network.eval()
1100
+ action_encoder.eval()
1101
+ ppo_agent.policy_network.eval()
1102
+
1103
+ total_loss = 0.0
1104
+ with torch.no_grad():
1105
+ for batch in eval_loader:
1106
+ src_batch = batch['input_ids'].to(device)
1107
+ tgt_batch = batch['labels'].to(device)
1108
+
1109
+ # Forward pass through Transformer (on CPU)
1110
+ transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
1111
+
1112
+ # Encode actions
1113
+ encoded_actions = action_encoder(tgt_batch[:, :-1].to(device)) # Move to GPU if necessary
1114
+
1115
+ # World Model - Representation
1116
+ state = representation_network(transformer_output.to(device))
1117
+
1118
+ # Dynamics Network - Predict next state
1119
+ predicted_next_state = dynamics_network(state, encoded_actions)
1120
+
1121
+ # Prediction Network - Policy logits and value
1122
+ policy_logits, value_estimates = prediction_network(predicted_next_state)
1123
+
1124
+ # Placeholder: Define true_policy and true_value
1125
+ # Replace these with actual targets from your environment or dataset
1126
+ true_policy = torch.zeros_like(policy_logits).to(device)
1127
+ true_value = torch.zeros_like(value_estimates).to(device)
1128
+
1129
+ # Compute PPO loss
1130
+ # Placeholder: Replace with actual old_log_probs, actions, returns, and advantages
1131
+ old_log_probs = torch.zeros_like(policy_logits).to(device)
1132
+ actions = torch.argmax(policy_logits, dim=-1)
1133
+ returns = torch.zeros(actions.size(0)).to(device)
1134
+ advantages = torch.zeros(actions.size(0)).to(device)
1135
+
1136
+ ppo_loss = ppo_agent.compute_loss(old_log_probs, actions, returns, advantages)
1137
+
1138
+ # Compute other losses
1139
+ info_nce = InfoNCE_Loss()(state, state) # Placeholder: replace with actual positive pairs
1140
+ covariance = CovarianceRegularization()(predicted_next_state.view(-1, predicted_next_state.size(-1)))
1141
+ dynamics_loss = DynamicsPerformanceLoss()(torch.zeros_like(predicted_next_state).to(device), predicted_next_state)
1142
+ perturbed_next_state = predicted_next_state + torch.randn_like(predicted_next_state) * 0.01
1143
+ thought_loss = ThoughtConsistencyLoss()(torch.zeros_like(predicted_next_state).to(device), perturbed_next_state)
1144
+ pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
1145
+ action_diversity = ActionDiversityReward()(encoded_actions.view(-1, encoded_actions.size(-1)))
1146
+ mcts_best_values = torch.zeros(actions.size(0)).to(device) # Placeholder: replace with actual MCTS values
1147
+ etv = ExpectedThoughtValueLoss()(mcts_best_values)
1148
+ visit_counts = torch.ones(actions.size(0), policy_logits.size(-1)).to(device) # Placeholder: replace with actual visit counts
1149
+ exploration = ExplorationRegularization()(visit_counts)
1150
+ old_policy = F.softmax(policy_logits.detach(), dim=-1)
1151
+ new_policy = F.softmax(policy_logits, dim=-1)
1152
+ kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
1153
+
1154
+ # Total Loss
1155
+ loss = (
1156
+ ppo_loss +
1157
+ info_nce +
1158
+ covariance +
1159
+ dynamics_loss +
1160
+ thought_loss +
1161
+ pv_loss +
1162
+ action_diversity +
1163
+ etv +
1164
+ exploration +
1165
+ kl_loss
1166
+ )
1167
+
1168
+ total_loss += loss.item()
1169
+
1170
+ avg_loss = total_loss / len(eval_loader)
1171
+ print(f"World Model evaluation completed. Average loss: {avg_loss:.4f}")
1172
+ return avg_loss
1173
+
1174
+
1175
+ def main():
1176
+ args = parse_args()
1177
+ print("Arguments parsed successfully.")
1178
+
1179
+ # Create save directory
1180
+ if not os.path.exists(args.save_dir):
1181
+ os.makedirs(args.save_dir)
1182
+ print(f"Save directory created: {args.save_dir}")
1183
+
1184
+ # Load tokenizer
1185
+ print("Loading tokenizer...")
1186
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
1187
+ if tokenizer.pad_token is None:
1188
+ tokenizer.pad_token = tokenizer.eos_token
1189
+ print("Tokenizer loaded successfully.")
1190
+
1191
+ # Define padding_idx and input dimension based on tokenizer
1192
+ padding_idx = tokenizer.pad_token_id
1193
+ input_dim = len(tokenizer)
1194
+
1195
+ # Load data
1196
+ print("Loading and preprocessing data...")
1197
+ train_loader, eval_loader = load_data(args, tokenizer)
1198
+ print("Data loaded and preprocessed successfully.")
1199
+
1200
+ # Define model parameters
1201
+ d_model = 512 # half to save space
1202
+ num_heads = 8
1203
+ num_layers = 6
1204
+ d_ff = 2048
1205
+ num_experts = 4
1206
+ output_dim = input_dim
1207
+ dropout = 0.1
1208
+ top_k = 2
1209
+ state_dim = 128
1210
+ action_dim = d_model
1211
+ hidden_dim = 512
1212
+ vocab_dim = len(tokenizer)
1213
+ # Initialize and load the Transformer model (on CPU)
1214
+ print("Initializing and loading Transformer model...")
1215
+ model_transformer = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout, top_k)
1216
+ model_transformer.load_state_dict(torch.load(args.transformer_model_path, map_location='cpu'))
1217
+ model_transformer.eval()
1218
+ model_transformer.to('cpu')
1219
+ print("Transformer model loaded and moved to CPU.")
1220
+
1221
+ # Define World Model components
1222
+ representation_network = RepresentationNetwork(vocab_dim, d_model, state_dim).to(device)
1223
+ dynamics_network = DynamicsNetwork(state_dim, action_dim, hidden_dim).to(device)
1224
+ prediction_network = PredictionNetwork(state_dim, input_dim, 1).to(device)
1225
+ action_encoder = ActionEncoder(input_dim, action_dim).to(device)
1226
+
1227
+ # Define Optimizers and Schedulers
1228
+ optimizer = optim.AdamW(
1229
+ list(representation_network.parameters()) +
1230
+ list(dynamics_network.parameters()) +
1231
+ list(prediction_network.parameters()) +
1232
+ list(action_encoder.parameters()),
1233
+ lr=args.learning_rate, weight_decay=args.weight_decay
1234
+ )
1235
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs)
1236
+ scaler = GradScaler()
1237
+
1238
+ # Initialize PPO Agent
1239
+ ppo_agent = PPOAgent(
1240
+ policy_network=prediction_network,
1241
+ optimizer=optim.AdamW(prediction_network.parameters(), lr=args.learning_rate),
1242
+ clip_epsilon=0.2,
1243
+ entropy_coef=0.01,
1244
+ value_coef=0.5
1245
+ )
1246
+
1247
+ # Bundle World Model components
1248
+ world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent)
1249
+
1250
+ print("Setup complete. Starting training...")
1251
+
1252
+ for epoch in range(args.num_epochs):
1253
+ print(f"Epoch {epoch + 1}/{args.num_epochs} started.")
1254
+
1255
+ # Train World Model
1256
+ avg_train_loss = train_epoch_world_model(
1257
+ world_model_components,
1258
+ train_loader,
1259
+ optimizer,
1260
+ scheduler,
1261
+ scaler,
1262
+ args,
1263
+ model_transformer,
1264
+ state_dim,
1265
+ d_model, # this is the embedding dimension
1266
+ input_dim
1267
+ )
1268
+
1269
+ print(f"World Model training epoch {epoch + 1} completed. Average loss: {avg_train_loss:.4f}")
1270
+
1271
+ # Evaluate World Model
1272
+ avg_eval_loss = evaluate_world_model(
1273
+ world_model_components,
1274
+ model_transformer,
1275
+ eval_loader,
1276
+ args
1277
+ )
1278
+ print(f"Evaluation for epoch {epoch + 1} completed. Average loss: {avg_eval_loss:.4f}")
1279
+
1280
+ print(f"Epoch {epoch + 1}/{args.num_epochs}, Train Loss: {avg_train_loss:.4f}, Eval Loss: {avg_eval_loss:.4f}")
1281
+
1282
+ # Save Models
1283
+ save_all_models(model_transformer, representation_network, dynamics_network, prediction_network, action_encoder, args.save_dir, epoch + 1)
1284
+ print(f"Models saved for epoch {epoch + 1}")
1285
+
1286
+ print("Training completed.")
1287
+
1288
+
1289
+ if __name__ == '__main__':
1290
+ main()
1291
+
1292
+