File size: 4,001 Bytes
f5eb6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
os.chdir(current_dir)

from tqdm import tqdm
import json

class KmerPairTokenizer:
  def __init__(self):
    self.k_mers = 4
    self.vocab = {}
    self.merges = {}
    self.vocab_size = 0
    self.init_vocab = {"\n": 1, "A": 2, "T": 3, "G": 4, "C": 5, "P": 6, "M": 7, "U": 8, " ": 9}
  
  def _tokenize_seq(self, sequence):
    kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc="tokenizing k-mers")]
    return kmers
  
  def _get_stats(self, ids, counts=None):
    """
      takes list of integers and returns dictionary of counts of pairs(consecutive ones)
      eg: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
      allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):
      counts[pair] = counts.get(pair, 0) + 1
    return counts

  def _merge(self, ids, pair, idx):
    """
      in the list of integers, replaces all consecutive pair with the new integer token idx
      eg: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    new_ids = []
    i = 0
    while i < len(ids):
      if i+1 < len(ids) and ids[i] == pair[0] and ids[i+1] == pair[1]:
        new_ids.append(idx)
        i += 2
      else:
        new_ids.append(ids[i])
        i += 1
    return new_ids
  
  def get_ids(self, data):
    all_kmers = []
    seq_to_no = {}
    ass_no = []
    i = 1
    for seq in data:
      all_kmers.extend(self._tokenize_seq(seq))

    for seq in all_kmers:
      if seq not in seq_to_no:
        seq_to_no[seq] = i
        i += 1
      ass_no.append(seq_to_no[seq])
    
    del all_kmers, i
    return ass_no, seq_to_no

  def train_tokenizer(self, data: str, max_vocab: int):
    n_merges = max_vocab
    text_pairs, init_vocab = self.get_ids([data])
    ids = list(text_pairs)

    del text_pairs, max_vocab
    merges = {}
    ids_len = len(init_vocab)

    for i in tqdm(range(n_merges), desc="training the tokenizer"):
      stats = self._get_stats(ids)
      pair = max(stats, key=stats.get)
      idx = ids_len + i + 1
      ids = self._merge(ids, pair, idx)
      merges[pair] = idx

    vocab = {value: key for key, value in init_vocab.items()}    
    for (p0, p1), idx in merges.items():
      vocab[idx] = vocab[p0] + vocab[p1]

    self.vocab = vocab
    self.merges = merges
    self.vocab_size = len(self.vocab)

    del vocab, merges, ids, stats, pair, idx
  
  def encode(self, text):
    text_pairs, _ = self.get_ids([text])
    ids = list(text_pairs)
    total_pairs = len(ids) - 1

    with tqdm(total=total_pairs, desc="Encoding text") as pbar:
      while len(ids) >= 2:
        stats = self._get_stats(ids)
        pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
        if pair not in self.merges:
          break
        idx = self.merges[pair]
        ids = self._merge(ids, pair, idx)
        pbar.update(1)
    return ids

  def decode(self, ids):
    tokens = [self.vocab[idx] for idx in ids]
    sequence = ''.join(tokens)
    return sequence
  
  def save_model(self, file_path):
    model_file = file_path + f"/base_mer.model"
    vocab_file = file_path + f"/base_kmer.json"

    with open(model_file, 'w', encoding='utf-8') as f:
      for ids1, ids2 in self.merges:
        f.write(f"{ids1} {ids2}\n")
    with open(vocab_file, 'w') as f:
      json.dump(self.vocab, f)
    print('model file saved successfully!')
  
  def load(self, model_path, vocab_path):
    assert model_path.endswith('.model')
    assert vocab_path.endswith('.json')

    with open(vocab_path, 'r') as f:
      vocab_data = json.load(f)
      
    self.vocab = vocab_data
    self.vocab_size = len(vocab_data)

    merges = {}
    idx = 256 + 1
    with open(model_path, 'r', encoding='utf-8') as fread:
      for line in fread:
        idx1, idx2 = map(int, line.split())
        merges[(idx1, idx2)] = idx
        idx += 1
    self.merges = merges