Added model and tokenizer
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- .gitignore +3 -1
- app.py +18 -5
- model/.DS_Store +0 -0
- model/model_1000_.bin +3 -0
- my_gpt.py +186 -0
- tokenizer/__init__.py +0 -0
- tokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
- tokenizer/__pycache__/base.cpython-311.pyc +0 -0
- tokenizer/__pycache__/base.cpython-38.pyc +0 -0
- tokenizer/__pycache__/base.cpython-39.pyc +0 -0
- tokenizer/__pycache__/tokenizer.cpython-38.pyc +0 -0
- tokenizer/__pycache__/tokenizer.cpython-39.pyc +0 -0
- tokenizer/base.py +110 -0
- tokenizer/tokenizer.py +72 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
model/model_1000_.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
*.txt
|
|
|
|
|
|
1 |
+
*.txt
|
2 |
+
tokenizer/README.md
|
3 |
+
__pycache__
|
app.py
CHANGED
@@ -1,11 +1,24 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
|
|
|
4 |
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
-
def greet(name):
|
8 |
-
return "Hello " + name + "!!"
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from my_gpt import my_gpt
|
4 |
+
from tokenizer.tokenizer import BPE
|
5 |
|
6 |
+
##Load model
|
7 |
+
model = my_gpt.load_pretrained("model/model_1000_.bin")
|
8 |
+
tokenizer = BPE()
|
9 |
|
10 |
|
|
|
|
|
11 |
|
12 |
+
def generate(input_text):
|
13 |
+
tokens = tokenizer.encode(input_text)
|
14 |
+
gen_ids = model.generate(torch.tensor([tokens]))
|
15 |
+
output = tokenizer.decode(gen_ids[0].tolist())
|
16 |
+
return output
|
17 |
+
|
18 |
+
iface = gr.Interface(fn=generate,
|
19 |
+
inputs="text",
|
20 |
+
outputs="text",
|
21 |
+
title="GPT - 1000 steps",
|
22 |
+
description="""This model is trained for 1000 steps only. It is not
|
23 |
+
able to generate perfect sentences/words. However, it has learnt a gist of the English language""")
|
24 |
iface.launch()
|
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/model_1000_.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:566dd60c0869b306eda84fd1d64dbacce02b256492da9ea626d227b046125bd2
|
3 |
+
size 56950645
|
my_gpt.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
block_size = 256
|
9 |
+
vocab_size = 500
|
10 |
+
n_embed = 384
|
11 |
+
dropout = 0.2
|
12 |
+
n_head = 6
|
13 |
+
n_layer = 6
|
14 |
+
|
15 |
+
class Head(nn.Module):
|
16 |
+
def __init__(self, head_size=16):
|
17 |
+
super().__init__()
|
18 |
+
self.query = nn.Linear(n_embed, head_size, bias=False)
|
19 |
+
self.key = nn.Linear(n_embed, head_size, bias=False)
|
20 |
+
self.value = nn.Linear(n_embed, head_size, bias=False)
|
21 |
+
self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
def forward(self,x):
|
27 |
+
B,T,C = x.shape
|
28 |
+
|
29 |
+
q = self.query(x)
|
30 |
+
k = self.key(x)
|
31 |
+
|
32 |
+
wei = (q @ k.transpose(-2,-1)) * (k.shape[-1]**(-0.5))
|
33 |
+
wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf'))
|
34 |
+
wei = F.softmax(wei, dim=-1)
|
35 |
+
wei = self.dropout(wei)
|
36 |
+
|
37 |
+
v = self.value(x)
|
38 |
+
|
39 |
+
out = wei @ v ## (B,T,HS)
|
40 |
+
|
41 |
+
return out
|
42 |
+
|
43 |
+
class MultiHeadAttention(nn.Module):
|
44 |
+
def __init__(self,num_heads, head_size) :
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.heads = nn.ModuleList(Head(head_size=head_size) for _ in range(num_heads))
|
48 |
+
self.proj = nn.Linear(head_size * num_heads, n_embed)
|
49 |
+
self.dropout = nn.Dropout(dropout)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
53 |
+
out = self.dropout(self.proj(out))
|
54 |
+
return out
|
55 |
+
|
56 |
+
class FeedForward(nn.Module):
|
57 |
+
def __init__(self,n_embed) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self.net = nn.Sequential(
|
60 |
+
nn.Linear(n_embed,4* n_embed),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Linear(4 * n_embed, n_embed),
|
63 |
+
nn.Dropout(dropout),
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.net(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
class decoder_block(nn.Module):
|
71 |
+
def __init__(self, n_embed, n_heads):
|
72 |
+
super().__init__()
|
73 |
+
self.sa = MultiHeadAttention(n_heads,n_embed//n_heads)
|
74 |
+
self.ln1 = nn.LayerNorm(n_embed)
|
75 |
+
self.ln2 = nn.LayerNorm(n_embed)
|
76 |
+
self.ffwd = FeedForward(n_embed)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = x + self.sa(self.ln1(x))
|
80 |
+
x = x + self.ffwd(self.ln2(x))
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
class my_gpt(nn.Module):
|
86 |
+
def __init__(self, block_size = 256):
|
87 |
+
super().__init__()
|
88 |
+
self.block_size = block_size ##context window size
|
89 |
+
self.token_embed = nn.Embedding(vocab_size, n_embed)
|
90 |
+
self.pos_embed = nn.Embedding(vocab_size, n_embed)
|
91 |
+
self.lm_head = nn.Linear(n_embed, vocab_size)
|
92 |
+
self.sa_head = Head(vocab_size)
|
93 |
+
self.d_blocks = nn.Sequential(*[decoder_block(n_embed=n_embed,n_heads=n_head) for _ in range(n_layer)])
|
94 |
+
self.ln_f = nn.LayerNorm(n_embed) # final layer norm
|
95 |
+
|
96 |
+
self.apply(self._init_weights)
|
97 |
+
|
98 |
+
def _init_weights(self, module):
|
99 |
+
if isinstance(module, nn.Linear):
|
100 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
101 |
+
if module.bias is not None:
|
102 |
+
torch.nn.init.zeros_(module.bias)
|
103 |
+
elif isinstance(module, nn.Embedding):
|
104 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
105 |
+
|
106 |
+
def forward(self, idx, targets = None):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
idx: int(B,T) Token ids
|
110 |
+
targets :
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
logits
|
114 |
+
"""
|
115 |
+
# print("idx ", idx)
|
116 |
+
B, T = idx.shape ##
|
117 |
+
tok_emd = self.token_embed(idx) ##(B,T,C)
|
118 |
+
pos_emd = self.pos_embed(idx)
|
119 |
+
|
120 |
+
|
121 |
+
x = tok_emd + pos_emd
|
122 |
+
# print("x1 ", x.shape)
|
123 |
+
|
124 |
+
x = self.d_blocks(x) #
|
125 |
+
x = self.ln_f(x) # (B,T,C)
|
126 |
+
|
127 |
+
logits = self.lm_head(x) ##(B,T,vocab_size)
|
128 |
+
if targets is None:
|
129 |
+
loss = None
|
130 |
+
else:
|
131 |
+
B, T, C = logits.shape
|
132 |
+
# print("logits ", logits.shape)
|
133 |
+
logits = logits.view(B*T,C)
|
134 |
+
targets = targets.view(B*T)
|
135 |
+
|
136 |
+
loss = F.cross_entropy(logits, targets)
|
137 |
+
# print("Logits", logits.shape)
|
138 |
+
|
139 |
+
return logits, loss
|
140 |
+
|
141 |
+
|
142 |
+
def generate(self, context : torch.tensor, max_new_tokens: int = 46, use_cache = False):
|
143 |
+
"""
|
144 |
+
Generates the next "max_new_tokens" number of tokens.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
context (B,T):
|
148 |
+
max_new_tokens (int):
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
[token] : List of generated tokens.
|
152 |
+
"""
|
153 |
+
# print("Context:" , context)
|
154 |
+
for _ in range(max_new_tokens):
|
155 |
+
##Take only last allowed number of tokens
|
156 |
+
idx_tokens = context[:, -self.block_size:]
|
157 |
+
|
158 |
+
##generate the next token
|
159 |
+
logits, loss = self(idx_tokens)
|
160 |
+
|
161 |
+
##Take only last allowed number of tokens
|
162 |
+
logits = logits[:,-1,:] ##(B,vocab_size)
|
163 |
+
# print("logits:" , logits.shape)
|
164 |
+
|
165 |
+
probs = F.softmax(logits, dim= -1)
|
166 |
+
idx_next = torch.multinomial(probs,num_samples=1) ##(B,1)
|
167 |
+
|
168 |
+
context = torch.concatenate([context, idx_next], dim=1)
|
169 |
+
|
170 |
+
return context
|
171 |
+
|
172 |
+
def save_pretrained(self, path):
|
173 |
+
torch.save(self.state_dict(),path)
|
174 |
+
print("Saved pretrained Successfully")
|
175 |
+
|
176 |
+
@classmethod
|
177 |
+
def load_pretrained(cls, path):
|
178 |
+
print("Loading pretrained model...")
|
179 |
+
model = cls()
|
180 |
+
model.load_state_dict(torch.load(path))
|
181 |
+
return model
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
tokenizer/__init__.py
ADDED
File without changes
|
tokenizer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (186 Bytes). View file
|
|
tokenizer/__pycache__/base.cpython-311.pyc
ADDED
Binary file (4.35 kB). View file
|
|
tokenizer/__pycache__/base.cpython-38.pyc
ADDED
Binary file (2.8 kB). View file
|
|
tokenizer/__pycache__/base.cpython-39.pyc
ADDED
Binary file (2.02 kB). View file
|
|
tokenizer/__pycache__/tokenizer.cpython-38.pyc
ADDED
Binary file (2.34 kB). View file
|
|
tokenizer/__pycache__/tokenizer.cpython-39.pyc
ADDED
Binary file (1.97 kB). View file
|
|
tokenizer/base.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
sys.path.append("./")
|
7 |
+
def render_token(t: bytes) -> str:
|
8 |
+
# pretty print a token, escaping control characters
|
9 |
+
s = t.decode('utf-8', errors='replace')
|
10 |
+
# s = replace_control_characters(s)
|
11 |
+
return s
|
12 |
+
|
13 |
+
|
14 |
+
def get_freq_pairs(inp_toks):
|
15 |
+
"""Returns a count of the pairs"""
|
16 |
+
count = {}
|
17 |
+
for pair in zip(inp_toks, inp_toks[1:]):
|
18 |
+
count[pair] = count.get(pair,0) + 1
|
19 |
+
return count
|
20 |
+
|
21 |
+
|
22 |
+
def merge(id_list, pair, replace_with_idx):
|
23 |
+
"""
|
24 |
+
Replace the occurence of 'pair' in 'id_list' with 'replace_with_idx'
|
25 |
+
|
26 |
+
id_list : List of tokens
|
27 |
+
pair : List of 2 numbers
|
28 |
+
replace_with_idx : Int value
|
29 |
+
|
30 |
+
Returns new list with the pair replaced
|
31 |
+
"""
|
32 |
+
i=0
|
33 |
+
new_ids_list = []
|
34 |
+
while(i<len(id_list)):
|
35 |
+
if(i<len(id_list)-1 and id_list[i]==pair[0] and id_list[i+1]==pair[1]):
|
36 |
+
new_ids_list.append(replace_with_idx)
|
37 |
+
|
38 |
+
i+=2
|
39 |
+
else:
|
40 |
+
new_ids_list.append(id_list[i])
|
41 |
+
i+=1
|
42 |
+
|
43 |
+
return new_ids_list
|
44 |
+
|
45 |
+
class Tokenizer():
|
46 |
+
def __init__(self):
|
47 |
+
self.merges = {}
|
48 |
+
##vocab -> (int) : bytes . For all ints (0-256, 256+ from new merges)
|
49 |
+
|
50 |
+
self.vocab = {}
|
51 |
+
self.load()
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
def save(self):
|
56 |
+
with open('merges.txt', 'w') as f:
|
57 |
+
##Write only the pairs. Not the index of the merged pairs.
|
58 |
+
##When the tokenizer is loaded, allow the user to specify the index
|
59 |
+
for p1,p2 in self.merges.keys():
|
60 |
+
f.write(f"{p1} {p2}\n")
|
61 |
+
|
62 |
+
|
63 |
+
with open('vocab.txt', 'w') as f:
|
64 |
+
for idx, byte in self.vocab.items():
|
65 |
+
s = render_token(byte)
|
66 |
+
f.write(f"{idx} {s}\n")
|
67 |
+
|
68 |
+
def _build_vocab(self):
|
69 |
+
self.vocab = {idx: bytes([idx]) for idx in range(256)}
|
70 |
+
try:
|
71 |
+
|
72 |
+
for (tok0, tok1),idx in self.merges.items():
|
73 |
+
self.vocab[idx] = self.vocab[tok0] + self.vocab[tok1]
|
74 |
+
except Exception as e:
|
75 |
+
print(e)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
def load(self):
|
80 |
+
try:
|
81 |
+
# print("Loading", os.getcwd(), "hey" , __file__)
|
82 |
+
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'merges.txt'), 'r') as file:
|
83 |
+
|
84 |
+
idx = 256
|
85 |
+
for line in file:
|
86 |
+
tok0, tok1 = map(int,line.split())
|
87 |
+
self.merges[(tok0, tok1)] = idx
|
88 |
+
idx += 1
|
89 |
+
|
90 |
+
|
91 |
+
# print(self.merges)
|
92 |
+
|
93 |
+
self._build_vocab()
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
print(e)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
# print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))
|
108 |
+
tokenizer = Tokenizer()
|
109 |
+
|
110 |
+
|
tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import get_freq_pairs, merge, Tokenizer
|
2 |
+
|
3 |
+
class BPE(Tokenizer):
|
4 |
+
def __init__(self) -> None:
|
5 |
+
super().__init__()
|
6 |
+
|
7 |
+
def train(self, vocab_size, text):
|
8 |
+
##Vocabulary should contain atleast the ASCII characters
|
9 |
+
assert vocab_size>=256
|
10 |
+
|
11 |
+
num_merges = vocab_size-256
|
12 |
+
tokens = list(text.encode('utf-8'))
|
13 |
+
merges = {}
|
14 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
15 |
+
|
16 |
+
for i in range(num_merges):
|
17 |
+
stats = get_freq_pairs(tokens)
|
18 |
+
max_pair = max(stats, key=stats.get)
|
19 |
+
idx = 256 + i
|
20 |
+
tokens = merge(tokens, max_pair, idx)
|
21 |
+
merges[max_pair] = idx
|
22 |
+
vocab[idx] = vocab[max_pair[0]] + vocab[max_pair[1]]
|
23 |
+
|
24 |
+
|
25 |
+
self.merges = merges
|
26 |
+
self.vocab = vocab
|
27 |
+
|
28 |
+
self.save()
|
29 |
+
|
30 |
+
def encode(self, text):
|
31 |
+
ids = list(text.encode('utf-8'))
|
32 |
+
# print(ids)
|
33 |
+
# assert len(self.merges) > 0
|
34 |
+
##if len(ids) is greater than 2, we need to merge it
|
35 |
+
while True:
|
36 |
+
pair_counts = get_freq_pairs(ids)
|
37 |
+
# print(pair_counts)
|
38 |
+
|
39 |
+
min_index_pair = min(pair_counts, key= lambda x: self.merges.get(x, float('inf')))
|
40 |
+
if(min_index_pair) not in self.merges:
|
41 |
+
break
|
42 |
+
|
43 |
+
idx = self.merges.get(min_index_pair)
|
44 |
+
# print(ids)
|
45 |
+
ids = merge(ids, min_index_pair, idx)
|
46 |
+
return ids
|
47 |
+
|
48 |
+
def decode(self, ids):
|
49 |
+
print(ids)
|
50 |
+
# given ids (list of integers), return Python string
|
51 |
+
text_bytes = b"".join(self.vocab[idx] for idx in ids)
|
52 |
+
text = text_bytes.decode("utf-8", errors="replace")
|
53 |
+
return text
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
|
58 |
+
tokenizer = tokenizer()
|
59 |
+
|
60 |
+
with open('cindrella_stories.txt', 'r') as f:
|
61 |
+
text = f.read()
|
62 |
+
|
63 |
+
|
64 |
+
tokenizer.train(500, text)
|
65 |
+
|
66 |
+
s = "😁"
|
67 |
+
print("String is",s)
|
68 |
+
|
69 |
+
ids = tokenizer.encode(s)
|
70 |
+
print("Encoded string ",ids)
|
71 |
+
decoded_string = tokenizer.decode(ids)
|
72 |
+
print("Decoded string ",decoded_string)
|