{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "vXIGN6PAuZWg" }, "source": [ "### Train file for enigma model\n", "\n", "- Contains K-mer tokenizer, k=4, can be changed though\n", "- Train data is available on huggingface repo: [hf/engima-1.5b](https://huggingface.co./shivendrra/enigma-1.5b)\n", "- For now, trainig decoder-based model only\n", "- More about this on github repo: [github/enigma-1.5b](https://github.com/shivendrra/enigma-1.5b)\n", "- Saves model after training in '.pth' & '.safetensors' file for later use\n", "- Generate function works fine" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXpJBLyr30Rx" }, "outputs": [], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r7WUm0VL4bN4" }, "outputs": [], "source": [ "import torch\n", "\n", "# importing the data\n", "file_path = '/content/drive/MyDrive/consolidated_dna.txt'\n", "with open(file_path, 'r', encoding='utf-8') as file:\n", " dna_seq = file.read()\n", "file.close()\n", "\n", "print(f\"{(len(dna_seq)/1e6):.2f} million letters\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cdhybhz9owTK" }, "outputs": [], "source": [ "import os\n", "from tqdm import tqdm\n", "import json\n", "\n", "class KMerTokenizer:\n", " def __init__(self, k_mers: int=4):\n", " self.k_mers = k_mers\n", " self.vocab = {}\n", " self.id_to_token = []\n", " self.token_to_id = {}\n", "\n", " def tokenize_sequence(self, sequence):\n", " kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc=\"tokenizing k-mers\")]\n", " return kmers\n", "\n", " def build_vocab(self, sequences):\n", " all_kmers = []\n", " for sequence in sequences:\n", " all_kmers.extend(self.tokenize_sequence(sequence))\n", " token_count = {}\n", " for kmer in all_kmers:\n", " if kmer in token_count:\n", " token_count[kmer] += 1\n", " else:\n", " token_count[kmer] = 1\n", " sorted_tokens = sorted(token_count.items(), key=lambda x: x[1], reverse=True)\n", " for token, _ in sorted_tokens:\n", " self.token_to_id[token] = len(self.token_to_id)\n", " self.id_to_token.append(token)\n", " self.vocab = self.token_to_id\n", "\n", " def encode(self, sequence):\n", " encoded_sequence = []\n", " kmers = self.tokenize_sequence(sequence)\n", " for kmer in tqdm(kmers, desc=\"encoding sequences\"):\n", " if kmer in self.token_to_id:\n", " encoded_sequence.append(self.token_to_id[kmer])\n", " else:\n", " encoded_sequence.append(len(self.vocab))\n", " return encoded_sequence\n", "\n", " def decode(self, encoded_sequence):\n", " decoded_sequence = [self.id_to_token[token_id] for token_id in encoded_sequence]\n", " return decoded_sequence\n", "\n", " def save_model(self, model_path):\n", " vocab_file = f\"{model_path}/base_{self.k_mers}k.json\"\n", " with open(vocab_file, 'w') as f:\n", " json.dump(self.vocab, f)\n", "\n", " def load_model(self, path):\n", " assert path.endswith('.json')\n", " with open(path, 'r') as f:\n", " vocab = json.load(f)\n", "\n", " self.vocab = vocab\n", " self.token_to_id = self.vocab\n", " self.vocab_size = len(vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6BCpjdi5rjU4" }, "outputs": [], "source": [ "token = KMerTokenizer()\n", "token.build_vocab([dna_seq])\n", "print(f\"vocab size: {len(token.vocab)}\")\n", "print(token.id_to_token[:10])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Ou9txgmAdIB" }, "outputs": [], "source": [ "# Train and test splits\n", "data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n", "print(f\"{(len(data)/1e6):0f} million\"\")\n", "n = int(0.9*len(data)) # first 90% will be train, rest val\n", "train_data = data[:n]\n", "val_data = data[n:]\n", "print(f\"train data {(len(train_data)/1e6):.0f}million, val data {(len(val_data)/1e6):.0f}million\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ebFKQQ9NAq4e" }, "outputs": [], "source": [ "# hyperparams\n", "batch_size = 10\n", "block_size = 256\n", "max_iters = 5000\n", "eval_interval = 100\n", "learning_rate = 3e-5\n", "eval_iters = 100\n", "d_model = 512\n", "n_layers = 12\n", "n_head = 18\n", "dropout = 0.25\n", "norm_eps = 1e-5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dZMiYkr37cmU" }, "outputs": [], "source": [ "import torch.nn as nn\n", "from torch.nn import functional as F\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "class RMSNorm(nn.Module):\n", " def __init__(self, dim: int, eps: float = 1e-6):\n", " super().__init__()\n", " self.eps = eps\n", " self.weight = nn.Parameter(torch.ones(dim))\n", "\n", " def _norm(self, x):\n", " return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n", "\n", " def forward(self, x):\n", " output = self._norm(x.float()).type_as(x)\n", " return output * self.weight\n", "\n", "class SingleHead(nn.Module):\n", " def __init__(self,\n", " head_size: int,\n", " d_model: int,\n", " block_size: int,\n", " dropout: float):\n", " super().__init__()\n", " self.key = nn.Linear(d_model, head_size, bias=True)\n", " self.query = nn.Linear(d_model, head_size, bias=True)\n", " self.value = nn.Linear(d_model, head_size, bias=True)\n", " self.dropout = nn.Dropout(dropout)\n", " self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))\n", " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", "\n", " def forward(self, x: torch.Tensor, mask: bool= False):\n", " B, T, C = x.shape\n", " key = self.key(x)\n", " query = self.query(x)\n", " scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)\n", "\n", " if mask is True:\n", " scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n", "\n", " rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])\n", " scores = scores + rel_pos_scores\n", "\n", " att_mat = F.softmax(scores, dim=-1)\n", " att_mat = self.dropout(att_mat)\n", " value = self.value(x)\n", " output = torch.matmul(att_mat, value)\n", " return output\n", "\n", "class MultiHeadAttention(nn.Module):\n", " def __init__(self,\n", " d_model: int,\n", " block_size: int,\n", " n_head : int,\n", " dropout: float):\n", " head_size = d_model // n_head\n", " super().__init__()\n", " self.heads = nn.ModuleList([SingleHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n", " self.projection = nn.Linear(d_model, d_model)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x: torch.Tensor, mask: bool):\n", " out = torch.cat([h(x, mask) for h in self.heads], dim=-1)\n", " out = self.dropout(self.projection(out))\n", " return out\n", "\n", "class FeedForward(nn.Module):\n", " def __init__(self, d_model, dropout):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(d_model, 5 * d_model),\n", " nn.GELU(),\n", " nn.Linear(5 * d_model, d_model),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x: torch.Tensor):\n", " return self.net(x)\n", "\n", "class DecoderBlock(nn.Module):\n", " def __init__(self, d_model: int,\n", " block_size: int,\n", " n_head: int,\n", " norm_eps: float,\n", " dropout: float):\n", " super().__init__()\n", " self.self_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n", " self.ffwd = FeedForward(d_model, dropout)\n", " self.dropout = nn.Dropout(dropout)\n", " self.norm = RMSNorm(d_model, eps=norm_eps)\n", "\n", " def forward(self, x: torch.Tensor):\n", " x_out = self.self_att(self.norm(x), mask=True)\n", " x_out = x + self.dropout(x_out)\n", " del x\n", "\n", " x = self.self_att(self.norm(x_out, mask=False))\n", " x = x_out + self.dropout(x)\n", " del x_out\n", "\n", " x_out = self.ffwd(self.norm(x))\n", " x_out = x + self.dropout(x_out)\n", " del x\n", "\n", " return x_out\n", "\n", "class Transformer(nn.Module):\n", " def __init__(self, vocab_size: int):\n", " super().__init__()\n", " self.block_size = block_size\n", " self.token_embeddings = nn.Embedding(vocab_size, d_model)\n", " self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])\n", " self.norm_final = RMSNorm(d_model, eps=norm_eps)\n", " self.linear_final = nn.Linear(d_model, vocab_size)\n", " self.dropout = nn.Dropout(dropout)\n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias.data)\n", " elif isinstance(module, nn.Embedding):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", " x = self.token_embeddings(idx)\n", " x = self.decoder(x)\n", " logits = self.linear_final(self.norm_final(x))\n", "\n", " if targets is None:\n", " loss = None\n", "\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C)\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " @torch.no_grad()\n", " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n", " self.eval()\n", " for _ in range(max_new_tokens):\n", "\n", " idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]\n", " logits, _ = self(idx_cond)\n", " logits = logits[:, -1, :] / temperature\n", "\n", " if top_k is not None:\n", " v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n", " logits[logits < v[:, [-1]]] = -float('Inf')\n", "\n", " probs = F.softmax(logits, dim=-1)\n", " idx_next = torch.multinomial(probs, num_samples=1)\n", " idx = torch.cat((idx, idx_next), dim=1)\n", "\n", " return idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X9VOBZFr7g3W" }, "outputs": [], "source": [ "import timeit\n", "start_time = timeit.default_timer()\n", "\n", "def get_batch(split):\n", " data = train_data if split == 'train' else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " x, y = x.to(device), y.to(device)\n", " return x, y\n", "\n", "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in ['train', 'val']:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out\n", "\n", "vocab_size = len(token.vocab)\n", "model = Transformer(vocab_size)\n", "# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n", "# checkpoint = torch.load(checkpoint_path)\n", "# model.load_state_dict(checkpoint)\n", "m = model.to(device)\n", "\n", "# no of parameters\n", "n_param = sum(p.numel() for p in m.parameters())/1e6\n", "print(f\"{n_param:.1f} million parameters\")\n", "\n", "# optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "steps = []\n", "train_losses = []\n", "val_losses = []\n", "\n", "for iter in range(max_iters):\n", "\n", " if iter % eval_interval == 0 or iter == max_iters - 1:\n", " losses = estimate_loss()\n", " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", "\n", " steps.append(iter)\n", " train_losses.append(losses['train'])\n", " val_losses.append(losses['val'])\n", "\n", " xb, yb = get_batch('train')\n", " logits, loss = model(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tzJMKoA35uIV" }, "outputs": [], "source": [ "end_time = timeit.default_timer()\n", "print(f\"total parameters: {n_param:.1f} billion\")\n", "print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eB47Yn9aNrrO" }, "outputs": [], "source": [ "model_save_name = f'consolidated_00.pth'\n", "path = f\"/content/drive/MyDrive/{model_save_name}\"\n", "torch.save(model.state_dict(), path)\n", "\n", "# saving safe-tensors\n", "from safetensors.torch import save_file\n", "\n", "model_save_name = f'consolidated_00.safetensors'\n", "path = f\"/content/drive/MyDrive/{model_save_name}\"\n", "save_file(model.state_dict(), path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "89TNah_89CRB" }, "outputs": [], "source": [ "!nvidia-smi" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }