#include #include #include // Define device torch::Device device(torch::kCUDA); // Define constants const int batch_size = 8; const int block_size = 32; const int max_iters = 1000; const int eval_interval = 50; const int eval_iters = 5; const int d_model = 256; const int n_layer = 16; const int n_head = 12; const float dropout = 0.2; const float norm_eps = 1e-5; const int vocab_size = 5; // sample data torch::Tensor train_data = torch::rand({1000, block_size}); torch::Tensor val_data = torch::rand({500, block_size}); // Data loading function std::pair get_batch(const std::string& split) { torch::Tensor data = (split == "train") ? train_data : val_data; torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size}); torch::Tensor x = torch::empty({batch_size, block_size}); torch::Tensor y = torch::empty({batch_size, block_size}); for (int i = 0; i < batch_size; ++i) { x[i] = data.index({ix[i], ix[i] + block_size}); y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1}); } return std::make_pair(x.to(device), y.to(device)); } // Custom classes and functions class SWiGLU : public torch::nn::Module { public: SWiGLU() {} torch::Tensor forward(torch::Tensor x) { torch::Tensor sigmoid_output = torch::sigmoid(x); torch::Tensor relu_output = torch::relu(x); torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x; return out; } }; class UnMaskedHeadImpl : public torch::nn::Module { public: UnMaskedHeadImpl(int d_model, int head_size, float dropout) : key(register_module("key", torch::nn::Linear(d_model, head_size))), query(register_module("query", torch::nn::Linear(d_model, head_size))), value(register_module("value", torch::nn::Linear(d_model, head_size))), dropout(torch::nn::Dropout(dropout)) { register_module("dropout", dropout); } torch::Tensor forward(torch::Tensor x) { torch::Tensor key_out = key->forward(x); torch::Tensor query_out = query->forward(x); torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1)); weights = torch::softmax(weights, -1); weights = dropout(weights); torch::Tensor value_out = value->forward(x); torch::Tensor out = weights.matmul(value_out); return out; } private: torch::nn::Linear key, query, value; torch::nn::Dropout dropout; }; TORCH_MODULE(UnMaskedHead); class MaskedHeadImpl : public torch::nn::Module { public: MaskedHeadImpl(int head_size, float dropout, int d_model) : key(register_module("key", torch::nn::Linear(d_model, head_size))), query(register_module("query", torch::nn::Linear(d_model, head_size))), value(register_module("value", torch::nn::Linear(d_model, head_size))), dropout(torch::nn::Dropout(dropout)) { register_buffer("tril", torch::tril(torch::ones(block_size, block_size))); } torch::Tensor forward(torch::Tensor x) { torch::Tensor key_out = key->forward(x); torch::Tensor query_out = query->forward(x); torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1)); weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits::lowest()); weights = torch::softmax(weights, -1); weights = dropout(weights); torch::Tensor value_out = value->forward(x); torch::Tensor out = weights.matmul(value_out); return out; } private: torch::nn::Linear key, query, value; torch::nn::Dropout dropout; torch::Tensor tril; }; TORCH_MODULE(MaskedHead); class MultiUnMaskedImpl : public torch::nn::Module { public: MultiUnMaskedImpl(int d_model, int n_head, float dropout) : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))), dropout(torch::nn::Dropout(dropout)) { for (int i = 0; i < n_head; ++i) { heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout))); } } torch::Tensor forward(torch::Tensor x) { std::vector head_outputs; for (auto& head : heads) { head_outputs.push_back(head->forward(x)); } torch::Tensor out = torch::cat(head_outputs, -1); out = dropout(out); out = proj(out); return out; } private: torch::nn::Linear proj; torch::nn::Dropout dropout; std::vector heads; }; TORCH_MODULE(MultiUnMasked); class MultiMaskedImpl : public torch::nn::Module { public: MultiMaskedImpl(int d_model, int n_head, float dropout) : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))), dropout(torch::nn::Dropout(dropout)) { for (int i = 0; i < n_head; ++i) { heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout))); } } torch::Tensor forward(torch::Tensor x) { std::vector head_outputs; for (auto& head : heads) { head_outputs.push_back(head->forward(x)); } torch::Tensor out = torch::cat(head_outputs, -1); out = dropout(out); out = proj(out); return out; } private: torch::nn::Linear proj; torch::nn::Dropout dropout; std::vector heads; }; TORCH_MODULE(MultiMasked); class FeedForwardImpl : public torch::nn::Module { public: FeedForwardImpl(int d_model, float dropout) : net(register_module("net", torch::nn::Sequential( torch::nn::Linear(d_model, 4 * d_model), torch::nn::GELU(), torch::nn::Linear(4 * d_model, d_model), torch::nn::Dropout(dropout) ))) {} torch::Tensor forward(torch::Tensor x) { return net->forward(x); } private: torch::nn::Sequential net; }; TORCH_MODULE(FeedForward); class BlockImpl : public torch::nn::Module { public: BlockImpl(int d_model, int n_head, float norm_eps, float dropout) : sa_masked(MultiMasked(d_model, n_head, dropout)), sa_unmasked(MultiUnMasked(d_model, n_head, dropout)), ffwd(FeedForward(d_model, dropout)), norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))), norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {} torch::Tensor forward(torch::Tensor x) { torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x)); x = x2 + ffwd->forward(norm2->forward(x2)); x2 = x + sa_masked->forward(norm1->forward(x)); x = x2 + ffwd->forward(norm2->forward(x2)); return x; } private: MultiMasked sa_masked; MultiUnMasked sa_unmasked; FeedForward ffwd; torch::nn::LayerNorm norm1, norm2; }; TORCH_MODULE(Block); class EnigmaImpl : public torch::nn::Module { public: EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps) : toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))), pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))), norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))), linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) { for (int i = 0; i < n_layer; ++i) { block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout))); } register_buffer("block_size", torch::tensor(block_size)); _init_weights(this); } void _init_weights(torch::nn::Module* module) { auto parameters = module->named_parameters(); for (auto& param : parameters) { if (param.key().find("weight") != std::string::npos) { torch::nn::init::normal_(param.value(), 0.0, 0.02); } else if (param.key().find("bias") != std::string::npos) { torch::nn::init::zeros_(param.value()); } } } std::pair forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) { torch::Tensor toked_model_out = toked_model->forward(idx); torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1))); torch::Tensor x = toked_model_out + pos_encod_out; for (auto& block : block_layers) { x = block->forward(x); } torch::Tensor logits = linear_final->forward(norm_final->forward(x)); if (!targets.numel()) { return {logits, torch::Tensor()}; } else { logits = logits.view({-1, logits.size(-1)}); targets = targets.view({-1}); torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets); return {logits, loss}; } } std::vector>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) { std::vector>> completed_beams; torch::Tensor current_idx = idx.clone(); std::vector> beam = {std::make_pair(current_idx, 0.0)}; for (int i = 0; i < max_new_tokens; ++i) { std::vector> new_beam; for (auto& beam_item : beam) { torch::Tensor& current_idx = beam_item.first; torch::Tensor logits, loss; std::tie(logits, loss) = forward(current_idx); logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions // Apply softmax and temperature torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1); // Top-k sampling if (top_k > 0) { probs = top_k_filtering(probs, top_k); } // Sample from the distribution torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true); for (int j = 0; j < beam_width; ++j) { torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1); torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})})); new_beam.push_back(std::make_pair(new_idx, new_log_prob.item())); } } // Sort new beam by log probabilities std::sort(new_beam.begin(), new_beam.end(), [](const std::pair& a, const std::pair& b) { return a.second > b.second; }); // Only keep top beams beam = std::vector>(new_beam.begin(), new_beam.begin() + beam_width); } completed_beams.push_back(beam); return completed_beams; } std::vector>> top_k_filtering(torch::Tensor logits, int top_k) { torch::Tensor top_values, top_indices; std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1); torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1})); torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits::infinity()), logits); return filtered_logits; } private: torch::nn::Embedding toked_model, pos_encod; std::vector block_layers; torch::nn::LayerNorm norm_final; torch::nn::Linear linear_final; int block_size; }; TORCH_MODULE(Enigma); int main() { // Set seed torch::manual_seed(1400); // Create model Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps); model->to(device); // Define optimizer torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate)); // Training loop std::vector train_losses, val_losses; for (int iter = 0; iter < max_iters; ++iter) { if (iter % eval_interval == 0 || iter == max_iters - 1) { // Evaluate and print losses auto losses = estimate_loss(); std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl; // Save losses for plotting train_losses.push_back(losses["train"]); val_losses.push_back(losses["val"]); } // Get batch, forward pass, loss calculation, backward pass, optimizer step auto [xb, yb] = get_batch("train"); torch::Tensor logits, loss; std::tie(logits, loss) = model->forward(xb, yb); optimizer.zero_grad(); loss.backward(); optimizer.step(); } return 0; }