shivendrra's picture
added tokenizer files
f5eb6b9 verified
raw
history blame
1.79 kB
import os
from tqdm import tqdm
import json
class KMerTokenizer:
def __init__(self, k_mers: int=4):
self.k_mers = k_mers
self.vocab = {}
self.id_to_token = []
self.token_to_id = {}
def tokenize_sequence(self, sequence):
kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc="tokenizing k-mers")]
return kmers
def build_vocab(self, sequences):
all_kmers = []
for sequence in sequences:
all_kmers.extend(self.tokenize_sequence(sequence))
token_count = {}
for kmer in all_kmers:
if kmer in token_count:
token_count[kmer] += 1
else:
token_count[kmer] = 1
sorted_tokens = sorted(token_count.items(), key=lambda x: x[1], reverse=True)
for token, _ in sorted_tokens:
self.token_to_id[token] = len(self.token_to_id)
self.id_to_token.append(token)
self.vocab = self.token_to_id
def encode(self, sequence):
encoded_sequence = []
kmers = self.tokenize_sequence(sequence)
for kmer in tqdm(kmers, desc="encoding sequences"):
if kmer in self.token_to_id:
encoded_sequence.append(self.token_to_id[kmer])
else:
encoded_sequence.append(len(self.vocab))
return encoded_sequence
def decode(self, encoded_sequence):
decoded_sequence = [self.id_to_token[token_id] for token_id in encoded_sequence]
return decoded_sequence
def save_model(self, model_path):
vocab_file = f"{model_path}/base_{self.k_mers}k.json"
with open(vocab_file, 'w') as f:
json.dump(self.vocab, f)
def load_model(self, path):
assert path.endswith('.json')
with open(path, 'r') as f:
vocab = json.load(f)
self.vocab = vocab
self.token_to_id = self.vocab
self.vocab_size = len(vocab)