AYYasaswini commited on
Commit
ccdb496
·
verified ·
1 Parent(s): 606f11c

Upload gpt_dev.ipynb

Browse files
Files changed (1) hide show
  1. gpt_dev.ipynb +1556 -0
gpt_dev.ipynb ADDED
@@ -0,0 +1,1556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "## Building a GPT\n",
21
+ "\n",
22
+ "Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT."
23
+ ],
24
+ "metadata": {
25
+ "id": "wJpXpmjEYC_T"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 3,
31
+ "metadata": {
32
+ "colab": {
33
+ "base_uri": "https://localhost:8080/"
34
+ },
35
+ "id": "h5hjCcLDr2WC",
36
+ "outputId": "24b008b5-5eb3-4882-a553-1ef45aaaf782"
37
+ },
38
+ "outputs": [
39
+ {
40
+ "output_type": "stream",
41
+ "name": "stdout",
42
+ "text": [
43
+ "--2024-06-11 13:37:04-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
44
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
45
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
46
+ "HTTP request sent, awaiting response... 200 OK\n",
47
+ "Length: 1115394 (1.1M) [text/plain]\n",
48
+ "Saving to: ‘input.txt.1’\n",
49
+ "\n",
50
+ "\rinput.txt.1 0%[ ] 0 --.-KB/s \rinput.txt.1 100%[===================>] 1.06M --.-KB/s in 0.05s \n",
51
+ "\n",
52
+ "2024-06-11 13:37:04 (21.7 MB/s) - ‘input.txt.1’ saved [1115394/1115394]\n",
53
+ "\n"
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
59
+ "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "source": [
65
+ "# read it in to inspect it\n",
66
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
67
+ " text = f.read()"
68
+ ],
69
+ "metadata": {
70
+ "id": "O6medjfRsLD9"
71
+ },
72
+ "execution_count": 4,
73
+ "outputs": []
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "source": [
78
+ "print(\"length of dataset in characters: \", len(text))"
79
+ ],
80
+ "metadata": {
81
+ "colab": {
82
+ "base_uri": "https://localhost:8080/"
83
+ },
84
+ "id": "6xWI_VyAsN8F",
85
+ "outputId": "68d2ea04-26cd-4ce8-f31e-10868b38f7d0"
86
+ },
87
+ "execution_count": 5,
88
+ "outputs": [
89
+ {
90
+ "output_type": "stream",
91
+ "name": "stdout",
92
+ "text": [
93
+ "length of dataset in characters: 1115394\n"
94
+ ]
95
+ }
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "source": [
101
+ "# let's look at the first 1000 characters\n",
102
+ "print(text[:1000])"
103
+ ],
104
+ "metadata": {
105
+ "colab": {
106
+ "base_uri": "https://localhost:8080/"
107
+ },
108
+ "id": "2c5V0FvqseE0",
109
+ "outputId": "5306e25a-cad6-4ac6-9d34-8138bbaa34a4"
110
+ },
111
+ "execution_count": 6,
112
+ "outputs": [
113
+ {
114
+ "output_type": "stream",
115
+ "name": "stdout",
116
+ "text": [
117
+ "First Citizen:\n",
118
+ "Before we proceed any further, hear me speak.\n",
119
+ "\n",
120
+ "All:\n",
121
+ "Speak, speak.\n",
122
+ "\n",
123
+ "First Citizen:\n",
124
+ "You are all resolved rather to die than to famish?\n",
125
+ "\n",
126
+ "All:\n",
127
+ "Resolved. resolved.\n",
128
+ "\n",
129
+ "First Citizen:\n",
130
+ "First, you know Caius Marcius is chief enemy to the people.\n",
131
+ "\n",
132
+ "All:\n",
133
+ "We know't, we know't.\n",
134
+ "\n",
135
+ "First Citizen:\n",
136
+ "Let us kill him, and we'll have corn at our own price.\n",
137
+ "Is't a verdict?\n",
138
+ "\n",
139
+ "All:\n",
140
+ "No more talking on't; let it be done: away, away!\n",
141
+ "\n",
142
+ "Second Citizen:\n",
143
+ "One word, good citizens.\n",
144
+ "\n",
145
+ "First Citizen:\n",
146
+ "We are accounted poor citizens, the patricians good.\n",
147
+ "What authority surfeits on would relieve us: if they\n",
148
+ "would yield us but the superfluity, while it were\n",
149
+ "wholesome, we might guess they relieved us humanely;\n",
150
+ "but they think we are too dear: the leanness that\n",
151
+ "afflicts us, the object of our misery, is as an\n",
152
+ "inventory to particularise their abundance; our\n",
153
+ "sufferance is a gain to them Let us revenge this with\n",
154
+ "our pikes, ere we become rakes: for the gods know I\n",
155
+ "speak this in hunger for bread, not in thirst for revenge.\n",
156
+ "\n",
157
+ "\n"
158
+ ]
159
+ }
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "source": [
165
+ "# here are all the unique characters that occur in this text\n",
166
+ "chars = sorted(list(set(text)))\n",
167
+ "vocab_size = len(chars)\n",
168
+ "print(''.join(chars))\n",
169
+ "print(vocab_size)"
170
+ ],
171
+ "metadata": {
172
+ "colab": {
173
+ "base_uri": "https://localhost:8080/"
174
+ },
175
+ "id": "0e-Rbyr8sfM8",
176
+ "outputId": "3cfb92f5-e9dc-4a4d-bc24-01c34e91fe2c"
177
+ },
178
+ "execution_count": 7,
179
+ "outputs": [
180
+ {
181
+ "output_type": "stream",
182
+ "name": "stdout",
183
+ "text": [
184
+ "\n",
185
+ " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
186
+ "65\n"
187
+ ]
188
+ }
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "# create a mapping from characters to integers\n",
195
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
196
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
197
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
198
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
199
+ "\n",
200
+ "print(encode(\"hii there\"))\n",
201
+ "print(decode(encode(\"hii there\")))"
202
+ ],
203
+ "metadata": {
204
+ "colab": {
205
+ "base_uri": "https://localhost:8080/"
206
+ },
207
+ "id": "Yw1LKNCgwjj1",
208
+ "outputId": "b32844f8-99ed-4eb8-c569-06196f56051f"
209
+ },
210
+ "execution_count": 8,
211
+ "outputs": [
212
+ {
213
+ "output_type": "stream",
214
+ "name": "stdout",
215
+ "text": [
216
+ "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
217
+ "hii there\n"
218
+ ]
219
+ }
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "source": [
225
+ "# let's now encode the entire text dataset and store it into a torch.Tensor\n",
226
+ "import torch # we use PyTorch: https://pytorch.org\n",
227
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
228
+ "print(data.shape, data.dtype)\n",
229
+ "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this"
230
+ ],
231
+ "metadata": {
232
+ "colab": {
233
+ "base_uri": "https://localhost:8080/"
234
+ },
235
+ "id": "YJb0OXPwzvqg",
236
+ "outputId": "7081b874-3ef5-4e65-ee10-acbc24ac9f9b"
237
+ },
238
+ "execution_count": 9,
239
+ "outputs": [
240
+ {
241
+ "output_type": "stream",
242
+ "name": "stdout",
243
+ "text": [
244
+ "torch.Size([1115394]) torch.int64\n",
245
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
246
+ " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
247
+ " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
248
+ " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
249
+ " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
250
+ " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50,\n",
251
+ " 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58,\n",
252
+ " 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47,\n",
253
+ " 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42,\n",
254
+ " 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58,\n",
255
+ " 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63,\n",
256
+ " 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41,\n",
257
+ " 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63,\n",
258
+ " 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13,\n",
259
+ " 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1,\n",
260
+ " 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58,\n",
261
+ " 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1,\n",
262
+ " 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60,\n",
263
+ " 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1,\n",
264
+ " 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42,\n",
265
+ " 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43,\n",
266
+ " 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58,\n",
267
+ " 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6,\n",
268
+ " 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58,\n",
269
+ " 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53,\n",
270
+ " 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57,\n",
271
+ " 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1,\n",
272
+ " 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58,\n",
273
+ " 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47,\n",
274
+ " 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58,\n",
275
+ " 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52,\n",
276
+ " 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10,\n",
277
+ " 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43,\n",
278
+ " 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43,\n",
279
+ " 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1,\n",
280
+ " 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43,\n",
281
+ " 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1,\n",
282
+ " 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43,\n",
283
+ " 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49,\n",
284
+ " 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1,\n",
285
+ " 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0,\n",
286
+ " 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53,\n",
287
+ " 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56,\n",
288
+ " 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58,\n",
289
+ " 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n",
290
+ " 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n",
291
+ " 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47,\n",
292
+ " 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24,\n",
293
+ " 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57,\n",
294
+ " 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43,\n",
295
+ " 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57,\n",
296
+ " 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53,\n",
297
+ " 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1,\n",
298
+ " 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1,\n",
299
+ " 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1,\n",
300
+ " 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])\n"
301
+ ]
302
+ }
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "source": [
308
+ "# Let's now split up the data into train and validation sets\n",
309
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
310
+ "train_data = data[:n]\n",
311
+ "val_data = data[n:]"
312
+ ],
313
+ "metadata": {
314
+ "id": "f_WIXqxz0lU5"
315
+ },
316
+ "execution_count": 10,
317
+ "outputs": []
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "source": [
322
+ "block_size = 8\n",
323
+ "train_data[:block_size+1]"
324
+ ],
325
+ "metadata": {
326
+ "colab": {
327
+ "base_uri": "https://localhost:8080/"
328
+ },
329
+ "id": "TD5Bj8Y6IAD4",
330
+ "outputId": "44a45420-f035-40e7-a089-7685ca25d361"
331
+ },
332
+ "execution_count": 11,
333
+ "outputs": [
334
+ {
335
+ "output_type": "execute_result",
336
+ "data": {
337
+ "text/plain": [
338
+ "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])"
339
+ ]
340
+ },
341
+ "metadata": {},
342
+ "execution_count": 11
343
+ }
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "source": [
349
+ "x = train_data[:block_size]\n",
350
+ "y = train_data[1:block_size+1]\n",
351
+ "for t in range(block_size):\n",
352
+ " context = x[:t+1]\n",
353
+ " target = y[t]\n",
354
+ " print(f\"when input is {context} the target: {target}\")"
355
+ ],
356
+ "metadata": {
357
+ "colab": {
358
+ "base_uri": "https://localhost:8080/"
359
+ },
360
+ "id": "9HXDe8vGJCEn",
361
+ "outputId": "96af3b4e-7307-4949-c0f9-05c892514196"
362
+ },
363
+ "execution_count": 12,
364
+ "outputs": [
365
+ {
366
+ "output_type": "stream",
367
+ "name": "stdout",
368
+ "text": [
369
+ "when input is tensor([18]) the target: 47\n",
370
+ "when input is tensor([18, 47]) the target: 56\n",
371
+ "when input is tensor([18, 47, 56]) the target: 57\n",
372
+ "when input is tensor([18, 47, 56, 57]) the target: 58\n",
373
+ "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
374
+ "when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n",
375
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n",
376
+ "when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n"
377
+ ]
378
+ }
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "torch.manual_seed(1337)\n",
385
+ "batch_size = 4 # how many independent sequences will we process in parallel?\n",
386
+ "block_size = 8 # what is the maximum context length for predictions?\n",
387
+ "\n",
388
+ "def get_batch(split):\n",
389
+ " # generate a small batch of data of inputs x and targets y\n",
390
+ " data = train_data if split == 'train' else val_data\n",
391
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
392
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
393
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
394
+ " return x, y\n",
395
+ "\n",
396
+ "xb, yb = get_batch('train')\n",
397
+ "print('inputs:')\n",
398
+ "print(xb.shape)\n",
399
+ "print(xb)\n",
400
+ "print('targets:')\n",
401
+ "print(yb.shape)\n",
402
+ "print(yb)\n",
403
+ "\n",
404
+ "print('----')\n",
405
+ "\n",
406
+ "for b in range(batch_size): # batch dimension\n",
407
+ " for t in range(block_size): # time dimension\n",
408
+ " context = xb[b, :t+1]\n",
409
+ " target = yb[b,t]\n",
410
+ " print(f\"when input is {context.tolist()} the target: {target}\")"
411
+ ],
412
+ "metadata": {
413
+ "colab": {
414
+ "base_uri": "https://localhost:8080/"
415
+ },
416
+ "id": "Q3k1Czf7LuA9",
417
+ "outputId": "e7e206dc-1cae-4f95-a82d-5faa6fd1c627"
418
+ },
419
+ "execution_count": 13,
420
+ "outputs": [
421
+ {
422
+ "output_type": "stream",
423
+ "name": "stdout",
424
+ "text": [
425
+ "inputs:\n",
426
+ "torch.Size([4, 8])\n",
427
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
428
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
429
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
430
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n",
431
+ "targets:\n",
432
+ "torch.Size([4, 8])\n",
433
+ "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
434
+ " [53, 56, 1, 58, 46, 39, 58, 1],\n",
435
+ " [58, 1, 58, 46, 39, 58, 1, 46],\n",
436
+ " [17, 27, 10, 0, 21, 1, 54, 39]])\n",
437
+ "----\n",
438
+ "when input is [24] the target: 43\n",
439
+ "when input is [24, 43] the target: 58\n",
440
+ "when input is [24, 43, 58] the target: 5\n",
441
+ "when input is [24, 43, 58, 5] the target: 57\n",
442
+ "when input is [24, 43, 58, 5, 57] the target: 1\n",
443
+ "when input is [24, 43, 58, 5, 57, 1] the target: 46\n",
444
+ "when input is [24, 43, 58, 5, 57, 1, 46] the target: 43\n",
445
+ "when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39\n",
446
+ "when input is [44] the target: 53\n",
447
+ "when input is [44, 53] the target: 56\n",
448
+ "when input is [44, 53, 56] the target: 1\n",
449
+ "when input is [44, 53, 56, 1] the target: 58\n",
450
+ "when input is [44, 53, 56, 1, 58] the target: 46\n",
451
+ "when input is [44, 53, 56, 1, 58, 46] the target: 39\n",
452
+ "when input is [44, 53, 56, 1, 58, 46, 39] the target: 58\n",
453
+ "when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1\n",
454
+ "when input is [52] the target: 58\n",
455
+ "when input is [52, 58] the target: 1\n",
456
+ "when input is [52, 58, 1] the target: 58\n",
457
+ "when input is [52, 58, 1, 58] the target: 46\n",
458
+ "when input is [52, 58, 1, 58, 46] the target: 39\n",
459
+ "when input is [52, 58, 1, 58, 46, 39] the target: 58\n",
460
+ "when input is [52, 58, 1, 58, 46, 39, 58] the target: 1\n",
461
+ "when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46\n",
462
+ "when input is [25] the target: 17\n",
463
+ "when input is [25, 17] the target: 27\n",
464
+ "when input is [25, 17, 27] the target: 10\n",
465
+ "when input is [25, 17, 27, 10] the target: 0\n",
466
+ "when input is [25, 17, 27, 10, 0] the target: 21\n",
467
+ "when input is [25, 17, 27, 10, 0, 21] the target: 1\n",
468
+ "when input is [25, 17, 27, 10, 0, 21, 1] the target: 54\n",
469
+ "when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39\n"
470
+ ]
471
+ }
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "source": [
477
+ "print(xb) # our input to the transformer"
478
+ ],
479
+ "metadata": {
480
+ "colab": {
481
+ "base_uri": "https://localhost:8080/"
482
+ },
483
+ "id": "qpyyAeIzQjlO",
484
+ "outputId": "febd3181-36c8-4567-f33c-dbfc4cbc99d5"
485
+ },
486
+ "execution_count": 14,
487
+ "outputs": [
488
+ {
489
+ "output_type": "stream",
490
+ "name": "stdout",
491
+ "text": [
492
+ "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
493
+ " [44, 53, 56, 1, 58, 46, 39, 58],\n",
494
+ " [52, 58, 1, 58, 46, 39, 58, 1],\n",
495
+ " [25, 17, 27, 10, 0, 21, 1, 54]])\n"
496
+ ]
497
+ }
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "source": [
503
+ "import torch\n",
504
+ "import torch.nn as nn\n",
505
+ "from torch.nn import functional as F\n",
506
+ "torch.manual_seed(1337)\n",
507
+ "\n",
508
+ "class BigramLanguageModel(nn.Module):\n",
509
+ "\n",
510
+ " def __init__(self, vocab_size):\n",
511
+ " super().__init__()\n",
512
+ " # each token directly reads off the logits for the next token from a lookup table\n",
513
+ " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
514
+ "\n",
515
+ " def forward(self, idx, targets=None):\n",
516
+ "\n",
517
+ " # idx and targets are both (B,T) tensor of integers\n",
518
+ " logits = self.token_embedding_table(idx) # (B,T,C)\n",
519
+ "\n",
520
+ " if targets is None:\n",
521
+ " loss = None\n",
522
+ " else:\n",
523
+ " B, T, C = logits.shape\n",
524
+ " logits = logits.view(B*T, C)\n",
525
+ " targets = targets.view(B*T)\n",
526
+ " loss = F.cross_entropy(logits, targets)\n",
527
+ "\n",
528
+ " return logits, loss\n",
529
+ "\n",
530
+ " def generate(self, idx, max_new_tokens):\n",
531
+ " # idx is (B, T) array of indices in the current context\n",
532
+ " for _ in range(max_new_tokens):\n",
533
+ " # get the predictions\n",
534
+ " logits, loss = self(idx)\n",
535
+ " # focus only on the last time step\n",
536
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
537
+ " # apply softmax to get probabilities\n",
538
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
539
+ " # sample from the distribution\n",
540
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
541
+ " # append sampled index to the running sequence\n",
542
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
543
+ " return idx\n",
544
+ "\n",
545
+ "m = BigramLanguageModel(vocab_size)\n",
546
+ "logits, loss = m(xb, yb)\n",
547
+ "print(logits.shape)\n",
548
+ "print(loss)\n",
549
+ "\n",
550
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))\n"
551
+ ],
552
+ "metadata": {
553
+ "colab": {
554
+ "base_uri": "https://localhost:8080/"
555
+ },
556
+ "id": "nql_1ER53oCf",
557
+ "outputId": "7b1620c9-3bf2-45a2-8e08-d6ca73d09528"
558
+ },
559
+ "execution_count": 15,
560
+ "outputs": [
561
+ {
562
+ "output_type": "stream",
563
+ "name": "stdout",
564
+ "text": [
565
+ "torch.Size([32, 65])\n",
566
+ "tensor(4.8786, grad_fn=<NllLossBackward0>)\n",
567
+ "\n",
568
+ "Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3\n"
569
+ ]
570
+ }
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "source": [
576
+ "# create a PyTorch optimizer\n",
577
+ "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
578
+ ],
579
+ "metadata": {
580
+ "id": "eTyJ8qAaDdiF"
581
+ },
582
+ "execution_count": 16,
583
+ "outputs": []
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "source": [
588
+ "batch_size = 32\n",
589
+ "for steps in range(100): # increase number of steps for good results...\n",
590
+ "\n",
591
+ " # sample a batch of data\n",
592
+ " xb, yb = get_batch('train')\n",
593
+ "\n",
594
+ " # evaluate the loss\n",
595
+ " logits, loss = m(xb, yb)\n",
596
+ " optimizer.zero_grad(set_to_none=True)\n",
597
+ " loss.backward()\n",
598
+ " optimizer.step()\n",
599
+ "\n",
600
+ "print(loss.item())\n"
601
+ ],
602
+ "metadata": {
603
+ "colab": {
604
+ "base_uri": "https://localhost:8080/"
605
+ },
606
+ "id": "Hs4kI8YdEkQj",
607
+ "outputId": "234b1d99-e1d5-4394-ca9a-964027301d48"
608
+ },
609
+ "execution_count": 17,
610
+ "outputs": [
611
+ {
612
+ "output_type": "stream",
613
+ "name": "stdout",
614
+ "text": [
615
+ "4.587916374206543\n"
616
+ ]
617
+ }
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "source": [
623
+ "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))"
624
+ ],
625
+ "metadata": {
626
+ "colab": {
627
+ "base_uri": "https://localhost:8080/"
628
+ },
629
+ "id": "EcVIDWAZEtjN",
630
+ "outputId": "13e7e5a8-e382-4610-aecb-ce274d466533"
631
+ },
632
+ "execution_count": 18,
633
+ "outputs": [
634
+ {
635
+ "output_type": "stream",
636
+ "name": "stdout",
637
+ "text": [
638
+ "\n",
639
+ "xiKi-RJ:CgqVuUa!U?qMH.uk!sCuMXvv!CJFfx;LgRyJknOEti.?I&-gPlLyulId?XlaInQ'q,lT$\n",
640
+ "3Q&sGlvHQ?mqSq-eON\n",
641
+ "x?SP fUAfCAuCX:bOlgiRQWN:Mphaw\n",
642
+ "tRLKuYXEaAXxrcq-gCUzeh3w!AcyaylgYWjmJM?Uzw:inaY,:C&OECW:vmGGJAn3onAuMgia!ms$Vb q-gCOcPcUhOnxJGUGSPJWT:.?ujmJFoiNL&A'DxY,prZ?qdT;hoo'dHooXXlxf'WkHK&u3Q?rqUi.kz;?Yx?C&u3Qbfzxlyh'Vl:zyxjKXgC?\n",
643
+ "lv'QKFiBeviNxO'm!Upm$srm&TqViqiBD3HBP!juEOpmZJyF$Fwfy!PlvWPFC\n",
644
+ "&WDdP!Ko,px\n",
645
+ "x\n",
646
+ "tREOE;AJ.BeXkylOVD3KHp$e?nD,.SFbWWI'ubcL!q-tU;aXmJ&uGXHxJXI&Z!gHRpajj;l.\n",
647
+ "pTErIBjx;JKIgoCnLGXrJSP!AU-AcbczR?\n"
648
+ ]
649
+ }
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "source": [
655
+ "## The mathematical trick in self-attention"
656
+ ],
657
+ "metadata": {
658
+ "id": "XinV8nmAnmKN"
659
+ }
660
+ },
661
+ {
662
+ "cell_type": "code",
663
+ "source": [
664
+ "# toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\n",
665
+ "torch.manual_seed(42)\n",
666
+ "a = torch.tril(torch.ones(3, 3))\n",
667
+ "a = a / torch.sum(a, 1, keepdim=True)\n",
668
+ "b = torch.randint(0,10,(3,2)).float()\n",
669
+ "c = a @ b\n",
670
+ "print('a=')\n",
671
+ "print(a)\n",
672
+ "print('--')\n",
673
+ "print('b=')\n",
674
+ "print(b)\n",
675
+ "print('--')\n",
676
+ "print('c=')\n",
677
+ "print(c)"
678
+ ],
679
+ "metadata": {
680
+ "colab": {
681
+ "base_uri": "https://localhost:8080/"
682
+ },
683
+ "id": "tukiH-NbRBhA",
684
+ "outputId": "4de5f70a-e12c-4c6a-d591-5d0720e9de8c"
685
+ },
686
+ "execution_count": 19,
687
+ "outputs": [
688
+ {
689
+ "output_type": "stream",
690
+ "name": "stdout",
691
+ "text": [
692
+ "a=\n",
693
+ "tensor([[1.0000, 0.0000, 0.0000],\n",
694
+ " [0.5000, 0.5000, 0.0000],\n",
695
+ " [0.3333, 0.3333, 0.3333]])\n",
696
+ "--\n",
697
+ "b=\n",
698
+ "tensor([[2., 7.],\n",
699
+ " [6., 4.],\n",
700
+ " [6., 5.]])\n",
701
+ "--\n",
702
+ "c=\n",
703
+ "tensor([[2.0000, 7.0000],\n",
704
+ " [4.0000, 5.5000],\n",
705
+ " [4.6667, 5.3333]])\n"
706
+ ]
707
+ }
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "code",
712
+ "source": [
713
+ "# consider the following toy example:\n",
714
+ "\n",
715
+ "torch.manual_seed(1337)\n",
716
+ "B,T,C = 4,8,2 # batch, time, channels\n",
717
+ "x = torch.randn(B,T,C)\n",
718
+ "x.shape"
719
+ ],
720
+ "metadata": {
721
+ "colab": {
722
+ "base_uri": "https://localhost:8080/"
723
+ },
724
+ "id": "Hs_E24uRE8kr",
725
+ "outputId": "f1591218-d10f-420e-8d5a-456a0f90aed9"
726
+ },
727
+ "execution_count": 20,
728
+ "outputs": [
729
+ {
730
+ "output_type": "execute_result",
731
+ "data": {
732
+ "text/plain": [
733
+ "torch.Size([4, 8, 2])"
734
+ ]
735
+ },
736
+ "metadata": {},
737
+ "execution_count": 20
738
+ }
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "source": [
744
+ "# We want x[b,t] = mean_{i<=t} x[b,i]\n",
745
+ "xbow = torch.zeros((B,T,C))\n",
746
+ "for b in range(B):\n",
747
+ " for t in range(T):\n",
748
+ " xprev = x[b,:t+1] # (t,C)\n",
749
+ " xbow[b,t] = torch.mean(xprev, 0)\n"
750
+ ],
751
+ "metadata": {
752
+ "id": "86NuXX0fn7ps"
753
+ },
754
+ "execution_count": 21,
755
+ "outputs": []
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "source": [
760
+ "# version 2: using matrix multiply for a weighted aggregation\n",
761
+ "wei = torch.tril(torch.ones(T, T))\n",
762
+ "wei = wei / wei.sum(1, keepdim=True)\n",
763
+ "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)\n",
764
+ "torch.allclose(xbow, xbow2)"
765
+ ],
766
+ "metadata": {
767
+ "colab": {
768
+ "base_uri": "https://localhost:8080/"
769
+ },
770
+ "id": "yhdOAd6-wXkZ",
771
+ "outputId": "c7313d9b-d406-46ce-e2cd-f28c10ef41c2"
772
+ },
773
+ "execution_count": 22,
774
+ "outputs": [
775
+ {
776
+ "output_type": "execute_result",
777
+ "data": {
778
+ "text/plain": [
779
+ "False"
780
+ ]
781
+ },
782
+ "metadata": {},
783
+ "execution_count": 22
784
+ }
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "source": [
790
+ "# version 3: use Softmax\n",
791
+ "tril = torch.tril(torch.ones(T, T))\n",
792
+ "wei = torch.zeros((T,T))\n",
793
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
794
+ "wei = F.softmax(wei, dim=-1)\n",
795
+ "xbow3 = wei @ x\n",
796
+ "torch.allclose(xbow, xbow3)\n"
797
+ ],
798
+ "metadata": {
799
+ "colab": {
800
+ "base_uri": "https://localhost:8080/"
801
+ },
802
+ "id": "wOURrfG-ysoL",
803
+ "outputId": "40a4a993-5a9b-419c-e558-b935fd843dbf"
804
+ },
805
+ "execution_count": 23,
806
+ "outputs": [
807
+ {
808
+ "output_type": "execute_result",
809
+ "data": {
810
+ "text/plain": [
811
+ "False"
812
+ ]
813
+ },
814
+ "metadata": {},
815
+ "execution_count": 23
816
+ }
817
+ ]
818
+ },
819
+ {
820
+ "cell_type": "code",
821
+ "source": [
822
+ "# version 4: self-attention!\n",
823
+ "torch.manual_seed(1337)\n",
824
+ "B,T,C = 4,8,32 # batch, time, channels\n",
825
+ "x = torch.randn(B,T,C)\n",
826
+ "\n",
827
+ "# let's see a single Head perform self-attention\n",
828
+ "head_size = 16\n",
829
+ "key = nn.Linear(C, head_size, bias=False)\n",
830
+ "query = nn.Linear(C, head_size, bias=False)\n",
831
+ "value = nn.Linear(C, head_size, bias=False)\n",
832
+ "k = key(x) # (B, T, 16)\n",
833
+ "q = query(x) # (B, T, 16)\n",
834
+ "wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
835
+ "\n",
836
+ "tril = torch.tril(torch.ones(T, T))\n",
837
+ "#wei = torch.zeros((T,T))\n",
838
+ "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
839
+ "wei = F.softmax(wei, dim=-1)\n",
840
+ "\n",
841
+ "v = value(x)\n",
842
+ "out = wei @ v\n",
843
+ "#out = wei @ x\n",
844
+ "\n",
845
+ "out.shape"
846
+ ],
847
+ "metadata": {
848
+ "colab": {
849
+ "base_uri": "https://localhost:8080/"
850
+ },
851
+ "id": "EDarxEWIRMKq",
852
+ "outputId": "6fee2aa4-4ab6-4d89-c8ca-7463ee54962b"
853
+ },
854
+ "execution_count": 24,
855
+ "outputs": [
856
+ {
857
+ "output_type": "execute_result",
858
+ "data": {
859
+ "text/plain": [
860
+ "torch.Size([4, 8, 16])"
861
+ ]
862
+ },
863
+ "metadata": {},
864
+ "execution_count": 24
865
+ }
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "source": [
871
+ "wei[0]"
872
+ ],
873
+ "metadata": {
874
+ "colab": {
875
+ "base_uri": "https://localhost:8080/"
876
+ },
877
+ "id": "vT1hdtzXCjgL",
878
+ "outputId": "c664020c-c9dd-4c85-84a4-fae0320453f8"
879
+ },
880
+ "execution_count": 25,
881
+ "outputs": [
882
+ {
883
+ "output_type": "execute_result",
884
+ "data": {
885
+ "text/plain": [
886
+ "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
887
+ " [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
888
+ " [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
889
+ " [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n",
890
+ " [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n",
891
+ " [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n",
892
+ " [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n",
893
+ " [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n",
894
+ " grad_fn=<SelectBackward0>)"
895
+ ]
896
+ },
897
+ "metadata": {},
898
+ "execution_count": 25
899
+ }
900
+ ]
901
+ },
902
+ {
903
+ "cell_type": "markdown",
904
+ "source": [
905
+ "Notes:\n",
906
+ "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n",
907
+ "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n",
908
+ "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n",
909
+ "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n",
910
+ "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n",
911
+ "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below"
912
+ ],
913
+ "metadata": {
914
+ "id": "M5CvobiQ0pLr"
915
+ }
916
+ },
917
+ {
918
+ "cell_type": "code",
919
+ "source": [
920
+ "k = torch.randn(B,T,head_size)\n",
921
+ "q = torch.randn(B,T,head_size)\n",
922
+ "wei = q @ k.transpose(-2, -1) * head_size**-0.5"
923
+ ],
924
+ "metadata": {
925
+ "id": "4SNbLq5z3oBw"
926
+ },
927
+ "execution_count": 26,
928
+ "outputs": []
929
+ },
930
+ {
931
+ "cell_type": "code",
932
+ "source": [
933
+ "k.var()"
934
+ ],
935
+ "metadata": {
936
+ "colab": {
937
+ "base_uri": "https://localhost:8080/"
938
+ },
939
+ "id": "Nl6I9n9IRTSo",
940
+ "outputId": "162aab09-b860-4b73-c0ae-394451367460"
941
+ },
942
+ "execution_count": 27,
943
+ "outputs": [
944
+ {
945
+ "output_type": "execute_result",
946
+ "data": {
947
+ "text/plain": [
948
+ "tensor(1.0449)"
949
+ ]
950
+ },
951
+ "metadata": {},
952
+ "execution_count": 27
953
+ }
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "source": [
959
+ "q.var()"
960
+ ],
961
+ "metadata": {
962
+ "colab": {
963
+ "base_uri": "https://localhost:8080/"
964
+ },
965
+ "id": "T1tQx7oeRvtc",
966
+ "outputId": "20aacd2d-d414-4268-981e-86a5fd8afcc8"
967
+ },
968
+ "execution_count": 28,
969
+ "outputs": [
970
+ {
971
+ "output_type": "execute_result",
972
+ "data": {
973
+ "text/plain": [
974
+ "tensor(1.0700)"
975
+ ]
976
+ },
977
+ "metadata": {},
978
+ "execution_count": 28
979
+ }
980
+ ]
981
+ },
982
+ {
983
+ "cell_type": "code",
984
+ "source": [
985
+ "wei.var()"
986
+ ],
987
+ "metadata": {
988
+ "colab": {
989
+ "base_uri": "https://localhost:8080/"
990
+ },
991
+ "id": "MLb_odHU3iKM",
992
+ "outputId": "5d6ca0fd-51df-42ec-daf8-7fb2ff9f640f"
993
+ },
994
+ "execution_count": 29,
995
+ "outputs": [
996
+ {
997
+ "output_type": "execute_result",
998
+ "data": {
999
+ "text/plain": [
1000
+ "tensor(1.0918)"
1001
+ ]
1002
+ },
1003
+ "metadata": {},
1004
+ "execution_count": 29
1005
+ }
1006
+ ]
1007
+ },
1008
+ {
1009
+ "cell_type": "code",
1010
+ "source": [
1011
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)"
1012
+ ],
1013
+ "metadata": {
1014
+ "colab": {
1015
+ "base_uri": "https://localhost:8080/"
1016
+ },
1017
+ "id": "JB82yzt44REI",
1018
+ "outputId": "df0211e7-a2b0-46c7-9fd2-c5a8cc185ed7"
1019
+ },
1020
+ "execution_count": 30,
1021
+ "outputs": [
1022
+ {
1023
+ "output_type": "execute_result",
1024
+ "data": {
1025
+ "text/plain": [
1026
+ "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])"
1027
+ ]
1028
+ },
1029
+ "metadata": {},
1030
+ "execution_count": 30
1031
+ }
1032
+ ]
1033
+ },
1034
+ {
1035
+ "cell_type": "code",
1036
+ "source": [
1037
+ "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot"
1038
+ ],
1039
+ "metadata": {
1040
+ "colab": {
1041
+ "base_uri": "https://localhost:8080/"
1042
+ },
1043
+ "id": "Mpt8569BB9_f",
1044
+ "outputId": "cf991a1e-7072-4944-d578-886a270f57de"
1045
+ },
1046
+ "execution_count": 31,
1047
+ "outputs": [
1048
+ {
1049
+ "output_type": "execute_result",
1050
+ "data": {
1051
+ "text/plain": [
1052
+ "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])"
1053
+ ]
1054
+ },
1055
+ "metadata": {},
1056
+ "execution_count": 31
1057
+ }
1058
+ ]
1059
+ },
1060
+ {
1061
+ "cell_type": "code",
1062
+ "source": [
1063
+ "class LayerNorm1d: # (used to be BatchNorm1d)\n",
1064
+ "\n",
1065
+ " def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
1066
+ " self.eps = eps\n",
1067
+ " self.gamma = torch.ones(dim)\n",
1068
+ " self.beta = torch.zeros(dim)\n",
1069
+ "\n",
1070
+ " def __call__(self, x):\n",
1071
+ " # calculate the forward pass\n",
1072
+ " xmean = x.mean(1, keepdim=True) # batch mean\n",
1073
+ " xvar = x.var(1, keepdim=True) # batch variance\n",
1074
+ " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
1075
+ " self.out = self.gamma * xhat + self.beta\n",
1076
+ " return self.out\n",
1077
+ "\n",
1078
+ " def parameters(self):\n",
1079
+ " return [self.gamma, self.beta]\n",
1080
+ "\n",
1081
+ "torch.manual_seed(1337)\n",
1082
+ "module = LayerNorm1d(100)\n",
1083
+ "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n",
1084
+ "x = module(x)\n",
1085
+ "x.shape"
1086
+ ],
1087
+ "metadata": {
1088
+ "colab": {
1089
+ "base_uri": "https://localhost:8080/"
1090
+ },
1091
+ "id": "2Num7sX9CKOH",
1092
+ "outputId": "14c48660-c741-4cb8-ac79-53d2bf094a63"
1093
+ },
1094
+ "execution_count": 32,
1095
+ "outputs": [
1096
+ {
1097
+ "output_type": "execute_result",
1098
+ "data": {
1099
+ "text/plain": [
1100
+ "torch.Size([32, 100])"
1101
+ ]
1102
+ },
1103
+ "metadata": {},
1104
+ "execution_count": 32
1105
+ }
1106
+ ]
1107
+ },
1108
+ {
1109
+ "cell_type": "code",
1110
+ "source": [
1111
+ "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs"
1112
+ ],
1113
+ "metadata": {
1114
+ "colab": {
1115
+ "base_uri": "https://localhost:8080/"
1116
+ },
1117
+ "id": "633T2cmnW1uk",
1118
+ "outputId": "2a6e887c-6b82-454f-8f32-aefde73777c5"
1119
+ },
1120
+ "execution_count": 33,
1121
+ "outputs": [
1122
+ {
1123
+ "output_type": "execute_result",
1124
+ "data": {
1125
+ "text/plain": [
1126
+ "(tensor(0.1469), tensor(0.8803))"
1127
+ ]
1128
+ },
1129
+ "metadata": {},
1130
+ "execution_count": 33
1131
+ }
1132
+ ]
1133
+ },
1134
+ {
1135
+ "cell_type": "code",
1136
+ "source": [
1137
+ "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features"
1138
+ ],
1139
+ "metadata": {
1140
+ "colab": {
1141
+ "base_uri": "https://localhost:8080/"
1142
+ },
1143
+ "id": "LN9cK9BoXCYb",
1144
+ "outputId": "4c81f68e-b1d2-4a04-d38d-09583f104ea7"
1145
+ },
1146
+ "execution_count": 34,
1147
+ "outputs": [
1148
+ {
1149
+ "output_type": "execute_result",
1150
+ "data": {
1151
+ "text/plain": [
1152
+ "(tensor(-9.5367e-09), tensor(1.0000))"
1153
+ ]
1154
+ },
1155
+ "metadata": {},
1156
+ "execution_count": 34
1157
+ }
1158
+ ]
1159
+ },
1160
+ {
1161
+ "cell_type": "code",
1162
+ "source": [
1163
+ "# French to English translation example:\n",
1164
+ "\n",
1165
+ "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n",
1166
+ "# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>\n",
1167
+ "\n"
1168
+ ],
1169
+ "metadata": {
1170
+ "id": "dRJH6wM_XFfU"
1171
+ },
1172
+ "execution_count": 35,
1173
+ "outputs": []
1174
+ },
1175
+ {
1176
+ "cell_type": "markdown",
1177
+ "source": [
1178
+ "### Full finished code, for reference\n",
1179
+ "\n",
1180
+ "You may want to refer directly to the git repo instead though."
1181
+ ],
1182
+ "metadata": {
1183
+ "id": "ZcvKeBXoZFOY"
1184
+ }
1185
+ },
1186
+ {
1187
+ "cell_type": "code",
1188
+ "source": [
1189
+ "import torch\n",
1190
+ "import torch.nn as nn\n",
1191
+ "from torch.nn import functional as F\n",
1192
+ "\n",
1193
+ "# hyperparameters\n",
1194
+ "batch_size = 16 # how many independent sequences will we process in parallel?\n",
1195
+ "block_size = 32 # what is the maximum context length for predictions?\n",
1196
+ "max_iters = 5000\n",
1197
+ "#00\n",
1198
+ "eval_interval = 100\n",
1199
+ "learning_rate = 1e-3\n",
1200
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1201
+ "eval_iters = 200\n",
1202
+ "n_embd = 64\n",
1203
+ "n_head = 4\n",
1204
+ "n_layer = 4\n",
1205
+ "dropout = 0.0\n",
1206
+ "# ------------\n",
1207
+ "\n",
1208
+ "torch.manual_seed(1337)\n",
1209
+ "\n",
1210
+ "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
1211
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
1212
+ " text = f.read()\n",
1213
+ "\n",
1214
+ "# here are all the unique characters that occur in this text\n",
1215
+ "chars = sorted(list(set(text)))\n",
1216
+ "vocab_size = len(chars)\n",
1217
+ "# create a mapping from characters to integers\n",
1218
+ "stoi = { ch:i for i,ch in enumerate(chars) }\n",
1219
+ "itos = { i:ch for i,ch in enumerate(chars) }\n",
1220
+ "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
1221
+ "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
1222
+ "\n",
1223
+ "# Train and test splits\n",
1224
+ "data = torch.tensor(encode(text), dtype=torch.long)\n",
1225
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
1226
+ "train_data = data[:n]\n",
1227
+ "val_data = data[n:]\n",
1228
+ "\n",
1229
+ "# data loading\n",
1230
+ "def get_batch(split):\n",
1231
+ " # generate a small batch of data of inputs x and targets y\n",
1232
+ " data = train_data if split == 'train' else val_data\n",
1233
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
1234
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
1235
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
1236
+ " x, y = x.to(device), y.to(device)\n",
1237
+ " return x, y\n",
1238
+ "\n",
1239
+ "@torch.no_grad()\n",
1240
+ "def estimate_loss():\n",
1241
+ " out = {}\n",
1242
+ " model.eval()\n",
1243
+ " for split in ['train', 'val']:\n",
1244
+ " losses = torch.zeros(eval_iters)\n",
1245
+ " for k in range(eval_iters):\n",
1246
+ " X, Y = get_batch(split)\n",
1247
+ " logits, loss = model(X, Y)\n",
1248
+ " losses[k] = loss.item()\n",
1249
+ " out[split] = losses.mean()\n",
1250
+ " model.train()\n",
1251
+ " return out\n",
1252
+ "\n",
1253
+ "class Head(nn.Module):\n",
1254
+ " \"\"\" one head of self-attention \"\"\"\n",
1255
+ "\n",
1256
+ " def __init__(self, head_size):\n",
1257
+ " super().__init__()\n",
1258
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
1259
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
1260
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
1261
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
1262
+ "\n",
1263
+ " self.dropout = nn.Dropout(dropout)\n",
1264
+ "\n",
1265
+ " def forward(self, x):\n",
1266
+ " B,T,C = x.shape\n",
1267
+ " k = self.key(x) # (B,T,C)\n",
1268
+ " q = self.query(x) # (B,T,C)\n",
1269
+ " # compute attention scores (\"affinities\")\n",
1270
+ " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
1271
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
1272
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
1273
+ " wei = self.dropout(wei)\n",
1274
+ " # perform the weighted aggregation of the values\n",
1275
+ " v = self.value(x) # (B,T,C)\n",
1276
+ " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
1277
+ " return out\n",
1278
+ "\n",
1279
+ "class MultiHeadAttention(nn.Module):\n",
1280
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
1281
+ "\n",
1282
+ " def __init__(self, num_heads, head_size):\n",
1283
+ " super().__init__()\n",
1284
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
1285
+ " self.proj = nn.Linear(n_embd, n_embd)\n",
1286
+ " self.dropout = nn.Dropout(dropout)\n",
1287
+ "\n",
1288
+ " def forward(self, x):\n",
1289
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
1290
+ " out = self.dropout(self.proj(out))\n",
1291
+ " return out\n",
1292
+ "\n",
1293
+ "class FeedFoward(nn.Module):\n",
1294
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
1295
+ "\n",
1296
+ " def __init__(self, n_embd):\n",
1297
+ " super().__init__()\n",
1298
+ " self.net = nn.Sequential(\n",
1299
+ " nn.Linear(n_embd, 4 * n_embd),\n",
1300
+ " nn.ReLU(),\n",
1301
+ " nn.Linear(4 * n_embd, n_embd),\n",
1302
+ " nn.Dropout(dropout),\n",
1303
+ " )\n",
1304
+ "\n",
1305
+ " def forward(self, x):\n",
1306
+ " return self.net(x)\n",
1307
+ "\n",
1308
+ "class Block(nn.Module):\n",
1309
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
1310
+ "\n",
1311
+ " def __init__(self, n_embd, n_head):\n",
1312
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
1313
+ " super().__init__()\n",
1314
+ " head_size = n_embd // n_head\n",
1315
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
1316
+ " self.ffwd = FeedFoward(n_embd)\n",
1317
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
1318
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
1319
+ "\n",
1320
+ " def forward(self, x):\n",
1321
+ " x = x + self.sa(self.ln1(x))\n",
1322
+ " x = x + self.ffwd(self.ln2(x))\n",
1323
+ " return x\n",
1324
+ "\n",
1325
+ "# super simple bigram model\n",
1326
+ "class BigramLanguageModel(nn.Module):\n",
1327
+ "\n",
1328
+ " def __init__(self):\n",
1329
+ " super().__init__()\n",
1330
+ " # each token directly reads off the logits for the next token from a lookup table\n",
1331
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
1332
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
1333
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
1334
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
1335
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
1336
+ "\n",
1337
+ " def forward(self, idx, targets=None):\n",
1338
+ " B, T = idx.shape\n",
1339
+ "\n",
1340
+ " # idx and targets are both (B,T) tensor of integers\n",
1341
+ " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
1342
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
1343
+ " x = tok_emb + pos_emb # (B,T,C)\n",
1344
+ " x = self.blocks(x) # (B,T,C)\n",
1345
+ " x = self.ln_f(x) # (B,T,C)\n",
1346
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
1347
+ "\n",
1348
+ " if targets is None:\n",
1349
+ " loss = None\n",
1350
+ " else:\n",
1351
+ " B, T, C = logits.shape\n",
1352
+ " logits = logits.view(B*T, C)\n",
1353
+ " targets = targets.view(B*T)\n",
1354
+ " loss = F.cross_entropy(logits, targets)\n",
1355
+ "\n",
1356
+ " return logits, loss\n",
1357
+ "\n",
1358
+ " def generate(self, idx, max_new_tokens):\n",
1359
+ " # idx is (B, T) array of indices in the current context\n",
1360
+ " for _ in range(max_new_tokens):\n",
1361
+ " # crop idx to the last block_size tokens\n",
1362
+ " idx_cond = idx[:, -block_size:]\n",
1363
+ " # get the predictions\n",
1364
+ " logits, loss = self(idx_cond)\n",
1365
+ " # focus only on the last time step\n",
1366
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
1367
+ " # apply softmax to get probabilities\n",
1368
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
1369
+ " # sample from the distribution\n",
1370
+ " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
1371
+ " # append sampled index to the running sequence\n",
1372
+ " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
1373
+ " return idx\n",
1374
+ "\n",
1375
+ "model = BigramLanguageModel()\n",
1376
+ "m = model.to(device)\n",
1377
+ "# print the number of parameters in the model\n",
1378
+ "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
1379
+ "\n",
1380
+ "# create a PyTorch optimizer\n",
1381
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
1382
+ "\n",
1383
+ "for iter in range(max_iters):\n",
1384
+ "\n",
1385
+ " # every once in a while evaluate the loss on train and val sets\n",
1386
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
1387
+ " losses = estimate_loss()\n",
1388
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
1389
+ "\n",
1390
+ " # sample a batch of data\n",
1391
+ " xb, yb = get_batch('train')\n",
1392
+ "\n",
1393
+ " # evaluate the loss\n",
1394
+ " logits, loss = model(xb, yb)\n",
1395
+ " optimizer.zero_grad(set_to_none=True)\n",
1396
+ " loss.backward()\n",
1397
+ " optimizer.step()\n",
1398
+ "\n",
1399
+ "# generate from the model\n",
1400
+ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
1401
+ "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n"
1402
+ ],
1403
+ "metadata": {
1404
+ "colab": {
1405
+ "base_uri": "https://localhost:8080/"
1406
+ },
1407
+ "id": "hoelkOrFY8bN",
1408
+ "outputId": "4f7e6e13-879e-469d-dcdb-0d3c48e263c5"
1409
+ },
1410
+ "execution_count": 37,
1411
+ "outputs": [
1412
+ {
1413
+ "output_type": "stream",
1414
+ "name": "stdout",
1415
+ "text": [
1416
+ "0.209729 M parameters\n",
1417
+ "step 0: train loss 4.4116, val loss 4.4022\n",
1418
+ "step 100: train loss 2.6568, val loss 2.6670\n",
1419
+ "step 200: train loss 2.5090, val loss 2.5059\n",
1420
+ "step 300: train loss 2.4196, val loss 2.4338\n",
1421
+ "step 400: train loss 2.3504, val loss 2.3566\n",
1422
+ "step 500: train loss 2.2965, val loss 2.3129\n",
1423
+ "step 600: train loss 2.2410, val loss 2.2500\n",
1424
+ "step 700: train loss 2.2057, val loss 2.2191\n",
1425
+ "step 800: train loss 2.1633, val loss 2.1864\n",
1426
+ "step 900: train loss 2.1244, val loss 2.1510\n",
1427
+ "step 1000: train loss 2.1038, val loss 2.1308\n",
1428
+ "step 1100: train loss 2.0707, val loss 2.1197\n",
1429
+ "step 1200: train loss 2.0377, val loss 2.0800\n",
1430
+ "step 1300: train loss 2.0268, val loss 2.0650\n",
1431
+ "step 1400: train loss 1.9918, val loss 2.0356\n",
1432
+ "step 1500: train loss 1.9697, val loss 2.0293\n",
1433
+ "step 1600: train loss 1.9645, val loss 2.0499\n",
1434
+ "step 1700: train loss 1.9404, val loss 2.0129\n",
1435
+ "step 1800: train loss 1.9095, val loss 1.9951\n",
1436
+ "step 1900: train loss 1.9067, val loss 1.9855\n",
1437
+ "step 2000: train loss 1.8854, val loss 1.9948\n",
1438
+ "step 2100: train loss 1.8727, val loss 1.9766\n",
1439
+ "step 2200: train loss 1.8597, val loss 1.9631\n",
1440
+ "step 2300: train loss 1.8530, val loss 1.9516\n",
1441
+ "step 2400: train loss 1.8428, val loss 1.9464\n",
1442
+ "step 2500: train loss 1.8161, val loss 1.9424\n",
1443
+ "step 2600: train loss 1.8283, val loss 1.9406\n",
1444
+ "step 2700: train loss 1.8101, val loss 1.9322\n",
1445
+ "step 2800: train loss 1.8050, val loss 1.9233\n",
1446
+ "step 2900: train loss 1.8033, val loss 1.9289\n",
1447
+ "step 3000: train loss 1.7955, val loss 1.9216\n",
1448
+ "step 3100: train loss 1.7697, val loss 1.9184\n",
1449
+ "step 3200: train loss 1.7541, val loss 1.9088\n",
1450
+ "step 3300: train loss 1.7567, val loss 1.9034\n",
1451
+ "step 3400: train loss 1.7573, val loss 1.9000\n",
1452
+ "step 3500: train loss 1.7398, val loss 1.8925\n",
1453
+ "step 3600: train loss 1.7270, val loss 1.8869\n",
1454
+ "step 3700: train loss 1.7283, val loss 1.8814\n",
1455
+ "step 3800: train loss 1.7210, val loss 1.8918\n",
1456
+ "step 3900: train loss 1.7219, val loss 1.8732\n",
1457
+ "step 4000: train loss 1.7146, val loss 1.8576\n",
1458
+ "step 4100: train loss 1.7136, val loss 1.8720\n",
1459
+ "step 4200: train loss 1.7060, val loss 1.8653\n",
1460
+ "step 4300: train loss 1.7032, val loss 1.8499\n",
1461
+ "step 4400: train loss 1.7057, val loss 1.8656\n",
1462
+ "step 4500: train loss 1.6907, val loss 1.8477\n",
1463
+ "step 4600: train loss 1.6878, val loss 1.8371\n",
1464
+ "step 4700: train loss 1.6808, val loss 1.8415\n",
1465
+ "step 4800: train loss 1.6689, val loss 1.8457\n",
1466
+ "step 4900: train loss 1.6716, val loss 1.8415\n",
1467
+ "step 4999: train loss 1.6658, val loss 1.8275\n",
1468
+ "\n",
1469
+ "ROTCUMER:\n",
1470
+ "Tyburforth, bloody,\n",
1471
+ "WhIs migute: you duke I use list. WIthon of where's grande will! savist tought!\n",
1472
+ "Why room upwor alond, liegle. I hone, Iell thou sudd have then strue thus mind,\n",
1473
+ "His by blow, Virdom tow, glingien, yithre spees ssince them Those not.\n",
1474
+ "\n",
1475
+ "LUCIO:\n",
1476
+ "Look,----\n",
1477
+ "But thou sging them this my freceimmsed,\n",
1478
+ "By thou sovor conursion that thou sade but grove\n",
1479
+ "the tage encond:\n",
1480
+ "It will Rament me; an your touther,\n",
1481
+ "And havis like to-does, and little spright.\n",
1482
+ "\n",
1483
+ "GLOUCESTER:\n",
1484
+ "Rewards thou for Panfessira's bigguards such ways!\n",
1485
+ "What curfort his\n",
1486
+ "will havolss you, as I have the cervirs arled,\n",
1487
+ "Dear my love and pitace unto duly son.\n",
1488
+ "\n",
1489
+ "Secome:\n",
1490
+ "Offolk, even thy whose my late all that you by jotly us belies!\n",
1491
+ "Lord, we a-montencry! I\n",
1492
+ "\n",
1493
+ "SLARNE:\n",
1494
+ "Day, mave from out prrive And orculing\n",
1495
+ "What confess, temimelyour and stropt;\n",
1496
+ "Secumfospet the gatieus I'll that confence-sting,\n",
1497
+ "But; man't, Rolget\n",
1498
+ "would garnion'd live in which, you, prothre?\n",
1499
+ "\n",
1500
+ "CORIOLANUS:\n",
1501
+ "What bonum stravoing, not out be seemmed with\n",
1502
+ "That the boly noll to.\n",
1503
+ "Bently, which in on my not tomberven why, fortune,\n",
1504
+ "And that wark you, banot thus orl'ld groves viles.\n",
1505
+ "\n",
1506
+ "PUMNIUS:\n",
1507
+ "It thou addow less, proth-straing.\n",
1508
+ "Mutwing your contrant stomfe, whom they\n",
1509
+ "is by this famestle; and of the loves my not Mercarcious to the stord; thesoo, in thus my nome are:\n",
1510
+ "Will fuch, have there enplience your gone, ho's,\n",
1511
+ "And gentleman, my beged lind to be am\n",
1512
+ "in That ant:\n",
1513
+ "In I sugner murded! I play's,\n",
1514
+ "If not sume the confity will reasur slord:\n",
1515
+ "That get because at that his say\n",
1516
+ "and to beepts guarst you lom if then.\n",
1517
+ "\n",
1518
+ "MENEN MARGARUS:\n",
1519
+ "I but aftelence! made yoour never.\n",
1520
+ "\n",
1521
+ "KING RICHARD II:\n",
1522
+ "Who too near?\n",
1523
+ "\n",
1524
+ "LORDIUS:\n",
1525
+ "Or as madaw brird, tou thee?\n",
1526
+ "\n",
1527
+ "Sirightly the haste's beforempt.\n",
1528
+ "\n",
1529
+ "First:\n",
1530
+ "Is though.\n",
1531
+ "Fell, whose toes with requmpts, up I make\n",
1532
+ "Here figUS verean that I will, by the wateon.\n",
1533
+ "\n",
1534
+ "MOWIDIUS:\n",
1535
+ "How, while, more is in meep.\n",
1536
+ "twan be the fless this countrens platcar merperter sure make Giventled,\n",
1537
+ "At not your must to reason togs,\n",
1538
+ "And what you gue;--\n",
1539
+ "\n",
1540
+ "RUKE ESFiren; gravent,\n",
1541
+ "Apol\n"
1542
+ ]
1543
+ }
1544
+ ]
1545
+ },
1546
+ {
1547
+ "cell_type": "code",
1548
+ "source": [],
1549
+ "metadata": {
1550
+ "id": "fjjvMifYZf7x"
1551
+ },
1552
+ "execution_count": 36,
1553
+ "outputs": []
1554
+ }
1555
+ ]
1556
+ }