gagan001 commited on
Commit
9cdcbb3
·
1 Parent(s): 838f2d4

Added model and tokenizer

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/model_1000_.bin filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1 +1,3 @@
1
- *.txt
 
 
 
1
+ *.txt
2
+ tokenizer/README.md
3
+ __pycache__
app.py CHANGED
@@ -1,11 +1,24 @@
1
  import gradio as gr
2
- import sys
3
- sys.path.append("../Transformers")
 
4
 
 
 
 
5
 
6
 
7
- def greet(name):
8
- return "Hello " + name + "!!"
9
 
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
11
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from my_gpt import my_gpt
4
+ from tokenizer.tokenizer import BPE
5
 
6
+ ##Load model
7
+ model = my_gpt.load_pretrained("model/model_1000_.bin")
8
+ tokenizer = BPE()
9
 
10
 
 
 
11
 
12
+ def generate(input_text):
13
+ tokens = tokenizer.encode(input_text)
14
+ gen_ids = model.generate(torch.tensor([tokens]))
15
+ output = tokenizer.decode(gen_ids[0].tolist())
16
+ return output
17
+
18
+ iface = gr.Interface(fn=generate,
19
+ inputs="text",
20
+ outputs="text",
21
+ title="GPT - 1000 steps",
22
+ description="""This model is trained for 1000 steps only. It is not
23
+ able to generate perfect sentences/words. However, it has learnt a gist of the English language""")
24
  iface.launch()
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/model_1000_.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:566dd60c0869b306eda84fd1d64dbacce02b256492da9ea626d227b046125bd2
3
+ size 56950645
my_gpt.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import json
5
+ import logging
6
+
7
+
8
+ block_size = 256
9
+ vocab_size = 500
10
+ n_embed = 384
11
+ dropout = 0.2
12
+ n_head = 6
13
+ n_layer = 6
14
+
15
+ class Head(nn.Module):
16
+ def __init__(self, head_size=16):
17
+ super().__init__()
18
+ self.query = nn.Linear(n_embed, head_size, bias=False)
19
+ self.key = nn.Linear(n_embed, head_size, bias=False)
20
+ self.value = nn.Linear(n_embed, head_size, bias=False)
21
+ self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
22
+ self.dropout = nn.Dropout(dropout)
23
+
24
+
25
+
26
+ def forward(self,x):
27
+ B,T,C = x.shape
28
+
29
+ q = self.query(x)
30
+ k = self.key(x)
31
+
32
+ wei = (q @ k.transpose(-2,-1)) * (k.shape[-1]**(-0.5))
33
+ wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf'))
34
+ wei = F.softmax(wei, dim=-1)
35
+ wei = self.dropout(wei)
36
+
37
+ v = self.value(x)
38
+
39
+ out = wei @ v ## (B,T,HS)
40
+
41
+ return out
42
+
43
+ class MultiHeadAttention(nn.Module):
44
+ def __init__(self,num_heads, head_size) :
45
+ super().__init__()
46
+
47
+ self.heads = nn.ModuleList(Head(head_size=head_size) for _ in range(num_heads))
48
+ self.proj = nn.Linear(head_size * num_heads, n_embed)
49
+ self.dropout = nn.Dropout(dropout)
50
+
51
+ def forward(self, x):
52
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
53
+ out = self.dropout(self.proj(out))
54
+ return out
55
+
56
+ class FeedForward(nn.Module):
57
+ def __init__(self,n_embed) -> None:
58
+ super().__init__()
59
+ self.net = nn.Sequential(
60
+ nn.Linear(n_embed,4* n_embed),
61
+ nn.ReLU(),
62
+ nn.Linear(4 * n_embed, n_embed),
63
+ nn.Dropout(dropout),
64
+ )
65
+
66
+ def forward(self, x):
67
+ x = self.net(x)
68
+ return x
69
+
70
+ class decoder_block(nn.Module):
71
+ def __init__(self, n_embed, n_heads):
72
+ super().__init__()
73
+ self.sa = MultiHeadAttention(n_heads,n_embed//n_heads)
74
+ self.ln1 = nn.LayerNorm(n_embed)
75
+ self.ln2 = nn.LayerNorm(n_embed)
76
+ self.ffwd = FeedForward(n_embed)
77
+
78
+ def forward(self, x):
79
+ x = x + self.sa(self.ln1(x))
80
+ x = x + self.ffwd(self.ln2(x))
81
+ return x
82
+
83
+
84
+
85
+ class my_gpt(nn.Module):
86
+ def __init__(self, block_size = 256):
87
+ super().__init__()
88
+ self.block_size = block_size ##context window size
89
+ self.token_embed = nn.Embedding(vocab_size, n_embed)
90
+ self.pos_embed = nn.Embedding(vocab_size, n_embed)
91
+ self.lm_head = nn.Linear(n_embed, vocab_size)
92
+ self.sa_head = Head(vocab_size)
93
+ self.d_blocks = nn.Sequential(*[decoder_block(n_embed=n_embed,n_heads=n_head) for _ in range(n_layer)])
94
+ self.ln_f = nn.LayerNorm(n_embed) # final layer norm
95
+
96
+ self.apply(self._init_weights)
97
+
98
+ def _init_weights(self, module):
99
+ if isinstance(module, nn.Linear):
100
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
101
+ if module.bias is not None:
102
+ torch.nn.init.zeros_(module.bias)
103
+ elif isinstance(module, nn.Embedding):
104
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
105
+
106
+ def forward(self, idx, targets = None):
107
+ """
108
+ Args:
109
+ idx: int(B,T) Token ids
110
+ targets :
111
+
112
+ Returns:
113
+ logits
114
+ """
115
+ # print("idx ", idx)
116
+ B, T = idx.shape ##
117
+ tok_emd = self.token_embed(idx) ##(B,T,C)
118
+ pos_emd = self.pos_embed(idx)
119
+
120
+
121
+ x = tok_emd + pos_emd
122
+ # print("x1 ", x.shape)
123
+
124
+ x = self.d_blocks(x) #
125
+ x = self.ln_f(x) # (B,T,C)
126
+
127
+ logits = self.lm_head(x) ##(B,T,vocab_size)
128
+ if targets is None:
129
+ loss = None
130
+ else:
131
+ B, T, C = logits.shape
132
+ # print("logits ", logits.shape)
133
+ logits = logits.view(B*T,C)
134
+ targets = targets.view(B*T)
135
+
136
+ loss = F.cross_entropy(logits, targets)
137
+ # print("Logits", logits.shape)
138
+
139
+ return logits, loss
140
+
141
+
142
+ def generate(self, context : torch.tensor, max_new_tokens: int = 46, use_cache = False):
143
+ """
144
+ Generates the next "max_new_tokens" number of tokens.
145
+
146
+ Args:
147
+ context (B,T):
148
+ max_new_tokens (int):
149
+
150
+ Returns:
151
+ [token] : List of generated tokens.
152
+ """
153
+ # print("Context:" , context)
154
+ for _ in range(max_new_tokens):
155
+ ##Take only last allowed number of tokens
156
+ idx_tokens = context[:, -self.block_size:]
157
+
158
+ ##generate the next token
159
+ logits, loss = self(idx_tokens)
160
+
161
+ ##Take only last allowed number of tokens
162
+ logits = logits[:,-1,:] ##(B,vocab_size)
163
+ # print("logits:" , logits.shape)
164
+
165
+ probs = F.softmax(logits, dim= -1)
166
+ idx_next = torch.multinomial(probs,num_samples=1) ##(B,1)
167
+
168
+ context = torch.concatenate([context, idx_next], dim=1)
169
+
170
+ return context
171
+
172
+ def save_pretrained(self, path):
173
+ torch.save(self.state_dict(),path)
174
+ print("Saved pretrained Successfully")
175
+
176
+ @classmethod
177
+ def load_pretrained(cls, path):
178
+ print("Loading pretrained model...")
179
+ model = cls()
180
+ model.load_state_dict(torch.load(path))
181
+ return model
182
+
183
+
184
+
185
+
186
+
tokenizer/__init__.py ADDED
File without changes
tokenizer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (186 Bytes). View file
 
tokenizer/__pycache__/base.cpython-311.pyc ADDED
Binary file (4.35 kB). View file
 
tokenizer/__pycache__/base.cpython-38.pyc ADDED
Binary file (2.8 kB). View file
 
tokenizer/__pycache__/base.cpython-39.pyc ADDED
Binary file (2.02 kB). View file
 
tokenizer/__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (2.34 kB). View file
 
tokenizer/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (1.97 kB). View file
 
tokenizer/base.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+
5
+
6
+ sys.path.append("./")
7
+ def render_token(t: bytes) -> str:
8
+ # pretty print a token, escaping control characters
9
+ s = t.decode('utf-8', errors='replace')
10
+ # s = replace_control_characters(s)
11
+ return s
12
+
13
+
14
+ def get_freq_pairs(inp_toks):
15
+ """Returns a count of the pairs"""
16
+ count = {}
17
+ for pair in zip(inp_toks, inp_toks[1:]):
18
+ count[pair] = count.get(pair,0) + 1
19
+ return count
20
+
21
+
22
+ def merge(id_list, pair, replace_with_idx):
23
+ """
24
+ Replace the occurence of 'pair' in 'id_list' with 'replace_with_idx'
25
+
26
+ id_list : List of tokens
27
+ pair : List of 2 numbers
28
+ replace_with_idx : Int value
29
+
30
+ Returns new list with the pair replaced
31
+ """
32
+ i=0
33
+ new_ids_list = []
34
+ while(i<len(id_list)):
35
+ if(i<len(id_list)-1 and id_list[i]==pair[0] and id_list[i+1]==pair[1]):
36
+ new_ids_list.append(replace_with_idx)
37
+
38
+ i+=2
39
+ else:
40
+ new_ids_list.append(id_list[i])
41
+ i+=1
42
+
43
+ return new_ids_list
44
+
45
+ class Tokenizer():
46
+ def __init__(self):
47
+ self.merges = {}
48
+ ##vocab -> (int) : bytes . For all ints (0-256, 256+ from new merges)
49
+
50
+ self.vocab = {}
51
+ self.load()
52
+
53
+
54
+
55
+ def save(self):
56
+ with open('merges.txt', 'w') as f:
57
+ ##Write only the pairs. Not the index of the merged pairs.
58
+ ##When the tokenizer is loaded, allow the user to specify the index
59
+ for p1,p2 in self.merges.keys():
60
+ f.write(f"{p1} {p2}\n")
61
+
62
+
63
+ with open('vocab.txt', 'w') as f:
64
+ for idx, byte in self.vocab.items():
65
+ s = render_token(byte)
66
+ f.write(f"{idx} {s}\n")
67
+
68
+ def _build_vocab(self):
69
+ self.vocab = {idx: bytes([idx]) for idx in range(256)}
70
+ try:
71
+
72
+ for (tok0, tok1),idx in self.merges.items():
73
+ self.vocab[idx] = self.vocab[tok0] + self.vocab[tok1]
74
+ except Exception as e:
75
+ print(e)
76
+
77
+
78
+
79
+ def load(self):
80
+ try:
81
+ # print("Loading", os.getcwd(), "hey" , __file__)
82
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'merges.txt'), 'r') as file:
83
+
84
+ idx = 256
85
+ for line in file:
86
+ tok0, tok1 = map(int,line.split())
87
+ self.merges[(tok0, tok1)] = idx
88
+ idx += 1
89
+
90
+
91
+ # print(self.merges)
92
+
93
+ self._build_vocab()
94
+
95
+
96
+
97
+
98
+
99
+
100
+ except Exception as e:
101
+ print(e)
102
+
103
+
104
+
105
+
106
+ if __name__ == '__main__':
107
+ # print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))
108
+ tokenizer = Tokenizer()
109
+
110
+
tokenizer/tokenizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import get_freq_pairs, merge, Tokenizer
2
+
3
+ class BPE(Tokenizer):
4
+ def __init__(self) -> None:
5
+ super().__init__()
6
+
7
+ def train(self, vocab_size, text):
8
+ ##Vocabulary should contain atleast the ASCII characters
9
+ assert vocab_size>=256
10
+
11
+ num_merges = vocab_size-256
12
+ tokens = list(text.encode('utf-8'))
13
+ merges = {}
14
+ vocab = {idx: bytes([idx]) for idx in range(256)}
15
+
16
+ for i in range(num_merges):
17
+ stats = get_freq_pairs(tokens)
18
+ max_pair = max(stats, key=stats.get)
19
+ idx = 256 + i
20
+ tokens = merge(tokens, max_pair, idx)
21
+ merges[max_pair] = idx
22
+ vocab[idx] = vocab[max_pair[0]] + vocab[max_pair[1]]
23
+
24
+
25
+ self.merges = merges
26
+ self.vocab = vocab
27
+
28
+ self.save()
29
+
30
+ def encode(self, text):
31
+ ids = list(text.encode('utf-8'))
32
+ # print(ids)
33
+ # assert len(self.merges) > 0
34
+ ##if len(ids) is greater than 2, we need to merge it
35
+ while True:
36
+ pair_counts = get_freq_pairs(ids)
37
+ # print(pair_counts)
38
+
39
+ min_index_pair = min(pair_counts, key= lambda x: self.merges.get(x, float('inf')))
40
+ if(min_index_pair) not in self.merges:
41
+ break
42
+
43
+ idx = self.merges.get(min_index_pair)
44
+ # print(ids)
45
+ ids = merge(ids, min_index_pair, idx)
46
+ return ids
47
+
48
+ def decode(self, ids):
49
+ print(ids)
50
+ # given ids (list of integers), return Python string
51
+ text_bytes = b"".join(self.vocab[idx] for idx in ids)
52
+ text = text_bytes.decode("utf-8", errors="replace")
53
+ return text
54
+
55
+
56
+ if __name__ == "__main__":
57
+
58
+ tokenizer = tokenizer()
59
+
60
+ with open('cindrella_stories.txt', 'r') as f:
61
+ text = f.read()
62
+
63
+
64
+ tokenizer.train(500, text)
65
+
66
+ s = "😁"
67
+ print("String is",s)
68
+
69
+ ids = tokenizer.encode(s)
70
+ print("Encoded string ",ids)
71
+ decoded_string = tokenizer.decode(ids)
72
+ print("Decoded string ",decoded_string)