shivendrra
commited on
Commit
•
484d56b
1
Parent(s):
82ea75a
added model files
Browse files- enigma/EnBERT.py +181 -0
- enigma/TrainEnigma.ipynb +470 -0
- enigma/config_enigma.json +13 -0
- enigma/enigma.cpp +364 -0
- enigma/generate.py +126 -0
- enigma/model.py +388 -0
- enigma/run.py +100 -0
enigma/EnBERT.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
this isn't a bert based model, i just liked the name and named it
|
3 |
+
--> decoder-only model, uses RMS normalization and GELU activation function
|
4 |
+
--> one masked-attention and other unmasked
|
5 |
+
--> attention layers have relational positional-embeddings
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
with open('config.json', 'r', encoding='utf-8') as file:
|
10 |
+
params = json.load(file)
|
11 |
+
|
12 |
+
# required parameters
|
13 |
+
block_size = params['block_size']
|
14 |
+
d_model = params['d_model']
|
15 |
+
n_head = params['n_heads']
|
16 |
+
n_layers = params['n_layers']
|
17 |
+
learning_rate = params['learning_rate']
|
18 |
+
dropout = params['dropout']
|
19 |
+
norm_eps = params['norm_eps']
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
+
class RMSNorm(nn.Module):
|
27 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
28 |
+
super().__init__()
|
29 |
+
self.eps = eps
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
output = self._norm(x.float()).type_as(x)
|
37 |
+
return output * self.weight
|
38 |
+
|
39 |
+
class SingleHead(nn.Module):
|
40 |
+
def __init__(self,
|
41 |
+
head_size: int,
|
42 |
+
d_model: int,
|
43 |
+
block_size: int,
|
44 |
+
dropout: float):
|
45 |
+
super().__init__()
|
46 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
47 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
48 |
+
self.value = nn.Linear(d_model, head_size, bias=True)
|
49 |
+
self.dropout = nn.Dropout(dropout)
|
50 |
+
self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
|
51 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor, mask: bool= False):
|
54 |
+
B, T, C = x.shape
|
55 |
+
key = self.key(x)
|
56 |
+
query = self.query(x)
|
57 |
+
scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
|
58 |
+
|
59 |
+
if mask is True:
|
60 |
+
scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
61 |
+
|
62 |
+
rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
|
63 |
+
scores = scores + rel_pos_scores
|
64 |
+
|
65 |
+
att_mat = F.softmax(scores, dim=-1)
|
66 |
+
att_mat = self.dropout(att_mat)
|
67 |
+
value = self.value(x)
|
68 |
+
output = torch.matmul(att_mat, value)
|
69 |
+
return output
|
70 |
+
|
71 |
+
class MultiHeadAttention(nn.Module):
|
72 |
+
def __init__(self,
|
73 |
+
d_model: int,
|
74 |
+
block_size: int,
|
75 |
+
n_head : int,
|
76 |
+
dropout: float):
|
77 |
+
head_size = d_model // n_head
|
78 |
+
super().__init__()
|
79 |
+
self.heads = nn.ModuleList([SingleHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
|
80 |
+
self.projection = nn.Linear(d_model, d_model)
|
81 |
+
self.dropout = nn.Dropout(dropout)
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor, mask: bool):
|
84 |
+
out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
|
85 |
+
out = self.dropout(self.projection(out))
|
86 |
+
return out
|
87 |
+
|
88 |
+
class FeedForward(nn.Module):
|
89 |
+
def __init__(self, d_model, dropout):
|
90 |
+
super().__init__()
|
91 |
+
self.net = nn.Sequential(
|
92 |
+
nn.Linear(d_model, 5 * d_model),
|
93 |
+
nn.GELU(),
|
94 |
+
nn.Linear(5 * d_model, d_model),
|
95 |
+
nn.Dropout(dropout),
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x: torch.Tensor):
|
99 |
+
return self.net(x)
|
100 |
+
|
101 |
+
class DecoderBlock(nn.Module):
|
102 |
+
def __init__(self, d_model: int,
|
103 |
+
block_size: int,
|
104 |
+
n_head: int,
|
105 |
+
norm_eps: float,
|
106 |
+
dropout: float):
|
107 |
+
super().__init__()
|
108 |
+
self.self_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
109 |
+
self.ffwd = FeedForward(d_model, dropout)
|
110 |
+
self.dropout = nn.Dropout(dropout)
|
111 |
+
self.norm = RMSNorm(d_model, eps=norm_eps)
|
112 |
+
|
113 |
+
def forward(self, x: torch.Tensor):
|
114 |
+
x_out = self.self_att(self.norm(x), mask=True)
|
115 |
+
x_out = x + self.dropout(x_out)
|
116 |
+
del x
|
117 |
+
|
118 |
+
x = self.self_att(self.norm(x_out, mask=False))
|
119 |
+
x = x_out + self.dropout(x)
|
120 |
+
del x_out
|
121 |
+
|
122 |
+
x_out = self.ffwd(self.norm(x))
|
123 |
+
x_out = x + self.dropout(x_out)
|
124 |
+
del x
|
125 |
+
|
126 |
+
return x_out
|
127 |
+
|
128 |
+
class Transformer(nn.Module):
|
129 |
+
def __init__(self, vocab_size: int):
|
130 |
+
super().__init__()
|
131 |
+
self.block_size = block_size
|
132 |
+
self.token_embeddings = nn.Embedding(vocab_size, d_model)
|
133 |
+
self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])
|
134 |
+
self.norm_final = RMSNorm(d_model, eps=norm_eps)
|
135 |
+
self.linear_final = nn.Linear(d_model, vocab_size)
|
136 |
+
self.dropout = nn.Dropout(dropout)
|
137 |
+
self.apply(self._init_weights)
|
138 |
+
|
139 |
+
def _init_weights(self, module):
|
140 |
+
if isinstance(module, nn.Linear):
|
141 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
142 |
+
if module.bias is not None:
|
143 |
+
torch.nn.init.zeros_(module.bias.data)
|
144 |
+
elif isinstance(module, nn.Embedding):
|
145 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
146 |
+
|
147 |
+
def forward(self, idx, targets=None):
|
148 |
+
B, T = idx.shape
|
149 |
+
x = self.token_embeddings(idx)
|
150 |
+
x = self.decoder(x)
|
151 |
+
logits = self.linear_final(self.norm_final(x))
|
152 |
+
|
153 |
+
if targets is None:
|
154 |
+
loss = None
|
155 |
+
|
156 |
+
else:
|
157 |
+
B, T, C = logits.shape
|
158 |
+
logits = logits.view(B*T, C)
|
159 |
+
targets = targets.view(B*T)
|
160 |
+
loss = F.cross_entropy(logits, targets)
|
161 |
+
|
162 |
+
return logits, loss
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
166 |
+
self.eval()
|
167 |
+
for _ in range(max_new_tokens):
|
168 |
+
|
169 |
+
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
|
170 |
+
logits, _ = self(idx_cond)
|
171 |
+
logits = logits[:, -1, :] / temperature
|
172 |
+
|
173 |
+
if top_k is not None:
|
174 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
175 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
176 |
+
|
177 |
+
probs = F.softmax(logits, dim=-1)
|
178 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
179 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
180 |
+
|
181 |
+
return idx
|
enigma/TrainEnigma.ipynb
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "vXIGN6PAuZWg"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"### Train file for enigma model\n",
|
10 |
+
"\n",
|
11 |
+
"- Contains K-mer tokenizer, k=4, can be changed though\n",
|
12 |
+
"- Train data is available on huggingface repo: [hf/engima-1.5b](https://huggingface.co/shivendrra/enigma-1.5b)\n",
|
13 |
+
"- For now, trainig decoder-based model only\n",
|
14 |
+
"- More about this on github repo: [github/enigma-1.5b](https://github.com/shivendrra/enigma-1.5b)\n",
|
15 |
+
"- Saves model after training in '.pth' & '.safetensors' file for later use\n",
|
16 |
+
"- Generate function works fine"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": null,
|
22 |
+
"metadata": {
|
23 |
+
"id": "WXpJBLyr30Rx"
|
24 |
+
},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"from google.colab import drive\n",
|
28 |
+
"drive.mount('/content/drive')"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": null,
|
34 |
+
"metadata": {
|
35 |
+
"id": "r7WUm0VL4bN4"
|
36 |
+
},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"import torch\n",
|
40 |
+
"\n",
|
41 |
+
"# importing the data\n",
|
42 |
+
"file_path = '/content/drive/MyDrive/consolidated_dna.txt'\n",
|
43 |
+
"with open(file_path, 'r', encoding='utf-8') as file:\n",
|
44 |
+
" dna_seq = file.read()\n",
|
45 |
+
"file.close()\n",
|
46 |
+
"\n",
|
47 |
+
"print(f\"{(len(dna_seq)/1e6):.2f} million letters\")"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": null,
|
53 |
+
"metadata": {
|
54 |
+
"id": "Cdhybhz9owTK"
|
55 |
+
},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"import os\n",
|
59 |
+
"from tqdm import tqdm\n",
|
60 |
+
"import json\n",
|
61 |
+
"\n",
|
62 |
+
"class KMerTokenizer:\n",
|
63 |
+
" def __init__(self, k_mers: int=4):\n",
|
64 |
+
" self.k_mers = k_mers\n",
|
65 |
+
" self.vocab = {}\n",
|
66 |
+
" self.id_to_token = []\n",
|
67 |
+
" self.token_to_id = {}\n",
|
68 |
+
"\n",
|
69 |
+
" def tokenize_sequence(self, sequence):\n",
|
70 |
+
" kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc=\"tokenizing k-mers\")]\n",
|
71 |
+
" return kmers\n",
|
72 |
+
"\n",
|
73 |
+
" def build_vocab(self, sequences):\n",
|
74 |
+
" all_kmers = []\n",
|
75 |
+
" for sequence in sequences:\n",
|
76 |
+
" all_kmers.extend(self.tokenize_sequence(sequence))\n",
|
77 |
+
" token_count = {}\n",
|
78 |
+
" for kmer in all_kmers:\n",
|
79 |
+
" if kmer in token_count:\n",
|
80 |
+
" token_count[kmer] += 1\n",
|
81 |
+
" else:\n",
|
82 |
+
" token_count[kmer] = 1\n",
|
83 |
+
" sorted_tokens = sorted(token_count.items(), key=lambda x: x[1], reverse=True)\n",
|
84 |
+
" for token, _ in sorted_tokens:\n",
|
85 |
+
" self.token_to_id[token] = len(self.token_to_id)\n",
|
86 |
+
" self.id_to_token.append(token)\n",
|
87 |
+
" self.vocab = self.token_to_id\n",
|
88 |
+
"\n",
|
89 |
+
" def encode(self, sequence):\n",
|
90 |
+
" encoded_sequence = []\n",
|
91 |
+
" kmers = self.tokenize_sequence(sequence)\n",
|
92 |
+
" for kmer in tqdm(kmers, desc=\"encoding sequences\"):\n",
|
93 |
+
" if kmer in self.token_to_id:\n",
|
94 |
+
" encoded_sequence.append(self.token_to_id[kmer])\n",
|
95 |
+
" else:\n",
|
96 |
+
" encoded_sequence.append(len(self.vocab))\n",
|
97 |
+
" return encoded_sequence\n",
|
98 |
+
"\n",
|
99 |
+
" def decode(self, encoded_sequence):\n",
|
100 |
+
" decoded_sequence = [self.id_to_token[token_id] for token_id in encoded_sequence]\n",
|
101 |
+
" return decoded_sequence\n",
|
102 |
+
"\n",
|
103 |
+
" def save_model(self, model_path):\n",
|
104 |
+
" vocab_file = f\"{model_path}/base_{self.k_mers}k.json\"\n",
|
105 |
+
" with open(vocab_file, 'w') as f:\n",
|
106 |
+
" json.dump(self.vocab, f)\n",
|
107 |
+
"\n",
|
108 |
+
" def load_model(self, path):\n",
|
109 |
+
" assert path.endswith('.json')\n",
|
110 |
+
" with open(path, 'r') as f:\n",
|
111 |
+
" vocab = json.load(f)\n",
|
112 |
+
"\n",
|
113 |
+
" self.vocab = vocab\n",
|
114 |
+
" self.token_to_id = self.vocab\n",
|
115 |
+
" self.vocab_size = len(vocab)"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {
|
122 |
+
"id": "6BCpjdi5rjU4"
|
123 |
+
},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"token = KMerTokenizer()\n",
|
127 |
+
"token.build_vocab([dna_seq])\n",
|
128 |
+
"print(f\"vocab size: {len(token.vocab)}\")\n",
|
129 |
+
"print(token.id_to_token[:10])"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": null,
|
135 |
+
"metadata": {
|
136 |
+
"id": "6Ou9txgmAdIB"
|
137 |
+
},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"# Train and test splits\n",
|
141 |
+
"data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n",
|
142 |
+
"print(f\"{(len(data)/1e6):0f} million\"\")\n",
|
143 |
+
"n = int(0.9*len(data)) # first 90% will be train, rest val\n",
|
144 |
+
"train_data = data[:n]\n",
|
145 |
+
"val_data = data[n:]\n",
|
146 |
+
"print(f\"train data {(len(train_data)/1e6):.0f}million, val data {(len(val_data)/1e6):.0f}million\")"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": null,
|
152 |
+
"metadata": {
|
153 |
+
"id": "ebFKQQ9NAq4e"
|
154 |
+
},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"# hyperparams\n",
|
158 |
+
"batch_size = 10\n",
|
159 |
+
"block_size = 256\n",
|
160 |
+
"max_iters = 5000\n",
|
161 |
+
"eval_interval = 100\n",
|
162 |
+
"learning_rate = 3e-5\n",
|
163 |
+
"eval_iters = 100\n",
|
164 |
+
"d_model = 512\n",
|
165 |
+
"n_layers = 12\n",
|
166 |
+
"n_head = 18\n",
|
167 |
+
"dropout = 0.25\n",
|
168 |
+
"norm_eps = 1e-5"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {
|
175 |
+
"id": "dZMiYkr37cmU"
|
176 |
+
},
|
177 |
+
"outputs": [],
|
178 |
+
"source": [
|
179 |
+
"import torch.nn as nn\n",
|
180 |
+
"from torch.nn import functional as F\n",
|
181 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
182 |
+
"\n",
|
183 |
+
"class RMSNorm(nn.Module):\n",
|
184 |
+
" def __init__(self, dim: int, eps: float = 1e-6):\n",
|
185 |
+
" super().__init__()\n",
|
186 |
+
" self.eps = eps\n",
|
187 |
+
" self.weight = nn.Parameter(torch.ones(dim))\n",
|
188 |
+
"\n",
|
189 |
+
" def _norm(self, x):\n",
|
190 |
+
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
|
191 |
+
"\n",
|
192 |
+
" def forward(self, x):\n",
|
193 |
+
" output = self._norm(x.float()).type_as(x)\n",
|
194 |
+
" return output * self.weight\n",
|
195 |
+
"\n",
|
196 |
+
"class SingleHead(nn.Module):\n",
|
197 |
+
" def __init__(self,\n",
|
198 |
+
" head_size: int,\n",
|
199 |
+
" d_model: int,\n",
|
200 |
+
" block_size: int,\n",
|
201 |
+
" dropout: float):\n",
|
202 |
+
" super().__init__()\n",
|
203 |
+
" self.key = nn.Linear(d_model, head_size, bias=True)\n",
|
204 |
+
" self.query = nn.Linear(d_model, head_size, bias=True)\n",
|
205 |
+
" self.value = nn.Linear(d_model, head_size, bias=True)\n",
|
206 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
207 |
+
" self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
|
208 |
+
" self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
|
209 |
+
"\n",
|
210 |
+
" def forward(self, x: torch.Tensor, mask: bool= False):\n",
|
211 |
+
" B, T, C = x.shape\n",
|
212 |
+
" key = self.key(x)\n",
|
213 |
+
" query = self.query(x)\n",
|
214 |
+
" scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)\n",
|
215 |
+
"\n",
|
216 |
+
" if mask is True:\n",
|
217 |
+
" scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
|
218 |
+
"\n",
|
219 |
+
" rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])\n",
|
220 |
+
" scores = scores + rel_pos_scores\n",
|
221 |
+
"\n",
|
222 |
+
" att_mat = F.softmax(scores, dim=-1)\n",
|
223 |
+
" att_mat = self.dropout(att_mat)\n",
|
224 |
+
" value = self.value(x)\n",
|
225 |
+
" output = torch.matmul(att_mat, value)\n",
|
226 |
+
" return output\n",
|
227 |
+
"\n",
|
228 |
+
"class MultiHeadAttention(nn.Module):\n",
|
229 |
+
" def __init__(self,\n",
|
230 |
+
" d_model: int,\n",
|
231 |
+
" block_size: int,\n",
|
232 |
+
" n_head : int,\n",
|
233 |
+
" dropout: float):\n",
|
234 |
+
" head_size = d_model // n_head\n",
|
235 |
+
" super().__init__()\n",
|
236 |
+
" self.heads = nn.ModuleList([SingleHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n",
|
237 |
+
" self.projection = nn.Linear(d_model, d_model)\n",
|
238 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
239 |
+
"\n",
|
240 |
+
" def forward(self, x: torch.Tensor, mask: bool):\n",
|
241 |
+
" out = torch.cat([h(x, mask) for h in self.heads], dim=-1)\n",
|
242 |
+
" out = self.dropout(self.projection(out))\n",
|
243 |
+
" return out\n",
|
244 |
+
"\n",
|
245 |
+
"class FeedForward(nn.Module):\n",
|
246 |
+
" def __init__(self, d_model, dropout):\n",
|
247 |
+
" super().__init__()\n",
|
248 |
+
" self.net = nn.Sequential(\n",
|
249 |
+
" nn.Linear(d_model, 5 * d_model),\n",
|
250 |
+
" nn.GELU(),\n",
|
251 |
+
" nn.Linear(5 * d_model, d_model),\n",
|
252 |
+
" nn.Dropout(dropout),\n",
|
253 |
+
" )\n",
|
254 |
+
"\n",
|
255 |
+
" def forward(self, x: torch.Tensor):\n",
|
256 |
+
" return self.net(x)\n",
|
257 |
+
"\n",
|
258 |
+
"class DecoderBlock(nn.Module):\n",
|
259 |
+
" def __init__(self, d_model: int,\n",
|
260 |
+
" block_size: int,\n",
|
261 |
+
" n_head: int,\n",
|
262 |
+
" norm_eps: float,\n",
|
263 |
+
" dropout: float):\n",
|
264 |
+
" super().__init__()\n",
|
265 |
+
" self.self_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
|
266 |
+
" self.ffwd = FeedForward(d_model, dropout)\n",
|
267 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
268 |
+
" self.norm = RMSNorm(d_model, eps=norm_eps)\n",
|
269 |
+
"\n",
|
270 |
+
" def forward(self, x: torch.Tensor):\n",
|
271 |
+
" x_out = self.self_att(self.norm(x), mask=True)\n",
|
272 |
+
" x_out = x + self.dropout(x_out)\n",
|
273 |
+
" del x\n",
|
274 |
+
"\n",
|
275 |
+
" x = self.self_att(self.norm(x_out, mask=False))\n",
|
276 |
+
" x = x_out + self.dropout(x)\n",
|
277 |
+
" del x_out\n",
|
278 |
+
"\n",
|
279 |
+
" x_out = self.ffwd(self.norm(x))\n",
|
280 |
+
" x_out = x + self.dropout(x_out)\n",
|
281 |
+
" del x\n",
|
282 |
+
"\n",
|
283 |
+
" return x_out\n",
|
284 |
+
"\n",
|
285 |
+
"class Transformer(nn.Module):\n",
|
286 |
+
" def __init__(self, vocab_size: int):\n",
|
287 |
+
" super().__init__()\n",
|
288 |
+
" self.block_size = block_size\n",
|
289 |
+
" self.token_embeddings = nn.Embedding(vocab_size, d_model)\n",
|
290 |
+
" self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])\n",
|
291 |
+
" self.norm_final = RMSNorm(d_model, eps=norm_eps)\n",
|
292 |
+
" self.linear_final = nn.Linear(d_model, vocab_size)\n",
|
293 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
294 |
+
" self.apply(self._init_weights)\n",
|
295 |
+
"\n",
|
296 |
+
" def _init_weights(self, module):\n",
|
297 |
+
" if isinstance(module, nn.Linear):\n",
|
298 |
+
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
299 |
+
" if module.bias is not None:\n",
|
300 |
+
" torch.nn.init.zeros_(module.bias.data)\n",
|
301 |
+
" elif isinstance(module, nn.Embedding):\n",
|
302 |
+
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
303 |
+
"\n",
|
304 |
+
" def forward(self, idx, targets=None):\n",
|
305 |
+
" B, T = idx.shape\n",
|
306 |
+
" x = self.token_embeddings(idx)\n",
|
307 |
+
" x = self.decoder(x)\n",
|
308 |
+
" logits = self.linear_final(self.norm_final(x))\n",
|
309 |
+
"\n",
|
310 |
+
" if targets is None:\n",
|
311 |
+
" loss = None\n",
|
312 |
+
"\n",
|
313 |
+
" else:\n",
|
314 |
+
" B, T, C = logits.shape\n",
|
315 |
+
" logits = logits.view(B*T, C)\n",
|
316 |
+
" targets = targets.view(B*T)\n",
|
317 |
+
" loss = F.cross_entropy(logits, targets)\n",
|
318 |
+
"\n",
|
319 |
+
" return logits, loss\n",
|
320 |
+
"\n",
|
321 |
+
" @torch.no_grad()\n",
|
322 |
+
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
|
323 |
+
" self.eval()\n",
|
324 |
+
" for _ in range(max_new_tokens):\n",
|
325 |
+
"\n",
|
326 |
+
" idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]\n",
|
327 |
+
" logits, _ = self(idx_cond)\n",
|
328 |
+
" logits = logits[:, -1, :] / temperature\n",
|
329 |
+
"\n",
|
330 |
+
" if top_k is not None:\n",
|
331 |
+
" v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
|
332 |
+
" logits[logits < v[:, [-1]]] = -float('Inf')\n",
|
333 |
+
"\n",
|
334 |
+
" probs = F.softmax(logits, dim=-1)\n",
|
335 |
+
" idx_next = torch.multinomial(probs, num_samples=1)\n",
|
336 |
+
" idx = torch.cat((idx, idx_next), dim=1)\n",
|
337 |
+
"\n",
|
338 |
+
" return idx"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"cell_type": "code",
|
343 |
+
"execution_count": null,
|
344 |
+
"metadata": {
|
345 |
+
"id": "X9VOBZFr7g3W"
|
346 |
+
},
|
347 |
+
"outputs": [],
|
348 |
+
"source": [
|
349 |
+
"import timeit\n",
|
350 |
+
"start_time = timeit.default_timer()\n",
|
351 |
+
"\n",
|
352 |
+
"def get_batch(split):\n",
|
353 |
+
" data = train_data if split == 'train' else val_data\n",
|
354 |
+
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
355 |
+
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
|
356 |
+
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
|
357 |
+
" x, y = x.to(device), y.to(device)\n",
|
358 |
+
" return x, y\n",
|
359 |
+
"\n",
|
360 |
+
"@torch.no_grad()\n",
|
361 |
+
"def estimate_loss():\n",
|
362 |
+
" out = {}\n",
|
363 |
+
" model.eval()\n",
|
364 |
+
" for split in ['train', 'val']:\n",
|
365 |
+
" losses = torch.zeros(eval_iters)\n",
|
366 |
+
" for k in range(eval_iters):\n",
|
367 |
+
" X, Y = get_batch(split)\n",
|
368 |
+
" logits, loss = model(X, Y)\n",
|
369 |
+
" losses[k] = loss.item()\n",
|
370 |
+
" out[split] = losses.mean()\n",
|
371 |
+
" model.train()\n",
|
372 |
+
" return out\n",
|
373 |
+
"\n",
|
374 |
+
"vocab_size = len(token.vocab)\n",
|
375 |
+
"model = Transformer(vocab_size)\n",
|
376 |
+
"# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
|
377 |
+
"# checkpoint = torch.load(checkpoint_path)\n",
|
378 |
+
"# model.load_state_dict(checkpoint)\n",
|
379 |
+
"m = model.to(device)\n",
|
380 |
+
"\n",
|
381 |
+
"# no of parameters\n",
|
382 |
+
"n_param = sum(p.numel() for p in m.parameters())/1e6\n",
|
383 |
+
"print(f\"{n_param:.1f} million parameters\")\n",
|
384 |
+
"\n",
|
385 |
+
"# optimizer\n",
|
386 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
|
387 |
+
"steps = []\n",
|
388 |
+
"train_losses = []\n",
|
389 |
+
"val_losses = []\n",
|
390 |
+
"\n",
|
391 |
+
"for iter in range(max_iters):\n",
|
392 |
+
"\n",
|
393 |
+
" if iter % eval_interval == 0 or iter == max_iters - 1:\n",
|
394 |
+
" losses = estimate_loss()\n",
|
395 |
+
" print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
|
396 |
+
"\n",
|
397 |
+
" steps.append(iter)\n",
|
398 |
+
" train_losses.append(losses['train'])\n",
|
399 |
+
" val_losses.append(losses['val'])\n",
|
400 |
+
"\n",
|
401 |
+
" xb, yb = get_batch('train')\n",
|
402 |
+
" logits, loss = model(xb, yb)\n",
|
403 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
404 |
+
" loss.backward()\n",
|
405 |
+
" optimizer.step()"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"execution_count": null,
|
411 |
+
"metadata": {
|
412 |
+
"id": "tzJMKoA35uIV"
|
413 |
+
},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"end_time = timeit.default_timer()\n",
|
417 |
+
"print(f\"total parameters: {n_param:.1f} billion\")\n",
|
418 |
+
"print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": null,
|
424 |
+
"metadata": {
|
425 |
+
"id": "eB47Yn9aNrrO"
|
426 |
+
},
|
427 |
+
"outputs": [],
|
428 |
+
"source": [
|
429 |
+
"model_save_name = f'consolidated_00.pth'\n",
|
430 |
+
"path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
|
431 |
+
"torch.save(model.state_dict(), path)\n",
|
432 |
+
"\n",
|
433 |
+
"# saving safe-tensors\n",
|
434 |
+
"from safetensors.torch import save_file\n",
|
435 |
+
"\n",
|
436 |
+
"model_save_name = f'consolidated_00.safetensors'\n",
|
437 |
+
"path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
|
438 |
+
"save_file(model.state_dict(), path)"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "code",
|
443 |
+
"execution_count": null,
|
444 |
+
"metadata": {
|
445 |
+
"id": "89TNah_89CRB"
|
446 |
+
},
|
447 |
+
"outputs": [],
|
448 |
+
"source": [
|
449 |
+
"!nvidia-smi"
|
450 |
+
]
|
451 |
+
}
|
452 |
+
],
|
453 |
+
"metadata": {
|
454 |
+
"accelerator": "GPU",
|
455 |
+
"colab": {
|
456 |
+
"gpuType": "T4",
|
457 |
+
"machine_shape": "hm",
|
458 |
+
"provenance": []
|
459 |
+
},
|
460 |
+
"kernelspec": {
|
461 |
+
"display_name": "Python 3",
|
462 |
+
"name": "python3"
|
463 |
+
},
|
464 |
+
"language_info": {
|
465 |
+
"name": "python"
|
466 |
+
}
|
467 |
+
},
|
468 |
+
"nbformat": 4,
|
469 |
+
"nbformat_minor": 0
|
470 |
+
}
|
enigma/config_enigma.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"batch_size": 10,
|
3 |
+
"block_size": 512,
|
4 |
+
"max_iters": 5000,
|
5 |
+
"eval_interval": 50,
|
6 |
+
"learning_rate": 3e-5,
|
7 |
+
"eval_iters": 100,
|
8 |
+
"d_model": 384,
|
9 |
+
"n_head": 12,
|
10 |
+
"n_layer": 12,
|
11 |
+
"dropout": 0.2,
|
12 |
+
"norm_eps": 1e-5
|
13 |
+
}
|
enigma/enigma.cpp
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/torch.h>
|
2 |
+
#include <iostream>
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
// Define device
|
6 |
+
torch::Device device(torch::kCUDA);
|
7 |
+
|
8 |
+
// Define constants
|
9 |
+
const int batch_size = 8;
|
10 |
+
const int block_size = 32;
|
11 |
+
const int max_iters = 1000;
|
12 |
+
const int eval_interval = 50;
|
13 |
+
const int eval_iters = 5;
|
14 |
+
const int d_model = 256;
|
15 |
+
const int n_layer = 16;
|
16 |
+
const int n_head = 12;
|
17 |
+
const float dropout = 0.2;
|
18 |
+
const float norm_eps = 1e-5;
|
19 |
+
const int vocab_size = 5;
|
20 |
+
|
21 |
+
// sample data
|
22 |
+
torch::Tensor train_data = torch::rand({1000, block_size});
|
23 |
+
torch::Tensor val_data = torch::rand({500, block_size});
|
24 |
+
|
25 |
+
// Data loading function
|
26 |
+
std::pair<torch::Tensor, torch::Tensor> get_batch(const std::string& split) {
|
27 |
+
torch::Tensor data = (split == "train") ? train_data : val_data;
|
28 |
+
torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size});
|
29 |
+
torch::Tensor x = torch::empty({batch_size, block_size});
|
30 |
+
torch::Tensor y = torch::empty({batch_size, block_size});
|
31 |
+
for (int i = 0; i < batch_size; ++i) {
|
32 |
+
x[i] = data.index({ix[i], ix[i] + block_size});
|
33 |
+
y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1});
|
34 |
+
}
|
35 |
+
return std::make_pair(x.to(device), y.to(device));
|
36 |
+
}
|
37 |
+
|
38 |
+
// Custom classes and functions
|
39 |
+
class SWiGLU : public torch::nn::Module {
|
40 |
+
public:
|
41 |
+
SWiGLU() {}
|
42 |
+
|
43 |
+
torch::Tensor forward(torch::Tensor x) {
|
44 |
+
torch::Tensor sigmoid_output = torch::sigmoid(x);
|
45 |
+
torch::Tensor relu_output = torch::relu(x);
|
46 |
+
torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x;
|
47 |
+
return out;
|
48 |
+
}
|
49 |
+
};
|
50 |
+
|
51 |
+
class UnMaskedHeadImpl : public torch::nn::Module {
|
52 |
+
public:
|
53 |
+
UnMaskedHeadImpl(int d_model, int head_size, float dropout)
|
54 |
+
: key(register_module("key", torch::nn::Linear(d_model, head_size))),
|
55 |
+
query(register_module("query", torch::nn::Linear(d_model, head_size))),
|
56 |
+
value(register_module("value", torch::nn::Linear(d_model, head_size))),
|
57 |
+
dropout(torch::nn::Dropout(dropout)) {
|
58 |
+
register_module("dropout", dropout);
|
59 |
+
}
|
60 |
+
|
61 |
+
torch::Tensor forward(torch::Tensor x) {
|
62 |
+
torch::Tensor key_out = key->forward(x);
|
63 |
+
torch::Tensor query_out = query->forward(x);
|
64 |
+
|
65 |
+
torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
|
66 |
+
weights = torch::softmax(weights, -1);
|
67 |
+
weights = dropout(weights);
|
68 |
+
|
69 |
+
torch::Tensor value_out = value->forward(x);
|
70 |
+
torch::Tensor out = weights.matmul(value_out);
|
71 |
+
return out;
|
72 |
+
}
|
73 |
+
|
74 |
+
private:
|
75 |
+
torch::nn::Linear key, query, value;
|
76 |
+
torch::nn::Dropout dropout;
|
77 |
+
};
|
78 |
+
|
79 |
+
TORCH_MODULE(UnMaskedHead);
|
80 |
+
|
81 |
+
class MaskedHeadImpl : public torch::nn::Module {
|
82 |
+
public:
|
83 |
+
MaskedHeadImpl(int head_size, float dropout, int d_model)
|
84 |
+
: key(register_module("key", torch::nn::Linear(d_model, head_size))),
|
85 |
+
query(register_module("query", torch::nn::Linear(d_model, head_size))),
|
86 |
+
value(register_module("value", torch::nn::Linear(d_model, head_size))),
|
87 |
+
dropout(torch::nn::Dropout(dropout)) {
|
88 |
+
register_buffer("tril", torch::tril(torch::ones(block_size, block_size)));
|
89 |
+
}
|
90 |
+
|
91 |
+
torch::Tensor forward(torch::Tensor x) {
|
92 |
+
torch::Tensor key_out = key->forward(x);
|
93 |
+
torch::Tensor query_out = query->forward(x);
|
94 |
+
|
95 |
+
torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
|
96 |
+
weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits<float>::lowest());
|
97 |
+
weights = torch::softmax(weights, -1);
|
98 |
+
weights = dropout(weights);
|
99 |
+
|
100 |
+
torch::Tensor value_out = value->forward(x);
|
101 |
+
torch::Tensor out = weights.matmul(value_out);
|
102 |
+
return out;
|
103 |
+
}
|
104 |
+
|
105 |
+
private:
|
106 |
+
torch::nn::Linear key, query, value;
|
107 |
+
torch::nn::Dropout dropout;
|
108 |
+
torch::Tensor tril;
|
109 |
+
};
|
110 |
+
|
111 |
+
TORCH_MODULE(MaskedHead);
|
112 |
+
|
113 |
+
class MultiUnMaskedImpl : public torch::nn::Module {
|
114 |
+
public:
|
115 |
+
MultiUnMaskedImpl(int d_model, int n_head, float dropout)
|
116 |
+
: proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
|
117 |
+
dropout(torch::nn::Dropout(dropout)) {
|
118 |
+
for (int i = 0; i < n_head; ++i) {
|
119 |
+
heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout)));
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
torch::Tensor forward(torch::Tensor x) {
|
124 |
+
std::vector<torch::Tensor> head_outputs;
|
125 |
+
for (auto& head : heads) {
|
126 |
+
head_outputs.push_back(head->forward(x));
|
127 |
+
}
|
128 |
+
torch::Tensor out = torch::cat(head_outputs, -1);
|
129 |
+
out = dropout(out);
|
130 |
+
out = proj(out);
|
131 |
+
return out;
|
132 |
+
}
|
133 |
+
|
134 |
+
private:
|
135 |
+
torch::nn::Linear proj;
|
136 |
+
torch::nn::Dropout dropout;
|
137 |
+
std::vector<UnMaskedHead> heads;
|
138 |
+
};
|
139 |
+
|
140 |
+
TORCH_MODULE(MultiUnMasked);
|
141 |
+
|
142 |
+
class MultiMaskedImpl : public torch::nn::Module {
|
143 |
+
public:
|
144 |
+
MultiMaskedImpl(int d_model, int n_head, float dropout)
|
145 |
+
: proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
|
146 |
+
dropout(torch::nn::Dropout(dropout)) {
|
147 |
+
for (int i = 0; i < n_head; ++i) {
|
148 |
+
heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout)));
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
torch::Tensor forward(torch::Tensor x) {
|
153 |
+
std::vector<torch::Tensor> head_outputs;
|
154 |
+
for (auto& head : heads) {
|
155 |
+
head_outputs.push_back(head->forward(x));
|
156 |
+
}
|
157 |
+
torch::Tensor out = torch::cat(head_outputs, -1);
|
158 |
+
out = dropout(out);
|
159 |
+
out = proj(out);
|
160 |
+
return out;
|
161 |
+
}
|
162 |
+
|
163 |
+
private:
|
164 |
+
torch::nn::Linear proj;
|
165 |
+
torch::nn::Dropout dropout;
|
166 |
+
std::vector<MaskedHead> heads;
|
167 |
+
};
|
168 |
+
|
169 |
+
TORCH_MODULE(MultiMasked);
|
170 |
+
|
171 |
+
class FeedForwardImpl : public torch::nn::Module {
|
172 |
+
public:
|
173 |
+
FeedForwardImpl(int d_model, float dropout)
|
174 |
+
: net(register_module("net", torch::nn::Sequential(
|
175 |
+
torch::nn::Linear(d_model, 4 * d_model),
|
176 |
+
torch::nn::GELU(),
|
177 |
+
torch::nn::Linear(4 * d_model, d_model),
|
178 |
+
torch::nn::Dropout(dropout)
|
179 |
+
))) {}
|
180 |
+
|
181 |
+
torch::Tensor forward(torch::Tensor x) {
|
182 |
+
return net->forward(x);
|
183 |
+
}
|
184 |
+
|
185 |
+
private:
|
186 |
+
torch::nn::Sequential net;
|
187 |
+
};
|
188 |
+
|
189 |
+
TORCH_MODULE(FeedForward);
|
190 |
+
|
191 |
+
class BlockImpl : public torch::nn::Module {
|
192 |
+
public:
|
193 |
+
BlockImpl(int d_model, int n_head, float norm_eps, float dropout)
|
194 |
+
: sa_masked(MultiMasked(d_model, n_head, dropout)),
|
195 |
+
sa_unmasked(MultiUnMasked(d_model, n_head, dropout)),
|
196 |
+
ffwd(FeedForward(d_model, dropout)),
|
197 |
+
norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
|
198 |
+
norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {}
|
199 |
+
|
200 |
+
torch::Tensor forward(torch::Tensor x) {
|
201 |
+
torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x));
|
202 |
+
x = x2 + ffwd->forward(norm2->forward(x2));
|
203 |
+
|
204 |
+
x2 = x + sa_masked->forward(norm1->forward(x));
|
205 |
+
x = x2 + ffwd->forward(norm2->forward(x2));
|
206 |
+
|
207 |
+
return x;
|
208 |
+
}
|
209 |
+
|
210 |
+
private:
|
211 |
+
MultiMasked sa_masked;
|
212 |
+
MultiUnMasked sa_unmasked;
|
213 |
+
FeedForward ffwd;
|
214 |
+
torch::nn::LayerNorm norm1, norm2;
|
215 |
+
};
|
216 |
+
|
217 |
+
TORCH_MODULE(Block);
|
218 |
+
|
219 |
+
class EnigmaImpl : public torch::nn::Module {
|
220 |
+
public:
|
221 |
+
EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps)
|
222 |
+
: toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))),
|
223 |
+
pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))),
|
224 |
+
norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
|
225 |
+
linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) {
|
226 |
+
for (int i = 0; i < n_layer; ++i) {
|
227 |
+
block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout)));
|
228 |
+
}
|
229 |
+
register_buffer("block_size", torch::tensor(block_size));
|
230 |
+
_init_weights(this);
|
231 |
+
}
|
232 |
+
|
233 |
+
void _init_weights(torch::nn::Module* module) {
|
234 |
+
auto parameters = module->named_parameters();
|
235 |
+
for (auto& param : parameters) {
|
236 |
+
if (param.key().find("weight") != std::string::npos) {
|
237 |
+
torch::nn::init::normal_(param.value(), 0.0, 0.02);
|
238 |
+
} else if (param.key().find("bias") != std::string::npos) {
|
239 |
+
torch::nn::init::zeros_(param.value());
|
240 |
+
}
|
241 |
+
}
|
242 |
+
}
|
243 |
+
|
244 |
+
std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) {
|
245 |
+
torch::Tensor toked_model_out = toked_model->forward(idx);
|
246 |
+
torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1)));
|
247 |
+
torch::Tensor x = toked_model_out + pos_encod_out;
|
248 |
+
|
249 |
+
for (auto& block : block_layers) {
|
250 |
+
x = block->forward(x);
|
251 |
+
}
|
252 |
+
|
253 |
+
torch::Tensor logits = linear_final->forward(norm_final->forward(x));
|
254 |
+
|
255 |
+
if (!targets.numel()) {
|
256 |
+
return {logits, torch::Tensor()};
|
257 |
+
} else {
|
258 |
+
logits = logits.view({-1, logits.size(-1)});
|
259 |
+
targets = targets.view({-1});
|
260 |
+
torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets);
|
261 |
+
return {logits, loss};
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
std::vector<std::vector<std::pair<torch::Tensor, float>>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) {
|
266 |
+
std::vector<std::vector<std::pair<torch::Tensor, float>>> completed_beams;
|
267 |
+
torch::Tensor current_idx = idx.clone();
|
268 |
+
std::vector<std::pair<torch::Tensor, float>> beam = {std::make_pair(current_idx, 0.0)};
|
269 |
+
|
270 |
+
for (int i = 0; i < max_new_tokens; ++i) {
|
271 |
+
std::vector<std::pair<torch::Tensor, float>> new_beam;
|
272 |
+
|
273 |
+
for (auto& beam_item : beam) {
|
274 |
+
torch::Tensor& current_idx = beam_item.first;
|
275 |
+
torch::Tensor logits, loss;
|
276 |
+
std::tie(logits, loss) = forward(current_idx);
|
277 |
+
logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions
|
278 |
+
|
279 |
+
// Apply softmax and temperature
|
280 |
+
torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1);
|
281 |
+
|
282 |
+
// Top-k sampling
|
283 |
+
if (top_k > 0) {
|
284 |
+
probs = top_k_filtering(probs, top_k);
|
285 |
+
}
|
286 |
+
|
287 |
+
// Sample from the distribution
|
288 |
+
torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true);
|
289 |
+
|
290 |
+
for (int j = 0; j < beam_width; ++j) {
|
291 |
+
torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1);
|
292 |
+
torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})}));
|
293 |
+
new_beam.push_back(std::make_pair(new_idx, new_log_prob.item()));
|
294 |
+
}
|
295 |
+
}
|
296 |
+
|
297 |
+
// Sort new beam by log probabilities
|
298 |
+
std::sort(new_beam.begin(), new_beam.end(), [](const std::pair<torch::Tensor, float>& a, const std::pair<torch::Tensor, float>& b) {
|
299 |
+
return a.second > b.second;
|
300 |
+
});
|
301 |
+
|
302 |
+
// Only keep top beams
|
303 |
+
beam = std::vector<std::pair<torch::Tensor, float>>(new_beam.begin(), new_beam.begin() + beam_width);
|
304 |
+
}
|
305 |
+
|
306 |
+
completed_beams.push_back(beam);
|
307 |
+
return completed_beams;
|
308 |
+
}
|
309 |
+
|
310 |
+
std::vector<std::vector<std::pair<torch::Tensor, float>>> top_k_filtering(torch::Tensor logits, int top_k) {
|
311 |
+
torch::Tensor top_values, top_indices;
|
312 |
+
std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1);
|
313 |
+
|
314 |
+
torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1}));
|
315 |
+
torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits<float>::infinity()), logits);
|
316 |
+
return filtered_logits;
|
317 |
+
}
|
318 |
+
|
319 |
+
private:
|
320 |
+
torch::nn::Embedding toked_model, pos_encod;
|
321 |
+
std::vector<Block> block_layers;
|
322 |
+
torch::nn::LayerNorm norm_final;
|
323 |
+
torch::nn::Linear linear_final;
|
324 |
+
int block_size;
|
325 |
+
};
|
326 |
+
|
327 |
+
TORCH_MODULE(Enigma);
|
328 |
+
|
329 |
+
int main() {
|
330 |
+
// Set seed
|
331 |
+
torch::manual_seed(1400);
|
332 |
+
|
333 |
+
// Create model
|
334 |
+
Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps);
|
335 |
+
model->to(device);
|
336 |
+
|
337 |
+
// Define optimizer
|
338 |
+
torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate));
|
339 |
+
|
340 |
+
// Training loop
|
341 |
+
std::vector<float> train_losses, val_losses;
|
342 |
+
for (int iter = 0; iter < max_iters; ++iter) {
|
343 |
+
if (iter % eval_interval == 0 || iter == max_iters - 1) {
|
344 |
+
// Evaluate and print losses
|
345 |
+
auto losses = estimate_loss();
|
346 |
+
std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl;
|
347 |
+
|
348 |
+
// Save losses for plotting
|
349 |
+
train_losses.push_back(losses["train"]);
|
350 |
+
val_losses.push_back(losses["val"]);
|
351 |
+
}
|
352 |
+
|
353 |
+
// Get batch, forward pass, loss calculation, backward pass, optimizer step
|
354 |
+
auto [xb, yb] = get_batch("train");
|
355 |
+
torch::Tensor logits, loss;
|
356 |
+
std::tie(logits, loss) = model->forward(xb, yb);
|
357 |
+
|
358 |
+
optimizer.zero_grad();
|
359 |
+
loss.backward();
|
360 |
+
optimizer.step();
|
361 |
+
}
|
362 |
+
|
363 |
+
return 0;
|
364 |
+
}
|
enigma/generate.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
3 |
+
os.chdir(current_directory)
|
4 |
+
|
5 |
+
with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
|
6 |
+
captions = file.read()
|
7 |
+
|
8 |
+
print(f"{(len(captions)/1e6):.2f} million letters")
|
9 |
+
|
10 |
+
from tokenizer import PerCharTokenizer
|
11 |
+
|
12 |
+
tokenizer = PerCharTokenizer()
|
13 |
+
vocab_size = tokenizer.vocab_size
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
|
20 |
+
from model import Transformer
|
21 |
+
model = Transformer(vocab_size=vocab_size)
|
22 |
+
|
23 |
+
class Generator(Transformer):
|
24 |
+
def __init__(self, vocab_size):
|
25 |
+
super().__init__()
|
26 |
+
self.vocab_size = vocab_size
|
27 |
+
self.block_size = Transformer.block_size
|
28 |
+
|
29 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
30 |
+
"""
|
31 |
+
generate new tokens using the trained model
|
32 |
+
|
33 |
+
Args:
|
34 |
+
- idx (Tensor): input tensor representing initial token indices
|
35 |
+
- max_new_tokens (int): max no of new tokens to generate
|
36 |
+
- temperature (float): softmax temperature for sampling
|
37 |
+
- top_k (int): no of top tokens to consider in sampling
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
- generated_tokens (list): list of generated token indices
|
41 |
+
"""
|
42 |
+
generated_tokens = []
|
43 |
+
|
44 |
+
for _ in range(max_new_tokens):
|
45 |
+
idx_cond = idx[:, -self.block_size:]
|
46 |
+
logits, _ = self(idx_cond)
|
47 |
+
logits = logits[:, -1, :]
|
48 |
+
|
49 |
+
scaled_logits = logits / temperature
|
50 |
+
if top_k > 0:
|
51 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
52 |
+
|
53 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
54 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
55 |
+
generated_tokens.append(sampled_idx.item())
|
56 |
+
idx = torch.cat((idx, sampled_idx), dim=1)
|
57 |
+
|
58 |
+
return generated_tokens
|
59 |
+
|
60 |
+
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
|
61 |
+
"""
|
62 |
+
Generate predictions for masked tokens using the trained model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
- idx (Tensor): input tensor representing token indices
|
66 |
+
- masked_indices (Tensor): tensor of indices indicating masked positions
|
67 |
+
- temperature (float): softmax temperature for sampling
|
68 |
+
- top_k (int): no of top tokens to consider in sampling
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
- predicted_tokens (Tensor): tensor of predicted token indices
|
72 |
+
"""
|
73 |
+
B, T = idx.shape
|
74 |
+
|
75 |
+
toked_model = self.toked_model(idx)
|
76 |
+
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
77 |
+
x = toked_model + pos_encod
|
78 |
+
|
79 |
+
for layer in self.enc_layer:
|
80 |
+
x_out = layer(x)
|
81 |
+
|
82 |
+
for layer in self.dec_layer:
|
83 |
+
x_final = layer(x, x_out)
|
84 |
+
|
85 |
+
x_masked = x_final.clone()
|
86 |
+
x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
|
87 |
+
|
88 |
+
x_masked = self.norm_final(x_masked)
|
89 |
+
logits = self.linear_final(x_masked)
|
90 |
+
|
91 |
+
masked_logits = logits[masked_indices].view(-1, logits.size(-1))
|
92 |
+
scaled_logits = masked_logits / temperature
|
93 |
+
if top_k > 0:
|
94 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
95 |
+
|
96 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
97 |
+
predicted_indices = torch.argmax(probs, dim=-1)
|
98 |
+
|
99 |
+
return predicted_indices
|
100 |
+
|
101 |
+
def _top_k_filtering(self, logits, top_k):
|
102 |
+
"""
|
103 |
+
filter logits to keep only the top-k tokens
|
104 |
+
|
105 |
+
Args:
|
106 |
+
- logits (Tensor): input tensor representing unscaled logits
|
107 |
+
- top_k (int): no of top tokens to keep
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining
|
111 |
+
"""
|
112 |
+
values, indices = torch.topk(logits, top_k, dim=-1)
|
113 |
+
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
114 |
+
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
115 |
+
|
116 |
+
return filtered_logits
|
117 |
+
|
118 |
+
checkpoint_path = '../trained models/enigma_47m.pth'
|
119 |
+
checkpoint = torch.load(checkpoint_path)
|
120 |
+
model.load_state_dict(checkpoint)
|
121 |
+
m = model.to(device)
|
122 |
+
|
123 |
+
target_text = "AGTTCTGCGAT"
|
124 |
+
context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
|
125 |
+
generated_output = tokenizer.decode(Generator.generate(context, max_new_tokens=10, temperature=0.5, top_k=5))
|
126 |
+
print(f"{target_text}{generated_output}")
|
enigma/model.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
transformer based model, but with few minimal tweaks
|
3 |
+
trained a 2.5billion parameters model with current set configurations
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
os.chdir(current_directory)
|
11 |
+
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
with open('config_enigma.json', 'r', encoding='utf-8') as file:
|
16 |
+
params = json.load(file)
|
17 |
+
|
18 |
+
batch_size = params['batch_size']
|
19 |
+
block_size = params['block_size']
|
20 |
+
n_head = params['n_head']
|
21 |
+
d_model = params['d_model']
|
22 |
+
n_layers = params['n_layer']
|
23 |
+
dropout = params['dropout']
|
24 |
+
norm_eps = params['norm_eps']
|
25 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
26 |
+
|
27 |
+
class AttentionHead(nn.Module):
|
28 |
+
"""
|
29 |
+
initialize a single head of self attention.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
- d_model (int): dimensionality of the model's hidden layers
|
33 |
+
- head_size (int): dimensionality of each attention head
|
34 |
+
- dropout (float): dropout probability
|
35 |
+
- block_size (int): the maximum sequence length for positional encoding
|
36 |
+
"""
|
37 |
+
def __init__(self, d_model, head_size, dropout, block_size):
|
38 |
+
super().__init__()
|
39 |
+
self.key = nn.Linear(d_model, head_size, bias=True)
|
40 |
+
self.query = nn.Linear(d_model, head_size, bias=True)
|
41 |
+
self.value = nn.Linear(d_model, head_size, bias=False)
|
42 |
+
self.dropout = nn.Dropout(dropout)
|
43 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
44 |
+
|
45 |
+
self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))
|
46 |
+
|
47 |
+
def forward(self, x, mask=False):
|
48 |
+
"""
|
49 |
+
forward pass of a single attention head.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
- x (Tensor): input tensor.
|
53 |
+
- mask (bool): flag indicating whether to apply masking
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
- out (Tensor): output tensor after self attention
|
57 |
+
"""
|
58 |
+
B, T, C = x.shape
|
59 |
+
key = self.key(x)
|
60 |
+
query = self.query(x)
|
61 |
+
|
62 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
|
63 |
+
rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])
|
64 |
+
scores += rel_pos_scores
|
65 |
+
|
66 |
+
if mask:
|
67 |
+
scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
68 |
+
|
69 |
+
weights = F.softmax(scores, dim=-1)
|
70 |
+
weights = self.dropout(weights)
|
71 |
+
|
72 |
+
value = self.value(x)
|
73 |
+
out = torch.matmul(weights, value)
|
74 |
+
return out
|
75 |
+
|
76 |
+
class MultiHeadAttention(nn.Module):
|
77 |
+
"""
|
78 |
+
initialize a multi-head attention module.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
- d_model (int): dimensionality of the model's hidden layers
|
82 |
+
- n_head (int): no of attention heads
|
83 |
+
- dropout (float): dropout probability
|
84 |
+
- block_size (int): context length
|
85 |
+
"""
|
86 |
+
def __init__(self, d_model, n_head, dropout, block_size):
|
87 |
+
head_size = d_model // n_head
|
88 |
+
super().__init__()
|
89 |
+
self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])
|
90 |
+
self.proj = nn.Linear(n_head * head_size, d_model)
|
91 |
+
self.dropout = nn.Dropout(dropout)
|
92 |
+
|
93 |
+
def forward(self, x, mask):
|
94 |
+
"""
|
95 |
+
forward pass of the multi-head attention module
|
96 |
+
|
97 |
+
Args:
|
98 |
+
- x (Tensor): input tensor
|
99 |
+
- mask (bool): flag indicating whether to apply masking
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
- out (Tensor): output tensor after multi-head attention
|
103 |
+
|
104 |
+
"""
|
105 |
+
out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)
|
106 |
+
out = self.dropout(self.proj(out))
|
107 |
+
return out
|
108 |
+
|
109 |
+
class FeedForward(nn.Module):
|
110 |
+
"""
|
111 |
+
initialize a feedforward network module
|
112 |
+
|
113 |
+
Args:
|
114 |
+
- d_model (int): the dimensionality of the model's hidden layers
|
115 |
+
- dropout (float): dropout probability
|
116 |
+
|
117 |
+
"""
|
118 |
+
def __init__(self, d_model, dropout):
|
119 |
+
super().__init__()
|
120 |
+
self.net = nn.Sequential(
|
121 |
+
nn.Linear(d_model, 10*d_model),
|
122 |
+
nn.GELU(),
|
123 |
+
nn.Linear(10*d_model, d_model),
|
124 |
+
nn.Dropout(dropout)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
"""
|
129 |
+
forward pass of the feedforward network module
|
130 |
+
|
131 |
+
Args:
|
132 |
+
- x (Tensor): input tensor
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
- out (Tensor): output tensor after passing through the feedforward network
|
136 |
+
"""
|
137 |
+
return self.net(x)
|
138 |
+
|
139 |
+
class EncoderNetwork(nn.Module):
|
140 |
+
"""
|
141 |
+
initialize an encoder network module
|
142 |
+
|
143 |
+
Args:
|
144 |
+
- d_model (int): dimensionality of the model's hidden layers
|
145 |
+
- n_head (int): no of attention heads in multi-head attention layers
|
146 |
+
- norm_eps (float): epsilon value for layer normalization
|
147 |
+
- dropout (float): dropout probability
|
148 |
+
- block_size (int): the maximum sequence length for positional encoding
|
149 |
+
"""
|
150 |
+
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
151 |
+
super().__init__()
|
152 |
+
self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
153 |
+
self.ffwd = FeedForward(d_model, dropout)
|
154 |
+
self.dropout = nn.Dropout(dropout)
|
155 |
+
self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
|
156 |
+
self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
|
157 |
+
|
158 |
+
def forward(self, src):
|
159 |
+
"""
|
160 |
+
forward pass of the encoder network module.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
- src (Tensor): input tensor representing source data
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
- src (Tensor): output tensor after passing through the encoder network
|
167 |
+
"""
|
168 |
+
src2 = self.s_att(src, mask=False)
|
169 |
+
src = src + self.dropout(src2)
|
170 |
+
src = self.norm1(src)
|
171 |
+
|
172 |
+
src2 = self.ffwd(src)
|
173 |
+
src = src + self.dropout(src2)
|
174 |
+
src = self.norm2(src)
|
175 |
+
|
176 |
+
return src
|
177 |
+
|
178 |
+
class DecoderNetwork(nn.Module):
|
179 |
+
"""
|
180 |
+
initialize a decoder network module
|
181 |
+
|
182 |
+
Args:
|
183 |
+
- d_model (int): dimensionality of the model's hidden layers
|
184 |
+
- n_head (int): no of attention heads in multi-head attention layers
|
185 |
+
- norm_eps (float): epsilon value for layer normalization
|
186 |
+
- dropout (float): dropout probability
|
187 |
+
- block_size (int): the maximum sequence length for positional encoding
|
188 |
+
"""
|
189 |
+
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
190 |
+
super().__init__()
|
191 |
+
self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
192 |
+
self.ffwd = FeedForward(d_model, dropout)
|
193 |
+
self.dropout = nn.Dropout(dropout)
|
194 |
+
self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
|
195 |
+
self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
|
196 |
+
|
197 |
+
def forward(self, src, att):
|
198 |
+
"""
|
199 |
+
forward pass of the decoder network module.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
- src (Tensor): input tensor, same as the encoder's inputs
|
203 |
+
- trg (Tensor): encoder's attention matrix
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
- src_f (Tensor): final output tensor
|
207 |
+
"""
|
208 |
+
src2 = self.s_att(src, mask=True)
|
209 |
+
src = src + self.dropout(src2)
|
210 |
+
src = src + self.norm1(src)
|
211 |
+
|
212 |
+
att = src + att
|
213 |
+
att2 = self.s_att(att, mask=False)
|
214 |
+
att2 = att + self.dropout(att2)
|
215 |
+
trg = att2 + self.norm1(att2)
|
216 |
+
|
217 |
+
src_f2 = self.ffwd(self.norm2(trg))
|
218 |
+
src_f = src_f + self.dropout(src_f2)
|
219 |
+
src_f = self.norm2(src_f)
|
220 |
+
|
221 |
+
return src_f
|
222 |
+
|
223 |
+
class Transformer(nn.Module):
|
224 |
+
"""
|
225 |
+
initialize a Transformer model
|
226 |
+
|
227 |
+
Args:
|
228 |
+
- vocab_size (int): size of the vocabulary
|
229 |
+
- d_model (int): dimensionality of the model's hidden layers
|
230 |
+
- block_size (int): maximum sequence length for positional encoding/context length
|
231 |
+
- n_layers (int): number of encoder and decoder layers in the Transformer
|
232 |
+
- n_head (int): number of attention heads in multi-head attention layers
|
233 |
+
- norm_eps (float): epsilon value for layer normalization
|
234 |
+
- dropout (float): dropout probability
|
235 |
+
"""
|
236 |
+
def __init__(self, vocab_size):
|
237 |
+
super().__init__()
|
238 |
+
self.block_size = block_size
|
239 |
+
self.toked_model = nn.Embedding(vocab_size, d_model)
|
240 |
+
self.pos_encod = nn.Embedding(block_size, d_model)
|
241 |
+
self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
242 |
+
self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
243 |
+
|
244 |
+
self.norm_final = nn.LayerNorm(d_model)
|
245 |
+
self.linear_final = nn.Linear(d_model, vocab_size)
|
246 |
+
self.dropout = nn.Dropout(dropout)
|
247 |
+
self.apply(self._init_weights)
|
248 |
+
|
249 |
+
def _init_weights(self, module):
|
250 |
+
"""
|
251 |
+
initialize weights of linear and embedding layers
|
252 |
+
|
253 |
+
Args:
|
254 |
+
- module (nn.Module): the module to initialize weights for
|
255 |
+
"""
|
256 |
+
if isinstance(module, nn.Linear):
|
257 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
258 |
+
if module.bias is not None:
|
259 |
+
torch.nn.init.zeros_(module.bias.data)
|
260 |
+
elif isinstance(module, nn.Embedding):
|
261 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
262 |
+
|
263 |
+
def forward(self, idx, targets=None):
|
264 |
+
"""
|
265 |
+
forward pass of the transformer model
|
266 |
+
|
267 |
+
Args:
|
268 |
+
- idx (Tensor): input tensor representing token indices
|
269 |
+
- targets (Tensor): target tensor for computing loss during training
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
- logits (Tensor): output logits from the final linear layer
|
273 |
+
- loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
|
274 |
+
"""
|
275 |
+
B, T = idx.shape
|
276 |
+
|
277 |
+
toked_model = self.toked_model(idx)
|
278 |
+
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
279 |
+
x = toked_model + pos_encod
|
280 |
+
|
281 |
+
for layer in self.enc_layer:
|
282 |
+
x_out = layer(x)
|
283 |
+
|
284 |
+
for layer in self.dec_layer:
|
285 |
+
x_final = layer(x, x_out)
|
286 |
+
|
287 |
+
x_final = self.norm_final(x_final)
|
288 |
+
logits = self.linear_final(x_final)
|
289 |
+
|
290 |
+
if targets is None:
|
291 |
+
loss = None
|
292 |
+
|
293 |
+
else:
|
294 |
+
B, T, C = logits.shape
|
295 |
+
logits = logits.view(B*T, C)
|
296 |
+
targets = targets.view(B*T)
|
297 |
+
loss = F.cross_entropy(logits, targets)
|
298 |
+
|
299 |
+
return logits, loss
|
300 |
+
|
301 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
302 |
+
"""
|
303 |
+
generate new tokens using the trained model
|
304 |
+
|
305 |
+
Args:
|
306 |
+
- idx (Tensor): input tensor representing initial token indices
|
307 |
+
- max_new_tokens (int): max no of new tokens to generate
|
308 |
+
- temperature (float): softmax temperature for sampling
|
309 |
+
- top_k (int): no of top tokens to consider in sampling
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
- generated_tokens (list): list of generated token indices
|
313 |
+
"""
|
314 |
+
generated_tokens = []
|
315 |
+
|
316 |
+
for _ in range(max_new_tokens):
|
317 |
+
idx_cond = idx[:, -self.block_size:]
|
318 |
+
logits, _ = self(idx_cond)
|
319 |
+
logits = logits[:, -1, :]
|
320 |
+
|
321 |
+
scaled_logits = logits / temperature
|
322 |
+
if top_k > 0:
|
323 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
324 |
+
|
325 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
326 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
327 |
+
generated_tokens.append(sampled_idx.item())
|
328 |
+
idx = torch.cat((idx, sampled_idx), dim=1)
|
329 |
+
|
330 |
+
return generated_tokens
|
331 |
+
|
332 |
+
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
|
333 |
+
"""
|
334 |
+
Generate predictions for masked tokens using the trained model.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
- idx (Tensor): input tensor representing token indices
|
338 |
+
- masked_indices (Tensor): tensor of indices indicating masked positions
|
339 |
+
- temperature (float): softmax temperature for sampling
|
340 |
+
- top_k (int): no of top tokens to consider in sampling
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
- predicted_tokens (Tensor): tensor of predicted token indices
|
344 |
+
"""
|
345 |
+
B, T = idx.shape
|
346 |
+
|
347 |
+
toked_model = self.toked_model(idx)
|
348 |
+
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
349 |
+
x = toked_model + pos_encod
|
350 |
+
|
351 |
+
for layer in self.enc_layer:
|
352 |
+
x_out = layer(x)
|
353 |
+
|
354 |
+
for layer in self.dec_layer:
|
355 |
+
x_final = layer(x, x_out)
|
356 |
+
|
357 |
+
x_masked = x_final.clone()
|
358 |
+
x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
|
359 |
+
|
360 |
+
x_masked = self.norm_final(x_masked)
|
361 |
+
logits = self.linear_final(x_masked)
|
362 |
+
|
363 |
+
masked_logits = logits[masked_indices].view(-1, logits.size(-1))
|
364 |
+
scaled_logits = masked_logits / temperature
|
365 |
+
if top_k > 0:
|
366 |
+
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
367 |
+
|
368 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
369 |
+
predicted_indices = torch.argmax(probs, dim=-1)
|
370 |
+
|
371 |
+
return predicted_indices
|
372 |
+
|
373 |
+
def _top_k_filtering(self, logits, top_k):
|
374 |
+
"""
|
375 |
+
filter logits to keep only the top-k tokens
|
376 |
+
|
377 |
+
Args:
|
378 |
+
- logits (Tensor): input tensor representing unscaled logits
|
379 |
+
- top_k (int): no of top tokens to keep
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining
|
383 |
+
"""
|
384 |
+
values, indices = torch.topk(logits, top_k, dim=-1)
|
385 |
+
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
386 |
+
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
387 |
+
|
388 |
+
return filtered_logits
|
enigma/run.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
use this file to train the model
|
3 |
+
|
4 |
+
working:
|
5 |
+
- imports vatious dependencies first, and then loads the training data
|
6 |
+
- tokenizes it, per-character basis
|
7 |
+
- loads the required hyper-parameters and the model file
|
8 |
+
- trains it till 'max_iters' and saves the model state, and generates outputs
|
9 |
+
|
10 |
+
with the current set configuration, model can reach upto ~60million parameters
|
11 |
+
and can become ~99% accurate with next token prediction
|
12 |
+
"""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
os.chdir(current_directory)
|
19 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
|
22 |
+
captions = file.read()
|
23 |
+
|
24 |
+
print(f"{(len(captions)/1e6):.2f} million letters")
|
25 |
+
|
26 |
+
from ..tokenizer import PerCharTokenizer
|
27 |
+
|
28 |
+
tokenizer = PerCharTokenizer()
|
29 |
+
vocab_size = tokenizer.vocab_size
|
30 |
+
# Train and test splits
|
31 |
+
data = torch.tensor(tokenizer.encode(captions), dtype=torch.long)
|
32 |
+
n = int(0.9*len(data)) # first 90% will be train, rest val
|
33 |
+
train_data = data[:n]
|
34 |
+
val_data = data[n:]
|
35 |
+
|
36 |
+
with open('/config_enigma.json', 'r', encoding='utf-8') as file:
|
37 |
+
params = json.load(file)
|
38 |
+
|
39 |
+
# required parameters
|
40 |
+
batch_size = params['batch_size']
|
41 |
+
block_size = params['block_size']
|
42 |
+
max_iters = params['max_iters']
|
43 |
+
eval_interval = params['eval_interval']
|
44 |
+
eval_iters = params['eval_iters']
|
45 |
+
learning_rate = params['learning_rate']
|
46 |
+
|
47 |
+
torch.manual_seed(1400)
|
48 |
+
# data loading
|
49 |
+
def get_batch(split):
|
50 |
+
# generate a small batch of data of inputs x and targets y
|
51 |
+
data = train_data if split == 'train' else val_data
|
52 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
53 |
+
x = torch.stack([data[i:i+block_size] for i in ix])
|
54 |
+
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
|
55 |
+
x, y = x.to(device), y.to(device)
|
56 |
+
return x, y
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def estimate_loss():
|
60 |
+
out = {}
|
61 |
+
model.eval()
|
62 |
+
for split in ['train', 'val']:
|
63 |
+
losses = torch.zeros(eval_iters)
|
64 |
+
for k in range(eval_iters):
|
65 |
+
X, Y = get_batch(split)
|
66 |
+
logits, loss = model(X, Y)
|
67 |
+
losses[k] = loss.item()
|
68 |
+
out[split] = losses.mean()
|
69 |
+
model.train()
|
70 |
+
return out
|
71 |
+
|
72 |
+
from model import Transformer
|
73 |
+
model = Transformer(vocab_size=vocab_size)
|
74 |
+
m = model.to(device)
|
75 |
+
|
76 |
+
# no of parameters
|
77 |
+
n_param = sum(p.numel() for p in m.parameters())/1e6
|
78 |
+
print(f"{n_param:.2f} million")
|
79 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
80 |
+
steps = []
|
81 |
+
train_losses = []
|
82 |
+
val_losses = []
|
83 |
+
|
84 |
+
for iter in range(max_iters):
|
85 |
+
|
86 |
+
if iter % eval_interval == 0 or iter == max_iters - 1:
|
87 |
+
losses = estimate_loss()
|
88 |
+
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
89 |
+
|
90 |
+
steps.append(iter)
|
91 |
+
train_losses.append(losses['train'])
|
92 |
+
val_losses.append(losses['val'])
|
93 |
+
|
94 |
+
xb, yb = get_batch('train')
|
95 |
+
logits, loss = model(xb, yb)
|
96 |
+
optimizer.zero_grad(set_to_none=True)
|
97 |
+
loss.backward()
|
98 |
+
optimizer.step()
|
99 |
+
|
100 |
+
torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')
|