import os from tqdm import tqdm import json class KMerTokenizer: def __init__(self, k_mers: int=4): self.k_mers = k_mers self.vocab = {} self.id_to_token = [] self.token_to_id = {} def tokenize_sequence(self, sequence): kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc="tokenizing k-mers")] return kmers def build_vocab(self, sequences): all_kmers = [] for sequence in sequences: all_kmers.extend(self.tokenize_sequence(sequence)) token_count = {} for kmer in all_kmers: if kmer in token_count: token_count[kmer] += 1 else: token_count[kmer] = 1 sorted_tokens = sorted(token_count.items(), key=lambda x: x[1], reverse=True) for token, _ in sorted_tokens: self.token_to_id[token] = len(self.token_to_id) self.id_to_token.append(token) self.vocab = self.token_to_id def encode(self, sequence): encoded_sequence = [] kmers = self.tokenize_sequence(sequence) for kmer in tqdm(kmers, desc="encoding sequences"): if kmer in self.token_to_id: encoded_sequence.append(self.token_to_id[kmer]) else: encoded_sequence.append(len(self.vocab)) return encoded_sequence def decode(self, encoded_sequence): decoded_sequence = [self.id_to_token[token_id] for token_id in encoded_sequence] return decoded_sequence def save_model(self, model_path): vocab_file = f"{model_path}/base_{self.k_mers}k.json" with open(vocab_file, 'w') as f: json.dump(self.vocab, f) def load_model(self, path): assert path.endswith('.json') with open(path, 'r') as f: vocab = json.load(f) self.vocab = vocab self.token_to_id = self.vocab self.vocab_size = len(vocab)