|
from .base import get_freq_pairs, merge, Tokenizer |
|
|
|
class BPE(Tokenizer): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def train(self, vocab_size, text): |
|
|
|
assert vocab_size>=256 |
|
|
|
num_merges = vocab_size-256 |
|
tokens = list(text.encode('utf-8')) |
|
merges = {} |
|
vocab = {idx: bytes([idx]) for idx in range(256)} |
|
|
|
for i in range(num_merges): |
|
stats = get_freq_pairs(tokens) |
|
max_pair = max(stats, key=stats.get) |
|
idx = 256 + i |
|
tokens = merge(tokens, max_pair, idx) |
|
merges[max_pair] = idx |
|
vocab[idx] = vocab[max_pair[0]] + vocab[max_pair[1]] |
|
|
|
|
|
self.merges = merges |
|
self.vocab = vocab |
|
|
|
self.save() |
|
|
|
def encode(self, text): |
|
ids = list(text.encode('utf-8')) |
|
|
|
|
|
|
|
while True: |
|
pair_counts = get_freq_pairs(ids) |
|
|
|
|
|
min_index_pair = min(pair_counts, key= lambda x: self.merges.get(x, float('inf'))) |
|
if(min_index_pair) not in self.merges: |
|
break |
|
|
|
idx = self.merges.get(min_index_pair) |
|
|
|
ids = merge(ids, min_index_pair, idx) |
|
return ids |
|
|
|
def decode(self, ids): |
|
print(ids) |
|
|
|
text_bytes = b"".join(self.vocab[idx] for idx in ids) |
|
text = text_bytes.decode("utf-8", errors="replace") |
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
tokenizer = tokenizer() |
|
|
|
with open('cindrella_stories.txt', 'r') as f: |
|
text = f.read() |
|
|
|
|
|
tokenizer.train(500, text) |
|
|
|
s = "π" |
|
print("String is",s) |
|
|
|
ids = tokenizer.encode(s) |
|
print("Encoded string ",ids) |
|
decoded_string = tokenizer.decode(ids) |
|
print("Decoded string ",decoded_string) |