|
import os |
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
os.chdir(current_dir) |
|
|
|
from tqdm import tqdm |
|
import json |
|
|
|
class KmerPairTokenizer: |
|
def __init__(self): |
|
self.k_mers = 4 |
|
self.vocab = {} |
|
self.merges = {} |
|
self.vocab_size = 0 |
|
self.init_vocab = {"\n": 1, "A": 2, "T": 3, "G": 4, "C": 5, "P": 6, "M": 7, "U": 8, " ": 9} |
|
|
|
def _tokenize_seq(self, sequence): |
|
kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc="tokenizing k-mers")] |
|
return kmers |
|
|
|
def _get_stats(self, ids, counts=None): |
|
""" |
|
takes list of integers and returns dictionary of counts of pairs(consecutive ones) |
|
eg: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} |
|
allows to update an existing dictionary of counts |
|
""" |
|
counts = {} if counts is None else counts |
|
for pair in zip(ids, ids[1:]): |
|
counts[pair] = counts.get(pair, 0) + 1 |
|
return counts |
|
|
|
def _merge(self, ids, pair, idx): |
|
""" |
|
in the list of integers, replaces all consecutive pair with the new integer token idx |
|
eg: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] |
|
""" |
|
new_ids = [] |
|
i = 0 |
|
while i < len(ids): |
|
if i+1 < len(ids) and ids[i] == pair[0] and ids[i+1] == pair[1]: |
|
new_ids.append(idx) |
|
i += 2 |
|
else: |
|
new_ids.append(ids[i]) |
|
i += 1 |
|
return new_ids |
|
|
|
def get_ids(self, data): |
|
all_kmers = [] |
|
seq_to_no = {} |
|
ass_no = [] |
|
i = 1 |
|
for seq in data: |
|
all_kmers.extend(self._tokenize_seq(seq)) |
|
|
|
for seq in all_kmers: |
|
if seq not in seq_to_no: |
|
seq_to_no[seq] = i |
|
i += 1 |
|
ass_no.append(seq_to_no[seq]) |
|
|
|
del all_kmers, i |
|
return ass_no, seq_to_no |
|
|
|
def train_tokenizer(self, data: str, max_vocab: int): |
|
n_merges = max_vocab |
|
text_pairs, init_vocab = self.get_ids([data]) |
|
ids = list(text_pairs) |
|
|
|
del text_pairs, max_vocab |
|
merges = {} |
|
ids_len = len(init_vocab) |
|
|
|
for i in tqdm(range(n_merges), desc="training the tokenizer"): |
|
stats = self._get_stats(ids) |
|
pair = max(stats, key=stats.get) |
|
idx = ids_len + i + 1 |
|
ids = self._merge(ids, pair, idx) |
|
merges[pair] = idx |
|
|
|
vocab = {value: key for key, value in init_vocab.items()} |
|
for (p0, p1), idx in merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
|
|
self.vocab = vocab |
|
self.merges = merges |
|
self.vocab_size = len(self.vocab) |
|
|
|
del vocab, merges, ids, stats, pair, idx |
|
|
|
def encode(self, text): |
|
text_pairs, _ = self.get_ids([text]) |
|
ids = list(text_pairs) |
|
total_pairs = len(ids) - 1 |
|
|
|
with tqdm(total=total_pairs, desc="Encoding text") as pbar: |
|
while len(ids) >= 2: |
|
stats = self._get_stats(ids) |
|
pair = min(stats, key=lambda p: self.merges.get(p, float('inf'))) |
|
if pair not in self.merges: |
|
break |
|
idx = self.merges[pair] |
|
ids = self._merge(ids, pair, idx) |
|
pbar.update(1) |
|
return ids |
|
|
|
def decode(self, ids): |
|
tokens = [self.vocab[idx] for idx in ids] |
|
sequence = ''.join(tokens) |
|
return sequence |
|
|
|
def save_model(self, file_path): |
|
model_file = file_path + f"/base_mer.model" |
|
vocab_file = file_path + f"/base_kmer.json" |
|
|
|
with open(model_file, 'w', encoding='utf-8') as f: |
|
for ids1, ids2 in self.merges: |
|
f.write(f"{ids1} {ids2}\n") |
|
with open(vocab_file, 'w') as f: |
|
json.dump(self.vocab, f) |
|
print('model file saved successfully!') |
|
|
|
def load(self, model_path, vocab_path): |
|
assert model_path.endswith('.model') |
|
assert vocab_path.endswith('.json') |
|
|
|
with open(vocab_path, 'r') as f: |
|
vocab_data = json.load(f) |
|
|
|
self.vocab = vocab_data |
|
self.vocab_size = len(vocab_data) |
|
|
|
merges = {} |
|
idx = 256 + 1 |
|
with open(model_path, 'r', encoding='utf-8') as fread: |
|
for line in fread: |
|
idx1, idx2 = map(int, line.split()) |
|
merges[(idx1, idx2)] = idx |
|
idx += 1 |
|
self.merges = merges |
|
|