enigma-1.5b / tokenizer /kmer_bpe.py
shivendrra's picture
added tokenizer files
f5eb6b9 verified
raw
history blame
4 kB
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
os.chdir(current_dir)
from tqdm import tqdm
import json
class KmerPairTokenizer:
def __init__(self):
self.k_mers = 4
self.vocab = {}
self.merges = {}
self.vocab_size = 0
self.init_vocab = {"\n": 1, "A": 2, "T": 3, "G": 4, "C": 5, "P": 6, "M": 7, "U": 8, " ": 9}
def _tokenize_seq(self, sequence):
kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc="tokenizing k-mers")]
return kmers
def _get_stats(self, ids, counts=None):
"""
takes list of integers and returns dictionary of counts of pairs(consecutive ones)
eg: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def _merge(self, ids, pair, idx):
"""
in the list of integers, replaces all consecutive pair with the new integer token idx
eg: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
new_ids = []
i = 0
while i < len(ids):
if i+1 < len(ids) and ids[i] == pair[0] and ids[i+1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
def get_ids(self, data):
all_kmers = []
seq_to_no = {}
ass_no = []
i = 1
for seq in data:
all_kmers.extend(self._tokenize_seq(seq))
for seq in all_kmers:
if seq not in seq_to_no:
seq_to_no[seq] = i
i += 1
ass_no.append(seq_to_no[seq])
del all_kmers, i
return ass_no, seq_to_no
def train_tokenizer(self, data: str, max_vocab: int):
n_merges = max_vocab
text_pairs, init_vocab = self.get_ids([data])
ids = list(text_pairs)
del text_pairs, max_vocab
merges = {}
ids_len = len(init_vocab)
for i in tqdm(range(n_merges), desc="training the tokenizer"):
stats = self._get_stats(ids)
pair = max(stats, key=stats.get)
idx = ids_len + i + 1
ids = self._merge(ids, pair, idx)
merges[pair] = idx
vocab = {value: key for key, value in init_vocab.items()}
for (p0, p1), idx in merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
self.vocab = vocab
self.merges = merges
self.vocab_size = len(self.vocab)
del vocab, merges, ids, stats, pair, idx
def encode(self, text):
text_pairs, _ = self.get_ids([text])
ids = list(text_pairs)
total_pairs = len(ids) - 1
with tqdm(total=total_pairs, desc="Encoding text") as pbar:
while len(ids) >= 2:
stats = self._get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
if pair not in self.merges:
break
idx = self.merges[pair]
ids = self._merge(ids, pair, idx)
pbar.update(1)
return ids
def decode(self, ids):
tokens = [self.vocab[idx] for idx in ids]
sequence = ''.join(tokens)
return sequence
def save_model(self, file_path):
model_file = file_path + f"/base_mer.model"
vocab_file = file_path + f"/base_kmer.json"
with open(model_file, 'w', encoding='utf-8') as f:
for ids1, ids2 in self.merges:
f.write(f"{ids1} {ids2}\n")
with open(vocab_file, 'w') as f:
json.dump(self.vocab, f)
print('model file saved successfully!')
def load(self, model_path, vocab_path):
assert model_path.endswith('.model')
assert vocab_path.endswith('.json')
with open(vocab_path, 'r') as f:
vocab_data = json.load(f)
self.vocab = vocab_data
self.vocab_size = len(vocab_data)
merges = {}
idx = 256 + 1
with open(model_path, 'r', encoding='utf-8') as fread:
for line in fread:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges