shivendrra commited on
Commit
f44acc0
1 Parent(s): 122f99f

added training and model files

Browse files
enigma/EnBERT.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ simple BERT architecture model, paired with one more layer of
3
+ masked self-attention, to predict next token
4
+ """
5
+
6
+ import torch
7
+ import os
8
+ current_directory = os.path.dirname(os.path.abspath(__file__))
9
+ os.chdir(current_directory)
10
+
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+
16
+ # hyperparams
17
+ batch_size = 8
18
+ block_size = 32
19
+ max_iters = 10
20
+ eval_interval = 10
21
+ learning_rate = 3e-4
22
+ eval_iters = 5
23
+ d_model = 256
24
+ n_layer = 16
25
+ n_head = 12
26
+ dropout = 0.2
27
+ norm_eps = 1e-5
28
+
29
+ class SWiGLU(nn.Module):
30
+ """ SWiGLU(x) = σ(x) ⊙ ReLU(x) + (1−σ(x)) ⊙ x """
31
+
32
+ def forward(self, x):
33
+ sigmoid_output = torch.sigmoid(x)
34
+ relu_output = F.relu(x)
35
+ out = sigmoid_output * relu_output + (1 - sigmoid_output) * x
36
+
37
+ return out
38
+
39
+ class UnMaskedHead(nn.Module):
40
+ """ single head of self attention """
41
+ def __init__(self, d_model, head_size, dropout):
42
+ super().__init__()
43
+ self.key = nn.Linear(d_model, head_size, bias=True)
44
+ self.query = nn.Linear(d_model, head_size, bias=True)
45
+ self.value = nn.Linear(d_model, head_size, bias=True)
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x):
49
+ B, T, C = x.shape
50
+ key = self.key(x)
51
+ query = self.query(x)
52
+
53
+ weights = query @ key.transpose(-2, -1) * key.shape[-1]**-0.5
54
+ weights = F.softmax(weights, dim=-1)
55
+ weights = self.dropout(weights)
56
+
57
+ value = self.value(x)
58
+ out = weights @ value
59
+ return out
60
+
61
+ class MaskedHead(nn.Module):
62
+ """ one head of self-attention """
63
+ def __init__(self, head_size, dropout, d_model):
64
+ super().__init__()
65
+ self.key = nn.Linear(d_model, head_size, bias=True)
66
+ self.query = nn.Linear(d_model, head_size, bias=True)
67
+ self.value = nn.Linear(d_model, head_size, bias=True)
68
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
69
+
70
+ self.dropout = nn.Dropout(dropout)
71
+
72
+ def forward(self, x):
73
+ B,T,C = x.shape
74
+ k = self.key(x)
75
+ q = self.query(x)
76
+
77
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
78
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
79
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
80
+ wei = self.dropout(wei)
81
+
82
+ v = self.value(x)
83
+ out = wei @ v
84
+ return out
85
+
86
+ class MultiUnMasked(nn.Module):
87
+ def __init__(self, d_model, n_head, dropout):
88
+ head_size = d_model // n_head
89
+ super().__init__()
90
+ self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
91
+ self.proj = nn.Linear(n_head * head_size, d_model)
92
+ self.dropout = nn.Dropout(dropout)
93
+
94
+ def forward(self, x):
95
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
96
+ out = self.dropout(self.proj(out))
97
+ return out
98
+
99
+ class MultiMasked(nn.Module):
100
+ def __init__(self, d_model, n_head, dropout):
101
+ head_size = d_model // n_head
102
+ super().__init__()
103
+ self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
104
+ self.proj = nn.Linear(n_head * head_size, d_model)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ def forward(self, x):
108
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
109
+ out = self.dropout(self.proj(out))
110
+ return out
111
+
112
+ class FeedForward(nn.Module):
113
+ def __init__(self, d_model, dropout):
114
+ super().__init__()
115
+ self.net = nn.Sequential(
116
+ nn.Linear(d_model, 4*d_model),
117
+ nn.GELU(),
118
+ nn.Linear(4*d_model, d_model),
119
+ nn.Dropout(dropout)
120
+ )
121
+
122
+ def forward(self, x):
123
+ return self.net(x)
124
+
125
+ class Block(nn.Module):
126
+ def __init__(self, d_model, n_head, norm_eps, dropout):
127
+ super().__init__()
128
+ self.sa_masked = MultiMasked(n_head=n_head, d_model=d_model, dropout=dropout)
129
+ self.sa_unmasked = MultiUnMasked(n_head=n_head, d_model=d_model, dropout=dropout)
130
+ self.ffwd = FeedForward(d_model, dropout=dropout)
131
+ self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
132
+ self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
133
+
134
+ def forward(self, x):
135
+ x2 = x + self.sa_unmasked(self.norm1(x))
136
+ x = x2 + self.norm2(self.ffwd(x2))
137
+
138
+ x2 = x + self.sa_masked(self.norm1(x))
139
+ x = x2 + self.norm2(self.ffwd(x2))
140
+ return x
141
+
142
+ class EnigmaBERT(nn.Module):
143
+ def __init__(self, vocab_size):
144
+ super().__init__()
145
+ self.toked_model = nn.Embedding(vocab_size, d_model)
146
+ self.pos_encod = nn.Embedding(block_size, d_model)
147
+ self.block = nn.Sequential(*[Block(d_model=d_model, dropout=dropout, norm_eps=norm_eps, n_head=n_head) for _ in range(n_layer)])
148
+ self.norm_final = nn.LayerNorm(d_model, eps=norm_eps)
149
+ self.linear_final = nn.Linear(d_model, vocab_size)
150
+ self.apply(self._init_weights)
151
+
152
+
153
+ def _init_weights(self, module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
156
+ if module.bias is not None:
157
+ torch.nn.init.zeros_(module.bias.data)
158
+ elif isinstance(module, nn.Embedding):
159
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
160
+
161
+ def forward(self, idx, targets=None):
162
+ B, T = idx.shape
163
+
164
+ toked_model = self.toked_model(idx)
165
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
166
+ x = toked_model + pos_encod
167
+ x = self.block(x)
168
+ x = self.norm_final(x)
169
+ logits = self.linear_final(x)
170
+
171
+ if targets is None:
172
+ loss = None
173
+
174
+ else:
175
+ B, T, C = logits.shape
176
+ logits = logits.view(B*T, C)
177
+ targets = targets.view(B*T)
178
+ loss = F.cross_entropy(logits, targets)
179
+
180
+ return logits, loss
181
+
182
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
183
+ generated_tokens = []
184
+
185
+ for _ in range(max_new_tokens):
186
+ idx_cond = idx[:, -block_size:]
187
+ logits, _ = self(idx_cond)
188
+ logits = logits[:, -1, :]
189
+
190
+ scaled_logits = logits / temperature
191
+ if top_k > 0:
192
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
193
+
194
+ probs = F.softmax(scaled_logits, dim=-1)
195
+ sampled_idx = torch.multinomial(probs, num_samples=1)
196
+ generated_tokens.append(sampled_idx.item())
197
+ idx = torch.cat((idx, sampled_idx), dim=1)
198
+
199
+ return generated_tokens
200
+
201
+
202
+ def _top_k_filtering(self, logits, top_k):
203
+ values, indices = torch.topk(logits, top_k, dim=-1)
204
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
205
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
206
+ return filtered_logits
enigma/TrainEnigma.ipynb ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "WXpJBLyr30Rx",
11
+ "outputId": "2806070a-648f-42ca-fa8a-9aeb8f99ceb7"
12
+ },
13
+ "outputs": [
14
+ {
15
+ "output_type": "stream",
16
+ "name": "stdout",
17
+ "text": [
18
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from google.colab import drive\n",
24
+ "drive.mount('/content/drive')"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "metadata": {
31
+ "colab": {
32
+ "base_uri": "https://localhost:8080/"
33
+ },
34
+ "id": "r7WUm0VL4bN4",
35
+ "outputId": "bfdefb82-479e-4f91-9a01-299ff76756e9"
36
+ },
37
+ "outputs": [
38
+ {
39
+ "output_type": "stream",
40
+ "name": "stdout",
41
+ "text": [
42
+ "485.52 million letters\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "import torch\n",
48
+ "\n",
49
+ "# importing the data\n",
50
+ "file_path = '/content/drive/MyDrive/train2.txt'\n",
51
+ "with open(file_path, 'r', encoding='utf-8') as file:\n",
52
+ " dna_seq = file.read()\n",
53
+ "file.close()\n",
54
+ "\n",
55
+ "print(f\"{(len(dna_seq)/1e6):.2f} million letters\")"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 3,
61
+ "metadata": {
62
+ "id": "Cdhybhz9owTK"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "class PerCharTokenizer:\n",
67
+ " \"\"\"\n",
68
+ " Args:\n",
69
+ " - chars (list): all bases along with special tokens represented as characters\n",
70
+ " - vocab_size (int): size of vocabulary\n",
71
+ "\n",
72
+ " Working:\n",
73
+ " - vocab contains all the bases and ['P', 'M', 'U'] as padding, mask and unknown token\n",
74
+ " - encode(): iterates over each character a time and the looks up for the position in vocab\n",
75
+ " and returns it's position as integer\n",
76
+ " - decode(): takes input of a list of integers and returns the specific item from vocab\n",
77
+ " \"\"\"\n",
78
+ " def __init__(self):\n",
79
+ " super().__init__()\n",
80
+ " self.chars = ['\\n', 'A', 'T', 'G', 'C', 'P', 'M', 'U', ' ']\n",
81
+ " self.vocab_size = len(self.chars)\n",
82
+ " self.string_to_index = {ch: i for i, ch in enumerate(self.chars)}\n",
83
+ " self.index_to_string = {i: ch for i, ch in enumerate(self.chars)}\n",
84
+ "\n",
85
+ " def encode(self, string):\n",
86
+ " encoded = []\n",
87
+ " for char in string:\n",
88
+ " if char in self.string_to_index:\n",
89
+ " encoded.append(self.string_to_index[char])\n",
90
+ " else:\n",
91
+ " special_index = len(self.string_to_index)\n",
92
+ " self.string_to_index[char] = special_index\n",
93
+ " self.index_to_string[special_index] = char\n",
94
+ " encoded.append(special_index)\n",
95
+ " return encoded\n",
96
+ "\n",
97
+ " def decode(self, integer):\n",
98
+ " decoded = []\n",
99
+ " for i in integer:\n",
100
+ " if i in self.index_to_string:\n",
101
+ " decoded.append(self.index_to_string[i])\n",
102
+ " else:\n",
103
+ " continue\n",
104
+ " return ''.join(decoded)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 4,
110
+ "metadata": {
111
+ "colab": {
112
+ "base_uri": "https://localhost:8080/"
113
+ },
114
+ "id": "6Ou9txgmAdIB",
115
+ "outputId": "cb5dd462-8b2a-445a-9524-1b484f288c64"
116
+ },
117
+ "outputs": [
118
+ {
119
+ "output_type": "stream",
120
+ "name": "stdout",
121
+ "text": [
122
+ "train data 436.97million, val data 48.55million\n"
123
+ ]
124
+ }
125
+ ],
126
+ "source": [
127
+ "token = PerCharTokenizer()\n",
128
+ "data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n",
129
+ "\n",
130
+ "# Train and test splits\n",
131
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
132
+ "train_data = data[:n]\n",
133
+ "val_data = data[n:]\n",
134
+ "print(f\"train data {(len(train_data)/1e6):.2f}million, val data {(len(val_data)/1e6):.2f}million\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 5,
140
+ "metadata": {
141
+ "id": "ebFKQQ9NAq4e"
142
+ },
143
+ "outputs": [],
144
+ "source": [
145
+ "# hyperparams\n",
146
+ "batch_size = 10\n",
147
+ "block_size = 512\n",
148
+ "max_iters = 5000\n",
149
+ "eval_interval = 100\n",
150
+ "learning_rate = 3e-4\n",
151
+ "eval_iters = 100\n",
152
+ "d_model = 384\n",
153
+ "n_layers = 12\n",
154
+ "n_head = 12\n",
155
+ "dropout = 0.25\n",
156
+ "norm_eps = 1e-4"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 6,
162
+ "metadata": {
163
+ "id": "dZMiYkr37cmU"
164
+ },
165
+ "outputs": [],
166
+ "source": [
167
+ "import math\n",
168
+ "import torch.nn as nn\n",
169
+ "from torch.nn import functional as F\n",
170
+ "\n",
171
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
172
+ "\n",
173
+ "class AttentionHead(nn.Module):\n",
174
+ " \"\"\"\n",
175
+ " initialize a single head of self attention.\n",
176
+ "\n",
177
+ " Args:\n",
178
+ " - d_model (int): dimensionality of the model's hidden layers\n",
179
+ " - head_size (int): dimensionality of each attention head\n",
180
+ " - dropout (float): dropout probability\n",
181
+ " - block_size (int): the maximum sequence length for positional encoding\n",
182
+ " \"\"\"\n",
183
+ " def __init__(self, d_model, head_size, dropout, block_size):\n",
184
+ " super().__init__()\n",
185
+ " self.key = nn.Linear(d_model, head_size, bias=True)\n",
186
+ " self.query = nn.Linear(d_model, head_size, bias=True)\n",
187
+ " self.value = nn.Linear(d_model, head_size, bias=False)\n",
188
+ " self.dropout = nn.Dropout(dropout)\n",
189
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
190
+ "\n",
191
+ " self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
192
+ "\n",
193
+ " def forward(self, x, mask=False):\n",
194
+ " \"\"\"\n",
195
+ " forward pass of a single attention head.\n",
196
+ "\n",
197
+ " Args:\n",
198
+ " - x (Tensor): input tensor.\n",
199
+ " - mask (bool): flag indicating whether to apply masking\n",
200
+ " Returns:\n",
201
+ " - out (Tensor): output tensor after self attention\n",
202
+ " \"\"\"\n",
203
+ " B, T, C = x.shape\n",
204
+ " key = self.key(x)\n",
205
+ " query = self.query(x)\n",
206
+ " scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
207
+ "\n",
208
+ " rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])\n",
209
+ " scores += rel_pos_scores\n",
210
+ "\n",
211
+ " if mask:\n",
212
+ " scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
213
+ " weights = F.softmax(scores, dim=-1)\n",
214
+ " weights = self.dropout(weights)\n",
215
+ "\n",
216
+ " value = self.value(x)\n",
217
+ " out = torch.matmul(weights, value)\n",
218
+ " return out\n",
219
+ "\n",
220
+ "class MultiHeadAttention(nn.Module):\n",
221
+ " \"\"\"\n",
222
+ " initialize a multi-head attention module.\n",
223
+ "\n",
224
+ " Args:\n",
225
+ " - d_model (int): dimensionality of the model's hidden layers\n",
226
+ " - n_head (int): no of attention heads\n",
227
+ " - dropout (float): dropout probability\n",
228
+ " - block_size (int): context length\n",
229
+ " \"\"\"\n",
230
+ " def __init__(self, d_model, n_head, dropout, block_size):\n",
231
+ " head_size = d_model // n_head\n",
232
+ " super().__init__()\n",
233
+ " self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])\n",
234
+ " self.proj = nn.Linear(n_head * head_size, d_model)\n",
235
+ " self.dropout = nn.Dropout(dropout)\n",
236
+ "\n",
237
+ " def forward(self, x, mask):\n",
238
+ " \"\"\"\n",
239
+ " forward pass of the multi-head attention module\n",
240
+ "\n",
241
+ " Args:\n",
242
+ " - x (Tensor): input tensor\n",
243
+ " - mask (bool): flag indicating whether to apply masking\n",
244
+ "\n",
245
+ " Returns:\n",
246
+ " - out (Tensor): output tensor after multi-head attention\n",
247
+ "\n",
248
+ " \"\"\"\n",
249
+ " out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)\n",
250
+ " out = self.dropout(self.proj(out))\n",
251
+ " return out\n",
252
+ "\n",
253
+ "class FeedForward(nn.Module):\n",
254
+ " \"\"\"\n",
255
+ " initialize a feedforward network module\n",
256
+ "\n",
257
+ " Args:\n",
258
+ " - d_model (int): the dimensionality of the model's hidden layers\n",
259
+ " - dropout (float): dropout probability\n",
260
+ "\n",
261
+ " \"\"\"\n",
262
+ " def __init__(self, d_model, dropout):\n",
263
+ " super().__init__()\n",
264
+ " self.net = nn.Sequential(\n",
265
+ " nn.Linear(d_model, 5*d_model),\n",
266
+ " nn.GELU(),\n",
267
+ " nn.Linear(5*d_model, d_model),\n",
268
+ " nn.Dropout(dropout)\n",
269
+ " )\n",
270
+ "\n",
271
+ " def forward(self, x):\n",
272
+ " \"\"\"\n",
273
+ " forward pass of the feedforward network module\n",
274
+ "\n",
275
+ " Args:\n",
276
+ " - x (Tensor): input tensor\n",
277
+ "\n",
278
+ " Returns:\n",
279
+ " - out (Tensor): output tensor after passing through the feedforward network\n",
280
+ " \"\"\"\n",
281
+ " return self.net(x)\n",
282
+ "\n",
283
+ "class EncoderNetwork(nn.Module):\n",
284
+ " \"\"\"\n",
285
+ " initialize an encoder network module\n",
286
+ "\n",
287
+ " Args:\n",
288
+ " - d_model (int): dimensionality of the model's hidden layers\n",
289
+ " - n_head (int): no of attention heads in multi-head attention layers\n",
290
+ " - norm_eps (float): epsilon value for layer normalization\n",
291
+ " - dropout (float): dropout probability\n",
292
+ " - block_size (int): the maximum sequence length for positional encoding\n",
293
+ " \"\"\"\n",
294
+ " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
295
+ " super().__init__()\n",
296
+ " self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
297
+ " self.ffwd = FeedForward(d_model, dropout)\n",
298
+ " self.dropout = nn.Dropout(dropout)\n",
299
+ " self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
300
+ " self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
301
+ "\n",
302
+ " def forward(self, src):\n",
303
+ " \"\"\"\n",
304
+ " forward pass of the encoder network module.\n",
305
+ "\n",
306
+ " Args:\n",
307
+ " - src (Tensor): input tensor representing source data\n",
308
+ "\n",
309
+ " Returns:\n",
310
+ " - src (Tensor): output tensor after passing through the encoder network\n",
311
+ " \"\"\"\n",
312
+ " src2 = self.s_att(src, mask=False)\n",
313
+ " src = src + self.dropout(src2)\n",
314
+ " src = self.norm1(src)\n",
315
+ "\n",
316
+ " src2 = self.ffwd(src)\n",
317
+ " src = src + self.dropout(src2)\n",
318
+ " src = self.norm2(src)\n",
319
+ "\n",
320
+ " return src\n",
321
+ "\n",
322
+ "class DecoderNetwork(nn.Module):\n",
323
+ " \"\"\"\n",
324
+ " initialize a decoder network module\n",
325
+ "\n",
326
+ " Args:\n",
327
+ " - d_model (int): dimensionality of the model's hidden layers\n",
328
+ " - n_head (int): no of attention heads in multi-head attention layers\n",
329
+ " - norm_eps (float): epsilon value for layer normalization\n",
330
+ " - dropout (float): dropout probability\n",
331
+ " - block_size (int): the maximum sequence length for positional encoding\n",
332
+ " \"\"\"\n",
333
+ " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
334
+ " super().__init__()\n",
335
+ " self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
336
+ " self.ffwd = FeedForward(d_model, dropout)\n",
337
+ " self.dropout = nn.Dropout(dropout)\n",
338
+ " self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
339
+ " self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
340
+ "\n",
341
+ " def forward(self, src, att):\n",
342
+ " \"\"\"\n",
343
+ " forward pass of the decoder network module.\n",
344
+ "\n",
345
+ " Args:\n",
346
+ " - src (Tensor): input tensor, same as the encoder's inputs\n",
347
+ " - trg (Tensor): encoder's attention matrix\n",
348
+ "\n",
349
+ " Returns:\n",
350
+ " - src_f (Tensor): final output tensor\n",
351
+ " \"\"\"\n",
352
+ " src2 = self.s_att(src, mask=True)\n",
353
+ " src = src + self.dropout(src2)\n",
354
+ " src = src + self.norm1(src)\n",
355
+ "\n",
356
+ " att = src + att\n",
357
+ " att2 = self.s_att(att, mask=False)\n",
358
+ " att2 = att + self.dropout(att2)\n",
359
+ " trg = att2 + self.norm1(att2)\n",
360
+ "\n",
361
+ " src_f2 = self.ffwd(self.norm2(trg))\n",
362
+ " src_f = src_f2 + self.dropout(src_f2)\n",
363
+ " src_f = self.norm2(src_f)\n",
364
+ "\n",
365
+ " return src_f\n",
366
+ "\n",
367
+ "class Transformer(nn.Module):\n",
368
+ " \"\"\"\n",
369
+ " initialize a Transformer model\n",
370
+ "\n",
371
+ " Args:\n",
372
+ " - vocab_size (int): size of the vocabulary\n",
373
+ " - d_model (int): dimensionality of the model's hidden layers\n",
374
+ " - block_size (int): maximum sequence length for positional encoding/context length\n",
375
+ " - n_layers (int): number of encoder and decoder layers in the Transformer\n",
376
+ " - n_head (int): number of attention heads in multi-head attention layers\n",
377
+ " - norm_eps (float): epsilon value for layer normalization\n",
378
+ " - dropout (float): dropout probability\n",
379
+ " \"\"\"\n",
380
+ " def __init__(self, vocab_size):\n",
381
+ " super().__init__()\n",
382
+ " self.block_size = block_size\n",
383
+ " self.toked_model = nn.Embedding(vocab_size, d_model)\n",
384
+ " self.pos_encod = nn.Embedding(block_size, d_model)\n",
385
+ " self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
386
+ " self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
387
+ "\n",
388
+ " self.norm_final = nn.LayerNorm(d_model)\n",
389
+ " self.linear_final = nn.Linear(d_model, vocab_size)\n",
390
+ " self.dropout = nn.Dropout(dropout)\n",
391
+ " self.apply(self._init_weights)\n",
392
+ "\n",
393
+ " def _init_weights(self, module):\n",
394
+ " \"\"\"\n",
395
+ " initialize weights of linear and embedding layers\n",
396
+ "\n",
397
+ " Args:\n",
398
+ " - module (nn.Module): the module to initialize weights for\n",
399
+ " \"\"\"\n",
400
+ " if isinstance(module, nn.Linear):\n",
401
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
402
+ " if module.bias is not None:\n",
403
+ " torch.nn.init.zeros_(module.bias.data)\n",
404
+ " elif isinstance(module, nn.Embedding):\n",
405
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
406
+ "\n",
407
+ " def forward(self, idx, targets=None):\n",
408
+ " \"\"\"\n",
409
+ " forward pass of the transformer model\n",
410
+ "\n",
411
+ " Args:\n",
412
+ " - idx (Tensor): input tensor representing token indices\n",
413
+ " - targets (Tensor): target tensor for computing loss during training\n",
414
+ "\n",
415
+ " Returns:\n",
416
+ " - logits (Tensor): output logits from the final linear layer\n",
417
+ " - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None\n",
418
+ " \"\"\"\n",
419
+ " B, T = idx.shape\n",
420
+ "\n",
421
+ " toked_model = self.toked_model(idx)\n",
422
+ " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
423
+ " x = toked_model + pos_encod\n",
424
+ "\n",
425
+ " for layer in self.enc_layer:\n",
426
+ " x_out = layer(x)\n",
427
+ "\n",
428
+ " for layer in self.dec_layer:\n",
429
+ " x_final = layer(x, x_out)\n",
430
+ "\n",
431
+ " x_final = self.norm_final(x_final)\n",
432
+ " logits = self.linear_final(x_final)\n",
433
+ "\n",
434
+ " if targets is None:\n",
435
+ " loss = None\n",
436
+ "\n",
437
+ " else:\n",
438
+ " B, T, C = logits.shape\n",
439
+ " logits = logits.view(B*T, C)\n",
440
+ " targets = targets.view(B*T)\n",
441
+ " loss = F.cross_entropy(logits, targets)\n",
442
+ "\n",
443
+ " return logits, loss\n",
444
+ " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
445
+ " \"\"\"\n",
446
+ " generate new tokens using the trained model\n",
447
+ "\n",
448
+ " Args:\n",
449
+ " - idx (Tensor): input tensor representing initial token indices\n",
450
+ " - max_new_tokens (int): max no of new tokens to generate\n",
451
+ " - temperature (float): softmax temperature for sampling\n",
452
+ " - top_k (int): no of top tokens to consider in sampling\n",
453
+ "\n",
454
+ " Returns:\n",
455
+ " - generated_tokens (list): list of generated token indices\n",
456
+ " \"\"\"\n",
457
+ " generated_tokens = []\n",
458
+ "\n",
459
+ " for _ in range(max_new_tokens):\n",
460
+ " idx_cond = idx[:, -self.block_size:]\n",
461
+ " logits, _ = self(idx_cond)\n",
462
+ " logits = logits[:, -1, :]\n",
463
+ "\n",
464
+ " scaled_logits = logits / temperature\n",
465
+ " if top_k > 0:\n",
466
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
467
+ "\n",
468
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
469
+ " sampled_idx = torch.multinomial(probs, num_samples=1)\n",
470
+ " generated_tokens.append(sampled_idx.item())\n",
471
+ " idx = torch.cat((idx, sampled_idx), dim=1)\n",
472
+ "\n",
473
+ " return generated_tokens\n",
474
+ "\n",
475
+ " def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
476
+ " \"\"\"\n",
477
+ " Generate predictions for masked tokens using the trained model.\n",
478
+ "\n",
479
+ " Args:\n",
480
+ " - idx (Tensor): input tensor representing token indices\n",
481
+ " - masked_indices (Tensor): tensor of indices indicating masked positions\n",
482
+ " - temperature (float): softmax temperature for sampling\n",
483
+ " - top_k (int): no of top tokens to consider in sampling\n",
484
+ "\n",
485
+ " Returns:\n",
486
+ " - predicted_tokens (Tensor): tensor of predicted token indices\n",
487
+ " \"\"\"\n",
488
+ " B, T = idx.shape\n",
489
+ "\n",
490
+ " toked_model = self.toked_model(idx)\n",
491
+ " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
492
+ " x = toked_model + pos_encod\n",
493
+ "\n",
494
+ " for layer in self.enc_layer:\n",
495
+ " x_out = layer(x)\n",
496
+ "\n",
497
+ " for layer in self.dec_layer:\n",
498
+ " x_final = layer(x, x_out)\n",
499
+ "\n",
500
+ " x_masked = x_final.clone()\n",
501
+ " x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
502
+ "\n",
503
+ " x_masked = self.norm_final(x_masked)\n",
504
+ " logits = self.linear_final(x_masked)\n",
505
+ "\n",
506
+ " masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
507
+ " scaled_logits = masked_logits / temperature\n",
508
+ " if top_k > 0:\n",
509
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
510
+ "\n",
511
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
512
+ " predicted_indices = torch.argmax(probs, dim=-1)\n",
513
+ "\n",
514
+ " return predicted_indices\n",
515
+ "\n",
516
+ " def _top_k_filtering(self, logits, top_k):\n",
517
+ " \"\"\"\n",
518
+ " filter logits to keep only the top-k tokens\n",
519
+ "\n",
520
+ " Args:\n",
521
+ " - logits (Tensor): input tensor representing unscaled logits\n",
522
+ " - top_k (int): no of top tokens to keep\n",
523
+ "\n",
524
+ " Returns:\n",
525
+ " - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
526
+ " \"\"\"\n",
527
+ " values, indices = torch.topk(logits, top_k, dim=-1)\n",
528
+ " min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
529
+ " filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
530
+ "\n",
531
+ " return filtered_logits"
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": 7,
537
+ "metadata": {
538
+ "colab": {
539
+ "base_uri": "https://localhost:8080/",
540
+ "height": 816
541
+ },
542
+ "id": "X9VOBZFr7g3W",
543
+ "outputId": "aa376025-0a37-4b93-e90a-9d95c6ef2c11"
544
+ },
545
+ "outputs": [
546
+ {
547
+ "output_type": "stream",
548
+ "name": "stdout",
549
+ "text": [
550
+ "2.5 billion parameters\n",
551
+ "step 0: train loss 2.2869, val loss 2.2884\n",
552
+ "step 100: train loss 1.3312, val loss 1.3281\n",
553
+ "step 200: train loss 1.3233, val loss 1.3181\n",
554
+ "step 300: train loss 1.3209, val loss 1.3196\n",
555
+ "step 400: train loss 1.3215, val loss 1.3203\n",
556
+ "step 500: train loss 1.1974, val loss 1.1994\n",
557
+ "step 600: train loss 0.3350, val loss 0.3365\n",
558
+ "step 700: train loss 0.0703, val loss 0.0702\n",
559
+ "step 800: train loss 0.0143, val loss 0.0143\n",
560
+ "step 900: train loss 0.0049, val loss 0.0047\n",
561
+ "step 1000: train loss 0.0041, val loss 0.0037\n",
562
+ "step 1100: train loss 0.0035, val loss 0.0036\n",
563
+ "step 1200: train loss 0.0038, val loss 0.0035\n",
564
+ "step 1300: train loss 0.0035, val loss 0.0033\n",
565
+ "step 1400: train loss 0.0035, val loss 0.0033\n",
566
+ "step 1500: train loss 0.0033, val loss 0.0033\n",
567
+ "step 1600: train loss 0.0033, val loss 0.0034\n",
568
+ "step 1700: train loss 0.0033, val loss 0.0033\n",
569
+ "step 1800: train loss 0.0033, val loss 0.0031\n",
570
+ "step 1900: train loss 0.0031, val loss 0.0031\n",
571
+ "step 2000: train loss 0.0032, val loss 0.0032\n"
572
+ ]
573
+ },
574
+ {
575
+ "output_type": "error",
576
+ "ename": "KeyboardInterrupt",
577
+ "evalue": "",
578
+ "traceback": [
579
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
580
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
581
+ "\u001b[0;32m<ipython-input-7-44818790f2dc>\u001b[0m in \u001b[0;36m<cell line: 45>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset_to_none\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
582
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
583
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
584
+ "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdec_layer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm_final\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
585
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
586
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
587
+ "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, src, att)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msrc\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0matt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ms_att\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 192\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mtrg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
588
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
589
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
590
+ "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
591
+ "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
592
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
593
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
594
+ "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
595
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
596
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
597
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
598
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
599
+ ]
600
+ }
601
+ ],
602
+ "source": [
603
+ "import timeit\n",
604
+ "\n",
605
+ "start_time = timeit.default_timer()\n",
606
+ "# data loading\n",
607
+ "def get_batch(split):\n",
608
+ "\n",
609
+ " data = train_data if split == 'train' else val_data\n",
610
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
611
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
612
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
613
+ " x, y = x.to(device), y.to(device)\n",
614
+ " return x, y\n",
615
+ "\n",
616
+ "@torch.no_grad()\n",
617
+ "def estimate_loss():\n",
618
+ " out = {}\n",
619
+ " model.eval()\n",
620
+ " for split in ['train', 'val']:\n",
621
+ " losses = torch.zeros(eval_iters)\n",
622
+ " for k in range(eval_iters):\n",
623
+ " X, Y = get_batch(split)\n",
624
+ " logits, loss = model(X, Y)\n",
625
+ " losses[k] = loss.item()\n",
626
+ " out[split] = losses.mean()\n",
627
+ " model.train()\n",
628
+ " return out\n",
629
+ "\n",
630
+ "vocab_size = token.vocab_size\n",
631
+ "model = Transformer(vocab_size)\n",
632
+ "# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
633
+ "# checkpoint = torch.load(checkpoint_path)\n",
634
+ "# model.load_state_dict(checkpoint)\n",
635
+ "m = model.to(device)\n",
636
+ "\n",
637
+ "# no of parameters\n",
638
+ "n_param = sum(p.numel() for p in m.parameters())/1e9\n",
639
+ "print(f\"{n_param:.1f} billion parameters\")\n",
640
+ "\n",
641
+ "# optimizer\n",
642
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
643
+ "steps = []\n",
644
+ "train_losses = []\n",
645
+ "val_losses = []\n",
646
+ "\n",
647
+ "for iter in range(max_iters):\n",
648
+ "\n",
649
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
650
+ " losses = estimate_loss()\n",
651
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
652
+ "\n",
653
+ " steps.append(iter)\n",
654
+ " train_losses.append(losses['train'])\n",
655
+ " val_losses.append(losses['val'])\n",
656
+ "\n",
657
+ " xb, yb = get_batch('train')\n",
658
+ " logits, loss = model(xb, yb)\n",
659
+ " optimizer.zero_grad(set_to_none=True)\n",
660
+ " loss.backward()\n",
661
+ " optimizer.step()"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "execution_count": 8,
667
+ "metadata": {
668
+ "id": "tzJMKoA35uIV",
669
+ "colab": {
670
+ "base_uri": "https://localhost:8080/"
671
+ },
672
+ "outputId": "ba527bf5-695c-4a8f-acc4-bd60d549eaad"
673
+ },
674
+ "outputs": [
675
+ {
676
+ "output_type": "stream",
677
+ "name": "stdout",
678
+ "text": [
679
+ "total parameters: 2.5 billion\n",
680
+ "trained in 1.82hrs\n"
681
+ ]
682
+ }
683
+ ],
684
+ "source": [
685
+ "end_time = timeit.default_timer()\n",
686
+ "print(f\"total parameters: {n_param:.1f} billion\")\n",
687
+ "print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "source": [
693
+ "model_save_name = f'enigma-{n_param:.1f}b_v1.pth'\n",
694
+ "path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
695
+ "torch.save(model.state_dict(), path)"
696
+ ],
697
+ "metadata": {
698
+ "id": "eB47Yn9aNrrO"
699
+ },
700
+ "execution_count": 10,
701
+ "outputs": []
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "source": [
706
+ "# 8-bit quantization\n",
707
+ "\n",
708
+ "import torch\n",
709
+ "import torch.quantization\n",
710
+ "\n",
711
+ "checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
712
+ "checkpoint = torch.load(checkpoint_path)\n",
713
+ "model.load_state_dict(checkpoint)\n",
714
+ "model = model.to(device)\n",
715
+ "\n",
716
+ "quantized_model = torch.quantization.quantize_dynamic(\n",
717
+ " model,\n",
718
+ " dtype=torch.qint8\n",
719
+ ")\n",
720
+ "quantized_model_file = f'/content/drive/MyDrive/enigma-2.5b-quant.pth'\n",
721
+ "torch.save(quantized_model.state_dict(), quantized_model_file)\n",
722
+ "\n",
723
+ "print(\"Quantized model saved successfully.\")"
724
+ ],
725
+ "metadata": {
726
+ "id": "7iGQdNHgms_U"
727
+ },
728
+ "execution_count": null,
729
+ "outputs": []
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "source": [
734
+ "# pruning\n",
735
+ "\n",
736
+ "import torch\n",
737
+ "from torch import nn\n",
738
+ "from torch.utils.model_zoo import load_url\n",
739
+ "import torch.nn.utils.prune as prune\n",
740
+ "\n",
741
+ "parameters_to_prune = [(model.encoder.self_attn, 'weight'), (model.encoder.linear1, 'weight')]\n",
742
+ "prune.global_unstructured(\n",
743
+ " parameters_to_prune,\n",
744
+ " pruning_method=prune.L1Unstructured,\n",
745
+ " amount=0.15,\n",
746
+ ")\n",
747
+ "\n",
748
+ "torch.save(model.state_dict(), 'enigma-2.5b_pruned.pth')"
749
+ ],
750
+ "metadata": {
751
+ "id": "YTJ19n4OFvZj"
752
+ },
753
+ "execution_count": null,
754
+ "outputs": []
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": null,
759
+ "metadata": {
760
+ "id": "K2FDOp7Quibq"
761
+ },
762
+ "outputs": [],
763
+ "source": [
764
+ "class Generator(Transformer):\n",
765
+ " def __init__(self, vocab_size, block_size):\n",
766
+ " super().__init__(vocab_size)\n",
767
+ " self.vocab_size = vocab_size\n",
768
+ " self.block_size = block_size\n",
769
+ "\n",
770
+ " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
771
+ " \"\"\"\n",
772
+ " generate new tokens using the trained model\n",
773
+ "\n",
774
+ " Args:\n",
775
+ " - idx (Tensor): input tensor representing initial token indices\n",
776
+ " - max_new_tokens (int): max no of new tokens to generate\n",
777
+ " - temperature (float): softmax temperature for sampling\n",
778
+ " - top_k (int): no of top tokens to consider in sampling\n",
779
+ "\n",
780
+ " Returns:\n",
781
+ " - generated_tokens (list): list of generated token indices\n",
782
+ " \"\"\"\n",
783
+ " generated_tokens = []\n",
784
+ "\n",
785
+ " for _ in range(max_new_tokens):\n",
786
+ " idx_cond = idx[:, -self.block_size:]\n",
787
+ " logits, _ = self(idx_cond)\n",
788
+ " logits = logits[:, -1, :]\n",
789
+ "\n",
790
+ " scaled_logits = logits / temperature\n",
791
+ " if top_k > 0:\n",
792
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
793
+ "\n",
794
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
795
+ " sampled_idx = torch.multinomial(probs, num_samples=1)\n",
796
+ " generated_tokens.append(sampled_idx.item())\n",
797
+ " idx = torch.cat((idx, sampled_idx), dim=1)\n",
798
+ "\n",
799
+ " return generated_tokens\n",
800
+ "\n",
801
+ " def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
802
+ " \"\"\"\n",
803
+ " Generate predictions for masked tokens using the trained model.\n",
804
+ "\n",
805
+ " Args:\n",
806
+ " - idx (Tensor): input tensor representing token indices\n",
807
+ " - masked_indices (Tensor): tensor of indices indicating masked positions\n",
808
+ " - temperature (float): softmax temperature for sampling\n",
809
+ " - top_k (int): no of top tokens to consider in sampling\n",
810
+ "\n",
811
+ " Returns:\n",
812
+ " - predicted_tokens (Tensor): tensor of predicted token indices\n",
813
+ " \"\"\"\n",
814
+ " B, T = idx.shape\n",
815
+ "\n",
816
+ " toked_model = self.toked_model(idx)\n",
817
+ " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
818
+ " x = toked_model + pos_encod\n",
819
+ "\n",
820
+ " for layer in self.enc_layer:\n",
821
+ " x_out = layer(x)\n",
822
+ "\n",
823
+ " for layer in self.dec_layer:\n",
824
+ " x_final = layer(x, x_out)\n",
825
+ "\n",
826
+ " x_masked = x_final.clone()\n",
827
+ " x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
828
+ "\n",
829
+ " x_masked = self.norm_final(x_masked)\n",
830
+ " logits = self.linear_final(x_masked)\n",
831
+ "\n",
832
+ " masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
833
+ " scaled_logits = masked_logits / temperature\n",
834
+ " if top_k > 0:\n",
835
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
836
+ "\n",
837
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
838
+ " predicted_indices = torch.argmax(probs, dim=-1)\n",
839
+ "\n",
840
+ " return predicted_indices\n",
841
+ "\n",
842
+ " def _top_k_filtering(self, logits, top_k):\n",
843
+ " \"\"\"\n",
844
+ " filter logits to keep only the top-k tokens\n",
845
+ "\n",
846
+ " Args:\n",
847
+ " - logits (Tensor): input tensor representing unscaled logits\n",
848
+ " - top_k (int): no of top tokens to keep\n",
849
+ "\n",
850
+ " Returns:\n",
851
+ " - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
852
+ " \"\"\"\n",
853
+ " values, indices = torch.topk(logits, top_k, dim=-1)\n",
854
+ " min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
855
+ " filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
856
+ "\n",
857
+ " return filtered_logits"
858
+ ]
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "metadata": {
864
+ "colab": {
865
+ "base_uri": "https://localhost:8080/",
866
+ "height": 429
867
+ },
868
+ "id": "c5CknylV4S2m",
869
+ "outputId": "12314d78-9147-4e60-f8b5-84207b97a1c7"
870
+ },
871
+ "outputs": [
872
+ {
873
+ "output_type": "error",
874
+ "ename": "RuntimeError",
875
+ "evalue": "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)",
876
+ "traceback": [
877
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
878
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
879
+ "\u001b[0;32m<ipython-input-17-db17ec37b06c>\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtarget_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"AGTTCTGCGAT\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgenerated_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtemperature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{target_text}{generated_output}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
880
+ "\u001b[0;32m<ipython-input-16-39da0e3e4598>\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, idx, max_new_tokens, temperature, top_k)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0midx_cond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblock_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx_cond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
881
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
882
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
883
+ "\u001b[0;32m<ipython-input-7-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtoked_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoked_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mpos_encod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_encod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoked_model\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mpos_encod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
884
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
885
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
886
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m return F.embedding(\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m self.norm_type, self.scale_grad_by_freq, self.sparse)\n",
887
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2235\u001b[0m \u001b[0;31m# remove once script supports set_grad_enabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2236\u001b[0m \u001b[0m_no_grad_embedding_renorm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2237\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2238\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2239\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
888
+ "\u001b[0;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)"
889
+ ]
890
+ }
891
+ ],
892
+ "source": [
893
+ "generator = Generator(vocab_size, block_size)\n",
894
+ "\n",
895
+ "target_text = \"AGTTCTGCGAT\"\n",
896
+ "context = torch.tensor([token.encode(target_text)], dtype=torch.long, device=device)\n",
897
+ "generated_output = token.decode(generator.generate(context, max_new_tokens=100, temperature=0.9, top_k=5))\n",
898
+ "print(f\"{target_text}{generated_output}\")"
899
+ ]
900
+ }
901
+ ],
902
+ "metadata": {
903
+ "accelerator": "GPU",
904
+ "colab": {
905
+ "gpuType": "T4",
906
+ "machine_shape": "hm",
907
+ "provenance": []
908
+ },
909
+ "kernelspec": {
910
+ "display_name": "Python 3",
911
+ "name": "python3"
912
+ },
913
+ "language_info": {
914
+ "name": "python"
915
+ }
916
+ },
917
+ "nbformat": 4,
918
+ "nbformat_minor": 0
919
+ }
enigma/config_enigma.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 10,
3
+ "block_size": 512,
4
+ "max_iters": 5000,
5
+ "eval_interval": 50,
6
+ "learning_rate": 3e-5,
7
+ "eval_iters": 100,
8
+ "d_model": 384,
9
+ "n_head": 12,
10
+ "n_layer": 12,
11
+ "dropout": 0.2,
12
+ "norm_eps": 1e-5
13
+ }
enigma/enigma.cpp ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <iostream>
3
+ #include <vector>
4
+
5
+ // Define device
6
+ torch::Device device(torch::kCUDA);
7
+
8
+ // Define constants
9
+ const int batch_size = 8;
10
+ const int block_size = 32;
11
+ const int max_iters = 1000;
12
+ const int eval_interval = 50;
13
+ const int eval_iters = 5;
14
+ const int d_model = 256;
15
+ const int n_layer = 16;
16
+ const int n_head = 12;
17
+ const float dropout = 0.2;
18
+ const float norm_eps = 1e-5;
19
+ const int vocab_size = 5;
20
+
21
+ // sample data
22
+ torch::Tensor train_data = torch::rand({1000, block_size});
23
+ torch::Tensor val_data = torch::rand({500, block_size});
24
+
25
+ // Data loading function
26
+ std::pair<torch::Tensor, torch::Tensor> get_batch(const std::string& split) {
27
+ torch::Tensor data = (split == "train") ? train_data : val_data;
28
+ torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size});
29
+ torch::Tensor x = torch::empty({batch_size, block_size});
30
+ torch::Tensor y = torch::empty({batch_size, block_size});
31
+ for (int i = 0; i < batch_size; ++i) {
32
+ x[i] = data.index({ix[i], ix[i] + block_size});
33
+ y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1});
34
+ }
35
+ return std::make_pair(x.to(device), y.to(device));
36
+ }
37
+
38
+ // Custom classes and functions
39
+ class SWiGLU : public torch::nn::Module {
40
+ public:
41
+ SWiGLU() {}
42
+
43
+ torch::Tensor forward(torch::Tensor x) {
44
+ torch::Tensor sigmoid_output = torch::sigmoid(x);
45
+ torch::Tensor relu_output = torch::relu(x);
46
+ torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x;
47
+ return out;
48
+ }
49
+ };
50
+
51
+ class UnMaskedHeadImpl : public torch::nn::Module {
52
+ public:
53
+ UnMaskedHeadImpl(int d_model, int head_size, float dropout)
54
+ : key(register_module("key", torch::nn::Linear(d_model, head_size))),
55
+ query(register_module("query", torch::nn::Linear(d_model, head_size))),
56
+ value(register_module("value", torch::nn::Linear(d_model, head_size))),
57
+ dropout(torch::nn::Dropout(dropout)) {
58
+ register_module("dropout", dropout);
59
+ }
60
+
61
+ torch::Tensor forward(torch::Tensor x) {
62
+ torch::Tensor key_out = key->forward(x);
63
+ torch::Tensor query_out = query->forward(x);
64
+
65
+ torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
66
+ weights = torch::softmax(weights, -1);
67
+ weights = dropout(weights);
68
+
69
+ torch::Tensor value_out = value->forward(x);
70
+ torch::Tensor out = weights.matmul(value_out);
71
+ return out;
72
+ }
73
+
74
+ private:
75
+ torch::nn::Linear key, query, value;
76
+ torch::nn::Dropout dropout;
77
+ };
78
+
79
+ TORCH_MODULE(UnMaskedHead);
80
+
81
+ class MaskedHeadImpl : public torch::nn::Module {
82
+ public:
83
+ MaskedHeadImpl(int head_size, float dropout, int d_model)
84
+ : key(register_module("key", torch::nn::Linear(d_model, head_size))),
85
+ query(register_module("query", torch::nn::Linear(d_model, head_size))),
86
+ value(register_module("value", torch::nn::Linear(d_model, head_size))),
87
+ dropout(torch::nn::Dropout(dropout)) {
88
+ register_buffer("tril", torch::tril(torch::ones(block_size, block_size)));
89
+ }
90
+
91
+ torch::Tensor forward(torch::Tensor x) {
92
+ torch::Tensor key_out = key->forward(x);
93
+ torch::Tensor query_out = query->forward(x);
94
+
95
+ torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
96
+ weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits<float>::lowest());
97
+ weights = torch::softmax(weights, -1);
98
+ weights = dropout(weights);
99
+
100
+ torch::Tensor value_out = value->forward(x);
101
+ torch::Tensor out = weights.matmul(value_out);
102
+ return out;
103
+ }
104
+
105
+ private:
106
+ torch::nn::Linear key, query, value;
107
+ torch::nn::Dropout dropout;
108
+ torch::Tensor tril;
109
+ };
110
+
111
+ TORCH_MODULE(MaskedHead);
112
+
113
+ class MultiUnMaskedImpl : public torch::nn::Module {
114
+ public:
115
+ MultiUnMaskedImpl(int d_model, int n_head, float dropout)
116
+ : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
117
+ dropout(torch::nn::Dropout(dropout)) {
118
+ for (int i = 0; i < n_head; ++i) {
119
+ heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout)));
120
+ }
121
+ }
122
+
123
+ torch::Tensor forward(torch::Tensor x) {
124
+ std::vector<torch::Tensor> head_outputs;
125
+ for (auto& head : heads) {
126
+ head_outputs.push_back(head->forward(x));
127
+ }
128
+ torch::Tensor out = torch::cat(head_outputs, -1);
129
+ out = dropout(out);
130
+ out = proj(out);
131
+ return out;
132
+ }
133
+
134
+ private:
135
+ torch::nn::Linear proj;
136
+ torch::nn::Dropout dropout;
137
+ std::vector<UnMaskedHead> heads;
138
+ };
139
+
140
+ TORCH_MODULE(MultiUnMasked);
141
+
142
+ class MultiMaskedImpl : public torch::nn::Module {
143
+ public:
144
+ MultiMaskedImpl(int d_model, int n_head, float dropout)
145
+ : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
146
+ dropout(torch::nn::Dropout(dropout)) {
147
+ for (int i = 0; i < n_head; ++i) {
148
+ heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout)));
149
+ }
150
+ }
151
+
152
+ torch::Tensor forward(torch::Tensor x) {
153
+ std::vector<torch::Tensor> head_outputs;
154
+ for (auto& head : heads) {
155
+ head_outputs.push_back(head->forward(x));
156
+ }
157
+ torch::Tensor out = torch::cat(head_outputs, -1);
158
+ out = dropout(out);
159
+ out = proj(out);
160
+ return out;
161
+ }
162
+
163
+ private:
164
+ torch::nn::Linear proj;
165
+ torch::nn::Dropout dropout;
166
+ std::vector<MaskedHead> heads;
167
+ };
168
+
169
+ TORCH_MODULE(MultiMasked);
170
+
171
+ class FeedForwardImpl : public torch::nn::Module {
172
+ public:
173
+ FeedForwardImpl(int d_model, float dropout)
174
+ : net(register_module("net", torch::nn::Sequential(
175
+ torch::nn::Linear(d_model, 4 * d_model),
176
+ torch::nn::GELU(),
177
+ torch::nn::Linear(4 * d_model, d_model),
178
+ torch::nn::Dropout(dropout)
179
+ ))) {}
180
+
181
+ torch::Tensor forward(torch::Tensor x) {
182
+ return net->forward(x);
183
+ }
184
+
185
+ private:
186
+ torch::nn::Sequential net;
187
+ };
188
+
189
+ TORCH_MODULE(FeedForward);
190
+
191
+ class BlockImpl : public torch::nn::Module {
192
+ public:
193
+ BlockImpl(int d_model, int n_head, float norm_eps, float dropout)
194
+ : sa_masked(MultiMasked(d_model, n_head, dropout)),
195
+ sa_unmasked(MultiUnMasked(d_model, n_head, dropout)),
196
+ ffwd(FeedForward(d_model, dropout)),
197
+ norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
198
+ norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {}
199
+
200
+ torch::Tensor forward(torch::Tensor x) {
201
+ torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x));
202
+ x = x2 + ffwd->forward(norm2->forward(x2));
203
+
204
+ x2 = x + sa_masked->forward(norm1->forward(x));
205
+ x = x2 + ffwd->forward(norm2->forward(x2));
206
+
207
+ return x;
208
+ }
209
+
210
+ private:
211
+ MultiMasked sa_masked;
212
+ MultiUnMasked sa_unmasked;
213
+ FeedForward ffwd;
214
+ torch::nn::LayerNorm norm1, norm2;
215
+ };
216
+
217
+ TORCH_MODULE(Block);
218
+
219
+ class EnigmaImpl : public torch::nn::Module {
220
+ public:
221
+ EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps)
222
+ : toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))),
223
+ pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))),
224
+ norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
225
+ linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) {
226
+ for (int i = 0; i < n_layer; ++i) {
227
+ block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout)));
228
+ }
229
+ register_buffer("block_size", torch::tensor(block_size));
230
+ _init_weights(this);
231
+ }
232
+
233
+ void _init_weights(torch::nn::Module* module) {
234
+ auto parameters = module->named_parameters();
235
+ for (auto& param : parameters) {
236
+ if (param.key().find("weight") != std::string::npos) {
237
+ torch::nn::init::normal_(param.value(), 0.0, 0.02);
238
+ } else if (param.key().find("bias") != std::string::npos) {
239
+ torch::nn::init::zeros_(param.value());
240
+ }
241
+ }
242
+ }
243
+
244
+ std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) {
245
+ torch::Tensor toked_model_out = toked_model->forward(idx);
246
+ torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1)));
247
+ torch::Tensor x = toked_model_out + pos_encod_out;
248
+
249
+ for (auto& block : block_layers) {
250
+ x = block->forward(x);
251
+ }
252
+
253
+ torch::Tensor logits = linear_final->forward(norm_final->forward(x));
254
+
255
+ if (!targets.numel()) {
256
+ return {logits, torch::Tensor()};
257
+ } else {
258
+ logits = logits.view({-1, logits.size(-1)});
259
+ targets = targets.view({-1});
260
+ torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets);
261
+ return {logits, loss};
262
+ }
263
+ }
264
+
265
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) {
266
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> completed_beams;
267
+ torch::Tensor current_idx = idx.clone();
268
+ std::vector<std::pair<torch::Tensor, float>> beam = {std::make_pair(current_idx, 0.0)};
269
+
270
+ for (int i = 0; i < max_new_tokens; ++i) {
271
+ std::vector<std::pair<torch::Tensor, float>> new_beam;
272
+
273
+ for (auto& beam_item : beam) {
274
+ torch::Tensor& current_idx = beam_item.first;
275
+ torch::Tensor logits, loss;
276
+ std::tie(logits, loss) = forward(current_idx);
277
+ logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions
278
+
279
+ // Apply softmax and temperature
280
+ torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1);
281
+
282
+ // Top-k sampling
283
+ if (top_k > 0) {
284
+ probs = top_k_filtering(probs, top_k);
285
+ }
286
+
287
+ // Sample from the distribution
288
+ torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true);
289
+
290
+ for (int j = 0; j < beam_width; ++j) {
291
+ torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1);
292
+ torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})}));
293
+ new_beam.push_back(std::make_pair(new_idx, new_log_prob.item()));
294
+ }
295
+ }
296
+
297
+ // Sort new beam by log probabilities
298
+ std::sort(new_beam.begin(), new_beam.end(), [](const std::pair<torch::Tensor, float>& a, const std::pair<torch::Tensor, float>& b) {
299
+ return a.second > b.second;
300
+ });
301
+
302
+ // Only keep top beams
303
+ beam = std::vector<std::pair<torch::Tensor, float>>(new_beam.begin(), new_beam.begin() + beam_width);
304
+ }
305
+
306
+ completed_beams.push_back(beam);
307
+ return completed_beams;
308
+ }
309
+
310
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> top_k_filtering(torch::Tensor logits, int top_k) {
311
+ torch::Tensor top_values, top_indices;
312
+ std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1);
313
+
314
+ torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1}));
315
+ torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits<float>::infinity()), logits);
316
+ return filtered_logits;
317
+ }
318
+
319
+ private:
320
+ torch::nn::Embedding toked_model, pos_encod;
321
+ std::vector<Block> block_layers;
322
+ torch::nn::LayerNorm norm_final;
323
+ torch::nn::Linear linear_final;
324
+ int block_size;
325
+ };
326
+
327
+ TORCH_MODULE(Enigma);
328
+
329
+ int main() {
330
+ // Set seed
331
+ torch::manual_seed(1400);
332
+
333
+ // Create model
334
+ Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps);
335
+ model->to(device);
336
+
337
+ // Define optimizer
338
+ torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate));
339
+
340
+ // Training loop
341
+ std::vector<float> train_losses, val_losses;
342
+ for (int iter = 0; iter < max_iters; ++iter) {
343
+ if (iter % eval_interval == 0 || iter == max_iters - 1) {
344
+ // Evaluate and print losses
345
+ auto losses = estimate_loss();
346
+ std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl;
347
+
348
+ // Save losses for plotting
349
+ train_losses.push_back(losses["train"]);
350
+ val_losses.push_back(losses["val"]);
351
+ }
352
+
353
+ // Get batch, forward pass, loss calculation, backward pass, optimizer step
354
+ auto [xb, yb] = get_batch("train");
355
+ torch::Tensor logits, loss;
356
+ std::tie(logits, loss) = model->forward(xb, yb);
357
+
358
+ optimizer.zero_grad();
359
+ loss.backward();
360
+ optimizer.step();
361
+ }
362
+
363
+ return 0;
364
+ }
enigma/generate.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ current_directory = os.path.dirname(os.path.abspath(__file__))
3
+ os.chdir(current_directory)
4
+
5
+ with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
6
+ captions = file.read()
7
+
8
+ print(f"{(len(captions)/1e6):.2f} million letters")
9
+
10
+ from tokenizer import PerCharTokenizer
11
+
12
+ tokenizer = PerCharTokenizer()
13
+ vocab_size = tokenizer.vocab_size
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn import functional as F
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ from model import Transformer
21
+ model = Transformer(vocab_size=vocab_size)
22
+
23
+ class Generator(Transformer):
24
+ def __init__(self, vocab_size):
25
+ super().__init__()
26
+ self.vocab_size = vocab_size
27
+ self.block_size = Transformer.block_size
28
+
29
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
30
+ """
31
+ generate new tokens using the trained model
32
+
33
+ Args:
34
+ - idx (Tensor): input tensor representing initial token indices
35
+ - max_new_tokens (int): max no of new tokens to generate
36
+ - temperature (float): softmax temperature for sampling
37
+ - top_k (int): no of top tokens to consider in sampling
38
+
39
+ Returns:
40
+ - generated_tokens (list): list of generated token indices
41
+ """
42
+ generated_tokens = []
43
+
44
+ for _ in range(max_new_tokens):
45
+ idx_cond = idx[:, -self.block_size:]
46
+ logits, _ = self(idx_cond)
47
+ logits = logits[:, -1, :]
48
+
49
+ scaled_logits = logits / temperature
50
+ if top_k > 0:
51
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
52
+
53
+ probs = F.softmax(scaled_logits, dim=-1)
54
+ sampled_idx = torch.multinomial(probs, num_samples=1)
55
+ generated_tokens.append(sampled_idx.item())
56
+ idx = torch.cat((idx, sampled_idx), dim=1)
57
+
58
+ return generated_tokens
59
+
60
+ def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
61
+ """
62
+ Generate predictions for masked tokens using the trained model.
63
+
64
+ Args:
65
+ - idx (Tensor): input tensor representing token indices
66
+ - masked_indices (Tensor): tensor of indices indicating masked positions
67
+ - temperature (float): softmax temperature for sampling
68
+ - top_k (int): no of top tokens to consider in sampling
69
+
70
+ Returns:
71
+ - predicted_tokens (Tensor): tensor of predicted token indices
72
+ """
73
+ B, T = idx.shape
74
+
75
+ toked_model = self.toked_model(idx)
76
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
77
+ x = toked_model + pos_encod
78
+
79
+ for layer in self.enc_layer:
80
+ x_out = layer(x)
81
+
82
+ for layer in self.dec_layer:
83
+ x_final = layer(x, x_out)
84
+
85
+ x_masked = x_final.clone()
86
+ x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
87
+
88
+ x_masked = self.norm_final(x_masked)
89
+ logits = self.linear_final(x_masked)
90
+
91
+ masked_logits = logits[masked_indices].view(-1, logits.size(-1))
92
+ scaled_logits = masked_logits / temperature
93
+ if top_k > 0:
94
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
95
+
96
+ probs = F.softmax(scaled_logits, dim=-1)
97
+ predicted_indices = torch.argmax(probs, dim=-1)
98
+
99
+ return predicted_indices
100
+
101
+ def _top_k_filtering(self, logits, top_k):
102
+ """
103
+ filter logits to keep only the top-k tokens
104
+
105
+ Args:
106
+ - logits (Tensor): input tensor representing unscaled logits
107
+ - top_k (int): no of top tokens to keep
108
+
109
+ Returns:
110
+ - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
111
+ """
112
+ values, indices = torch.topk(logits, top_k, dim=-1)
113
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
114
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
115
+
116
+ return filtered_logits
117
+
118
+ checkpoint_path = '../trained models/enigma_47m.pth'
119
+ checkpoint = torch.load(checkpoint_path)
120
+ model.load_state_dict(checkpoint)
121
+ m = model.to(device)
122
+
123
+ target_text = "AGTTCTGCGAT"
124
+ context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
125
+ generated_output = tokenizer.decode(Generator.generate(context, max_new_tokens=10, temperature=0.5, top_k=5))
126
+ print(f"{target_text}{generated_output}")
enigma/model.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ transformer based model, but with few minimal tweaks
3
+ trained a 2.5billion parameters model with current set configurations
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import os
9
+ current_directory = os.path.dirname(os.path.abspath(__file__))
10
+ os.chdir(current_directory)
11
+
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+ with open('config_enigma.json', 'r', encoding='utf-8') as file:
16
+ params = json.load(file)
17
+
18
+ batch_size = params['batch_size']
19
+ block_size = params['block_size']
20
+ n_head = params['n_head']
21
+ d_model = params['d_model']
22
+ n_layers = params['n_layer']
23
+ dropout = params['dropout']
24
+ norm_eps = params['norm_eps']
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+
27
+ class AttentionHead(nn.Module):
28
+ """
29
+ initialize a single head of self attention.
30
+
31
+ Args:
32
+ - d_model (int): dimensionality of the model's hidden layers
33
+ - head_size (int): dimensionality of each attention head
34
+ - dropout (float): dropout probability
35
+ - block_size (int): the maximum sequence length for positional encoding
36
+ """
37
+ def __init__(self, d_model, head_size, dropout, block_size):
38
+ super().__init__()
39
+ self.key = nn.Linear(d_model, head_size, bias=True)
40
+ self.query = nn.Linear(d_model, head_size, bias=True)
41
+ self.value = nn.Linear(d_model, head_size, bias=False)
42
+ self.dropout = nn.Dropout(dropout)
43
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
44
+
45
+ self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))
46
+
47
+ def forward(self, x, mask=False):
48
+ """
49
+ forward pass of a single attention head.
50
+
51
+ Args:
52
+ - x (Tensor): input tensor.
53
+ - mask (bool): flag indicating whether to apply masking
54
+
55
+ Returns:
56
+ - out (Tensor): output tensor after self attention
57
+ """
58
+ B, T, C = x.shape
59
+ key = self.key(x)
60
+ query = self.query(x)
61
+
62
+ scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
63
+ rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])
64
+ scores += rel_pos_scores
65
+
66
+ if mask:
67
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
68
+
69
+ weights = F.softmax(scores, dim=-1)
70
+ weights = self.dropout(weights)
71
+
72
+ value = self.value(x)
73
+ out = torch.matmul(weights, value)
74
+ return out
75
+
76
+ class MultiHeadAttention(nn.Module):
77
+ """
78
+ initialize a multi-head attention module.
79
+
80
+ Args:
81
+ - d_model (int): dimensionality of the model's hidden layers
82
+ - n_head (int): no of attention heads
83
+ - dropout (float): dropout probability
84
+ - block_size (int): context length
85
+ """
86
+ def __init__(self, d_model, n_head, dropout, block_size):
87
+ head_size = d_model // n_head
88
+ super().__init__()
89
+ self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])
90
+ self.proj = nn.Linear(n_head * head_size, d_model)
91
+ self.dropout = nn.Dropout(dropout)
92
+
93
+ def forward(self, x, mask):
94
+ """
95
+ forward pass of the multi-head attention module
96
+
97
+ Args:
98
+ - x (Tensor): input tensor
99
+ - mask (bool): flag indicating whether to apply masking
100
+
101
+ Returns:
102
+ - out (Tensor): output tensor after multi-head attention
103
+
104
+ """
105
+ out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)
106
+ out = self.dropout(self.proj(out))
107
+ return out
108
+
109
+ class FeedForward(nn.Module):
110
+ """
111
+ initialize a feedforward network module
112
+
113
+ Args:
114
+ - d_model (int): the dimensionality of the model's hidden layers
115
+ - dropout (float): dropout probability
116
+
117
+ """
118
+ def __init__(self, d_model, dropout):
119
+ super().__init__()
120
+ self.net = nn.Sequential(
121
+ nn.Linear(d_model, 10*d_model),
122
+ nn.GELU(),
123
+ nn.Linear(10*d_model, d_model),
124
+ nn.Dropout(dropout)
125
+ )
126
+
127
+ def forward(self, x):
128
+ """
129
+ forward pass of the feedforward network module
130
+
131
+ Args:
132
+ - x (Tensor): input tensor
133
+
134
+ Returns:
135
+ - out (Tensor): output tensor after passing through the feedforward network
136
+ """
137
+ return self.net(x)
138
+
139
+ class EncoderNetwork(nn.Module):
140
+ """
141
+ initialize an encoder network module
142
+
143
+ Args:
144
+ - d_model (int): dimensionality of the model's hidden layers
145
+ - n_head (int): no of attention heads in multi-head attention layers
146
+ - norm_eps (float): epsilon value for layer normalization
147
+ - dropout (float): dropout probability
148
+ - block_size (int): the maximum sequence length for positional encoding
149
+ """
150
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
151
+ super().__init__()
152
+ self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
153
+ self.ffwd = FeedForward(d_model, dropout)
154
+ self.dropout = nn.Dropout(dropout)
155
+ self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
156
+ self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
157
+
158
+ def forward(self, src):
159
+ """
160
+ forward pass of the encoder network module.
161
+
162
+ Args:
163
+ - src (Tensor): input tensor representing source data
164
+
165
+ Returns:
166
+ - src (Tensor): output tensor after passing through the encoder network
167
+ """
168
+ src2 = self.s_att(src, mask=False)
169
+ src = src + self.dropout(src2)
170
+ src = self.norm1(src)
171
+
172
+ src2 = self.ffwd(src)
173
+ src = src + self.dropout(src2)
174
+ src = self.norm2(src)
175
+
176
+ return src
177
+
178
+ class DecoderNetwork(nn.Module):
179
+ """
180
+ initialize a decoder network module
181
+
182
+ Args:
183
+ - d_model (int): dimensionality of the model's hidden layers
184
+ - n_head (int): no of attention heads in multi-head attention layers
185
+ - norm_eps (float): epsilon value for layer normalization
186
+ - dropout (float): dropout probability
187
+ - block_size (int): the maximum sequence length for positional encoding
188
+ """
189
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
190
+ super().__init__()
191
+ self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
192
+ self.ffwd = FeedForward(d_model, dropout)
193
+ self.dropout = nn.Dropout(dropout)
194
+ self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
195
+ self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
196
+
197
+ def forward(self, src, att):
198
+ """
199
+ forward pass of the decoder network module.
200
+
201
+ Args:
202
+ - src (Tensor): input tensor, same as the encoder's inputs
203
+ - trg (Tensor): encoder's attention matrix
204
+
205
+ Returns:
206
+ - src_f (Tensor): final output tensor
207
+ """
208
+ src2 = self.s_att(src, mask=True)
209
+ src = src + self.dropout(src2)
210
+ src = src + self.norm1(src)
211
+
212
+ att = src + att
213
+ att2 = self.s_att(att, mask=False)
214
+ att2 = att + self.dropout(att2)
215
+ trg = att2 + self.norm1(att2)
216
+
217
+ src_f2 = self.ffwd(self.norm2(trg))
218
+ src_f = src_f + self.dropout(src_f2)
219
+ src_f = self.norm2(src_f)
220
+
221
+ return src_f
222
+
223
+ class Transformer(nn.Module):
224
+ """
225
+ initialize a Transformer model
226
+
227
+ Args:
228
+ - vocab_size (int): size of the vocabulary
229
+ - d_model (int): dimensionality of the model's hidden layers
230
+ - block_size (int): maximum sequence length for positional encoding/context length
231
+ - n_layers (int): number of encoder and decoder layers in the Transformer
232
+ - n_head (int): number of attention heads in multi-head attention layers
233
+ - norm_eps (float): epsilon value for layer normalization
234
+ - dropout (float): dropout probability
235
+ """
236
+ def __init__(self, vocab_size):
237
+ super().__init__()
238
+ self.block_size = block_size
239
+ self.toked_model = nn.Embedding(vocab_size, d_model)
240
+ self.pos_encod = nn.Embedding(block_size, d_model)
241
+ self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
242
+ self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
243
+
244
+ self.norm_final = nn.LayerNorm(d_model)
245
+ self.linear_final = nn.Linear(d_model, vocab_size)
246
+ self.dropout = nn.Dropout(dropout)
247
+ self.apply(self._init_weights)
248
+
249
+ def _init_weights(self, module):
250
+ """
251
+ initialize weights of linear and embedding layers
252
+
253
+ Args:
254
+ - module (nn.Module): the module to initialize weights for
255
+ """
256
+ if isinstance(module, nn.Linear):
257
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
258
+ if module.bias is not None:
259
+ torch.nn.init.zeros_(module.bias.data)
260
+ elif isinstance(module, nn.Embedding):
261
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
262
+
263
+ def forward(self, idx, targets=None):
264
+ """
265
+ forward pass of the transformer model
266
+
267
+ Args:
268
+ - idx (Tensor): input tensor representing token indices
269
+ - targets (Tensor): target tensor for computing loss during training
270
+
271
+ Returns:
272
+ - logits (Tensor): output logits from the final linear layer
273
+ - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
274
+ """
275
+ B, T = idx.shape
276
+
277
+ toked_model = self.toked_model(idx)
278
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
279
+ x = toked_model + pos_encod
280
+
281
+ for layer in self.enc_layer:
282
+ x_out = layer(x)
283
+
284
+ for layer in self.dec_layer:
285
+ x_final = layer(x, x_out)
286
+
287
+ x_final = self.norm_final(x_final)
288
+ logits = self.linear_final(x_final)
289
+
290
+ if targets is None:
291
+ loss = None
292
+
293
+ else:
294
+ B, T, C = logits.shape
295
+ logits = logits.view(B*T, C)
296
+ targets = targets.view(B*T)
297
+ loss = F.cross_entropy(logits, targets)
298
+
299
+ return logits, loss
300
+
301
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
302
+ """
303
+ generate new tokens using the trained model
304
+
305
+ Args:
306
+ - idx (Tensor): input tensor representing initial token indices
307
+ - max_new_tokens (int): max no of new tokens to generate
308
+ - temperature (float): softmax temperature for sampling
309
+ - top_k (int): no of top tokens to consider in sampling
310
+
311
+ Returns:
312
+ - generated_tokens (list): list of generated token indices
313
+ """
314
+ generated_tokens = []
315
+
316
+ for _ in range(max_new_tokens):
317
+ idx_cond = idx[:, -self.block_size:]
318
+ logits, _ = self(idx_cond)
319
+ logits = logits[:, -1, :]
320
+
321
+ scaled_logits = logits / temperature
322
+ if top_k > 0:
323
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
324
+
325
+ probs = F.softmax(scaled_logits, dim=-1)
326
+ sampled_idx = torch.multinomial(probs, num_samples=1)
327
+ generated_tokens.append(sampled_idx.item())
328
+ idx = torch.cat((idx, sampled_idx), dim=1)
329
+
330
+ return generated_tokens
331
+
332
+ def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
333
+ """
334
+ Generate predictions for masked tokens using the trained model.
335
+
336
+ Args:
337
+ - idx (Tensor): input tensor representing token indices
338
+ - masked_indices (Tensor): tensor of indices indicating masked positions
339
+ - temperature (float): softmax temperature for sampling
340
+ - top_k (int): no of top tokens to consider in sampling
341
+
342
+ Returns:
343
+ - predicted_tokens (Tensor): tensor of predicted token indices
344
+ """
345
+ B, T = idx.shape
346
+
347
+ toked_model = self.toked_model(idx)
348
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
349
+ x = toked_model + pos_encod
350
+
351
+ for layer in self.enc_layer:
352
+ x_out = layer(x)
353
+
354
+ for layer in self.dec_layer:
355
+ x_final = layer(x, x_out)
356
+
357
+ x_masked = x_final.clone()
358
+ x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
359
+
360
+ x_masked = self.norm_final(x_masked)
361
+ logits = self.linear_final(x_masked)
362
+
363
+ masked_logits = logits[masked_indices].view(-1, logits.size(-1))
364
+ scaled_logits = masked_logits / temperature
365
+ if top_k > 0:
366
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
367
+
368
+ probs = F.softmax(scaled_logits, dim=-1)
369
+ predicted_indices = torch.argmax(probs, dim=-1)
370
+
371
+ return predicted_indices
372
+
373
+ def _top_k_filtering(self, logits, top_k):
374
+ """
375
+ filter logits to keep only the top-k tokens
376
+
377
+ Args:
378
+ - logits (Tensor): input tensor representing unscaled logits
379
+ - top_k (int): no of top tokens to keep
380
+
381
+ Returns:
382
+ - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
383
+ """
384
+ values, indices = torch.topk(logits, top_k, dim=-1)
385
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
386
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
387
+
388
+ return filtered_logits
enigma/run.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ use this file to train the model
3
+
4
+ working:
5
+ - imports vatious dependencies first, and then loads the training data
6
+ - tokenizes it, per-character basis
7
+ - loads the required hyper-parameters and the model file
8
+ - trains it till 'max_iters' and saves the model state, and generates outputs
9
+
10
+ with the current set configuration, model can reach upto ~60million parameters
11
+ and can become ~99% accurate with next token prediction
12
+ """
13
+
14
+ import torch
15
+ import json
16
+ import os
17
+ current_directory = os.path.dirname(os.path.abspath(__file__))
18
+ os.chdir(current_directory)
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
22
+ captions = file.read()
23
+
24
+ print(f"{(len(captions)/1e6):.2f} million letters")
25
+
26
+ from ..tokenizer import PerCharTokenizer
27
+
28
+ tokenizer = PerCharTokenizer()
29
+ vocab_size = tokenizer.vocab_size
30
+ # Train and test splits
31
+ data = torch.tensor(tokenizer.encode(captions), dtype=torch.long)
32
+ n = int(0.9*len(data)) # first 90% will be train, rest val
33
+ train_data = data[:n]
34
+ val_data = data[n:]
35
+
36
+ with open('/config_enigma.json', 'r', encoding='utf-8') as file:
37
+ params = json.load(file)
38
+
39
+ # required parameters
40
+ batch_size = params['batch_size']
41
+ block_size = params['block_size']
42
+ max_iters = params['max_iters']
43
+ eval_interval = params['eval_interval']
44
+ eval_iters = params['eval_iters']
45
+ learning_rate = params['learning_rate']
46
+
47
+ torch.manual_seed(1400)
48
+ # data loading
49
+ def get_batch(split):
50
+ # generate a small batch of data of inputs x and targets y
51
+ data = train_data if split == 'train' else val_data
52
+ ix = torch.randint(len(data) - block_size, (batch_size,))
53
+ x = torch.stack([data[i:i+block_size] for i in ix])
54
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
55
+ x, y = x.to(device), y.to(device)
56
+ return x, y
57
+
58
+ @torch.no_grad()
59
+ def estimate_loss():
60
+ out = {}
61
+ model.eval()
62
+ for split in ['train', 'val']:
63
+ losses = torch.zeros(eval_iters)
64
+ for k in range(eval_iters):
65
+ X, Y = get_batch(split)
66
+ logits, loss = model(X, Y)
67
+ losses[k] = loss.item()
68
+ out[split] = losses.mean()
69
+ model.train()
70
+ return out
71
+
72
+ from model import Transformer
73
+ model = Transformer(vocab_size=vocab_size)
74
+ m = model.to(device)
75
+
76
+ # no of parameters
77
+ n_param = sum(p.numel() for p in m.parameters())/1e6
78
+ print(f"{n_param:.2f} million")
79
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
80
+ steps = []
81
+ train_losses = []
82
+ val_losses = []
83
+
84
+ for iter in range(max_iters):
85
+
86
+ if iter % eval_interval == 0 or iter == max_iters - 1:
87
+ losses = estimate_loss()
88
+ print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
89
+
90
+ steps.append(iter)
91
+ train_losses.append(losses['train'])
92
+ val_losses.append(losses['val'])
93
+
94
+ xb, yb = get_batch('train')
95
+ logits, loss = model(xb, yb)
96
+ optimizer.zero_grad(set_to_none=True)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')