import json import sys import os sys.path.append("./") def render_token(t: bytes) -> str: # pretty print a token, escaping control characters s = t.decode('utf-8', errors='replace') # s = replace_control_characters(s) return s def get_freq_pairs(inp_toks): """Returns a count of the pairs""" count = {} for pair in zip(inp_toks, inp_toks[1:]): count[pair] = count.get(pair,0) + 1 return count def merge(id_list, pair, replace_with_idx): """ Replace the occurence of 'pair' in 'id_list' with 'replace_with_idx' id_list : List of tokens pair : List of 2 numbers replace_with_idx : Int value Returns new list with the pair replaced """ i=0 new_ids_list = [] while(i (int) : bytes . For all ints (0-256, 256+ from new merges) self.vocab = {} self.load() def save(self): with open('merges.txt', 'w') as f: ##Write only the pairs. Not the index of the merged pairs. ##When the tokenizer is loaded, allow the user to specify the index for p1,p2 in self.merges.keys(): f.write(f"{p1} {p2}\n") with open('vocab.txt', 'w') as f: for idx, byte in self.vocab.items(): s = render_token(byte) f.write(f"{idx} {s}\n") def _build_vocab(self): self.vocab = {idx: bytes([idx]) for idx in range(256)} try: for (tok0, tok1),idx in self.merges.items(): self.vocab[idx] = self.vocab[tok0] + self.vocab[tok1] except Exception as e: print(e) def load(self): try: # print("Loading", os.getcwd(), "hey" , __file__) with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'merges.txt'), 'r') as file: idx = 256 for line in file: tok0, tok1 = map(int,line.split()) self.merges[(tok0, tok1)] = idx idx += 1 # print(self.merges) self._build_vocab() except Exception as e: print(e) if __name__ == '__main__': # print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99)) tokenizer = Tokenizer()