{ "cells": [ { "cell_type": "code", "execution_count": 90, "metadata": { "id": "dbsnrDKKVarI", "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "4940bbc312654a479a2d006fb193a1dc", "cb166763984b4dc9ac578363ee7a68a0", "8980f01ea76b42bf9c187fa231ea2032", "10e961c20ac04aacbe65379bfe88ffde", "9843b879634248a6a75bc04927868ebf", "d1c98c8bd2584a2a951cc70c36424dde", "60e19650cb264161ae853de734324ed3", "9d4909c2f7894859b494b153da044720", "77c7c0ccfcd9446e8017dbe72b2852f2", "cc755a63cfa0414cba5369e1c29952d8", "a8ef20fdbb0f44e78cc5f9c4e6fdb23c", "821b194c386e483b9eb472e04dd61b59", "fe16b47bf80b4ce3ba4f4233c7994e0c", "de3a8b4206b54a5d86f556895a7a8261", "1335110c980b4d8894cf3d38c7020c20", "b31e53ab714f405fb38f1202b43b950f", "33d06740ad07403aa39b935896cf624d", "b814e9cbda7d4b9b8967d71718980994", "dba5a7d271ef429ebaed65d3d1773616", "59308a01d3034b42977b2df4cd9faacf", "9c9c7af2bb4a4b92a69987dd6a5ad09f", "53d6dfc9923c4c1b8866f7e03cbdca53", "e5bf5d3476654592a16213663e4a2bf8", "3fc3505da3664d95823ee2deb60237ea", "3484f174bd734930b367bc53bfa00d0d", "2372cce16b6a43a097bcb621a38e6e81", "01476e05861d463bb6155f078c250ab5", "d8da5a2dbae24c06881c5b6caa68398d", "ae3581adbc404f9fb28806779623ee7b", "80eeedade7124e25acc3b82cf2cd3913", "ef0ca9862e5642fc941f5ad986d7ac23", "342b62d5cd464efbab98c46f712ccaf0", "ba1b482addfb4af5906564b9e6a33cbe", "53b3d56551c74a928a6340744564e4b1", "d0246c0c4fbc4280892242c4bbe2d534", "e176c733436e445691c99519d4afae5d", "3c5318c4d6bd4c98876cfa24557fe04e", "f1766d28b0a3444f8c4aab56431341ba", "8185d6e8df4b42239632418dd73fa52c", "3e0786124dff4ad99fa1760d8652da77", "860f065014494f41940b580478f6edc9", "ce734f6e638548d19eb3e8424b7cfe39", "096689e22c204388983b8f9711afa836", "7c70cb93f3a3463192b8c33ddb53179f", "c77e344106cc4936a615aa2e29db011c", "9edb8202f5a146bb8c625529db891359", "edffa04c9dce4eea8f496f8090f26cbf", "393cfd777a3449798e4f3ed331c325fa", "56b1a5dcb4514f35a58e9bf1130a46ee", "09c26bf8735a431d9e5d867b4201c3cc", "1c860651d30a42489f1816db4a2edd90", "98a128a931f6483696126b3eb7ad7f80", "3f854219ee394259b2ae3198427a821c", "e94e3433d4a442938e218f45947b007e", "ab90b0e8d2804218ae3b29828404d0c8", "3cf559f2135144dcb9de4b7a1f4d0e0c", "65cf73e666844d1685943e1c7b9c202f", "9fac853bef854b79ae97ef061349441d", "8733866f7ac1438f9a70166647e17216", "42f8ee052c814e62bfad7048ff9521c2", "d33c1c4011e840e4983ec9562e37606d", "8d25a070cf1d4cbcb2a2a014df01dd2b", "f40104e4b3a544908ad4a3ed54e610aa", "2c51ed4764a34a998ff674f6202de391", "a97e5ef737d34460991a0827dae059af", "77cab8fdf36e4c26b995f919dbbfd3df", "a2bbf27484b8448b85a3814c28b6b0e0", "aaed9115254148a78ce2f4e23105260d", "c9c353fc75d641d883a2373f53f9b2f5", "43caad4161444864974cd05836d51b15", "7858e452142c4285a9f88ba20b91e851", "67b7376057fd46e89a2dcd295ff1682b", "f3a763184fe9468c8a89465fa3bae703", "603c8de331f24fb38c512c549a1f4770", "f297fa9f28344c1988b498064c9d779e", "4670a05116e84358b5505adfaace7cb0", "0c717191af2643a588ba4149a2f2c6e7", "5d1e9fa2170949c497c86105a041bab0", "1f13f740fcf64fe0ba0d510fce11e87f", "c337381b3b5246ad9034f71f7fa77d2f", "2247e86b596d47c79a8ea0febb316925", "84f9e148be4d464a8548a38910ca4141", "6d21eeb340414cd897ceb8043402cb2c", "75256bd58af04da7bb1e21e7783937a6", "b810ca727b7048cdbcf87214a1348581", "16f8cf2352da4849bda87a5e9970c46e", "c9e6c8f02a5c494aafa98432510dfc1d", "e59d30fc2eaf4f8fb07c77c1d9b95d77", "76220d3230244c909e16a0612a296f0a", "3ac15cd823044417b36dce730ab8f184", "d560df2b7b414049a4927041bc371337", "d2c1fb4d6d064ee7a1a6d168ebf3a8dd", "66e0730542614559a16b04e9e8974576", "c449d3e341c241efbca470080f9702ab", "fde0dc4e5c1e485d979c2635e88a7eca", "485c3146b9684f4cbf2b267f1afc13ac", "6087db2f7d7d4b2483a53ca40ca7dfc0", "5bba74bb6bd749d4a5752e35bbb2bfd8", "91c769901f784d3ba2d6f40a96cf176c", "09a68a426e16470c9765be71c5525926", "a48eb5d08cdb4bd68c732ec8764d029d", "7b968692a15a4590ba5b6f9e957b0005", "16517b6ea6d64b69abe72c76dbd38e70", "370c9bd31fd2495ca5e06f10e9889a11", "0a0a53ce7f4c4848877b40be2589fc27", "7dde8d78da5c4576ade994a0dad4b991", "81409b94afaf4666b79ca98b64b40b27", "a5b09b70a1d646cea7ff1143a784b349", "e2804b0d3b3d40c2bbef5e33e46e9ead", "6f8761d7de4e4b0f8c054f86399543aa" ] }, "outputId": "5808189b-e624-42d7-856f-bc3b0201fab9", "ExecuteTime": { "end_time": "2024-04-16T22:59:34.679348Z", "start_time": "2024-04-16T22:59:19.314163Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: datasets in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (2.18.0)\n", "Requirement already satisfied: wandb in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (0.16.6)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "wandb: WARNING Calling wandb.login() after wandb.init() has no effect.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: accelerate in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (0.28.0)\n", "Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.Requirement already satisfied: filelock in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.9.0)\n", "\n", "Requirement already satisfied: numpy>=1.17 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=12.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (15.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (2.2.1)\n", "Requirement already satisfied: requests>=2.19.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (4.65.0)\n", "Requirement already satisfied: xxhash in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)\n", "Requirement already satisfied: aiohttp in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.9.3)\n", "Requirement already satisfied: huggingface-hub>=0.19.4 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.21.4)\n", "Requirement already satisfied: packaging in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (8.1.7)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (3.1.43)\n", "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (5.9.8)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.45.0)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: setproctitle in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.3.3)\n", "Requirement already satisfied: setuptools in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (69.2.0)\n", "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (4.25.3)\n", "Requirement already satisfied: torch>=1.10.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from accelerate) (2.0.1+cu118)\n", "Requirement already satisfied: safetensors>=0.3.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from accelerate) (0.4.2)\n", "Requirement already satisfied: colorama in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n", "Requirement already satisfied: six>=1.4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from huggingface-hub>=0.19.4->datasets) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (2.0.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n", "Requirement already satisfied: sympy in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", "Token is valid (permission: write).\n", "Your token has been saved to C:\\Users\\saad.naeem\\.cache\\huggingface\\token\n", "Login successful\n" ] } ], "source": [ "# @title # 🌊 AutoBitnet\n", "\n", "# @markdown ---\n", "\n", "# @markdown ### ✨ Model Parameters\n", "\n", "MODEL_CONFIG = \"NousResearch/Nous-Hermes-llama-2-7b\" # @param {type:\"string\"}\n", "HEADS = 6 # @param {type: \"number\"}\n", "DIMENSIONS = 768 # @param {type: \"number\"}\n", "LAYERS = 6 # @param {type: \"number\"}\n", "INTERMEDIATE_SIZE= 1024 # @param {type: \"number\"}\n", "CONTEXT_LENGTH = 256 # @param {type: \"number\"}\n", "HUGGINGFACE_ID = \"saadnaeem\" # @param {type:\"string\"}\n", "NEW_MODEL = \"Llama2-70M-Cosmopedia-100k-Pretrained\" # @param {type:\"string\"}\n", "WANDB_TOKEN=''\n", "HF_TOKEN=''\n", "\n", "# @markdown ---\n", "\n", "# @markdown ### 💥 Training Parameters\n", "\n", "DATASET = \"abideen/Cosmopedia-100k-pretrain\" # @param {type:\"string\"}\n", "BATCH_SIZE = 32 # @param {type:\"number\"}\n", "LEARNING_RATE = 1.5e-4 # @param {type:\"number\"}\n", "EPOCHS = 1 # @param {type:\"number\"}\n", "!pip install datasets wandb accelerate\n", "from torch import nn\n", "from transformers.models.llama.modeling_llama import *\n", "from transformers import (AutoTokenizer, AutoConfig, LlamaForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, AutoModel)\n", "from datasets import load_dataset\n", "from huggingface_hub import login\n", "import wandb\n", "# wandb.ai/saadnaeem-dev\n", "\n", "from huggingface_hub import create_repo, HfApi\n", "\n", "def activation_quant(x):\n", " scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)\n", " y = (x * scale).round().clamp_(-128, 127) / scale\n", " return y\n", "def weight_quant(w):\n", " scale = 1.0 / w.abs().mean().clamp_(min=1e-5)\n", " u = (w * scale).round().clamp_(-1, 1) / scale\n", " return u\n", "\n", "class BitLinear(nn.Linear):\n", " def forward(self, x):\n", " w = self.weight # a weight tensor with shape [d, k]\n", " x = x.to(w.device)\n", " RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)\n", " x_norm = RMSNorm(x)\n", " # A trick for implementing Straight−Through−Estimator (STE) using detach()\n", " x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()\n", " w_quant = w + (weight_quant(w) - w).detach()\n", " y = F.linear(x_quant, w_quant)\n", " return y\n", "\n", "def convert_to_bitnet(model, copy_weights):\n", " for name, module in model.named_modules():\n", " # Replace linear layers with BitNet\n", " if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):\n", " for child_name, child_module in module.named_children():\n", " if isinstance(child_module, nn.Linear):\n", " bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device=\"cuda:0\")\n", " if copy_weights:\n", " bitlinear.weight = child_module.weight\n", " if child_module.bias is not None:\n", " bitlinear.bias = child_module.bias\n", " setattr(module, child_name, bitlinear)\n", " # Remove redundant input_layernorms\n", " elif isinstance(module, LlamaDecoderLayer):\n", " for child_name, child_module in module.named_children():\n", " if isinstance(child_module, LlamaRMSNorm) and child_name == \"input_layernorm\":\n", " setattr(module, child_name, nn.Identity().to(device=\"cuda:0\"))\n", "\n", "\n", "wandb.login(key=WANDB_TOKEN)\n", "login(token=HF_TOKEN)\n", "data = load_dataset(DATASET)" ] }, { "cell_type": "code", "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG)\n", "\n", "def tokenize(element):\n", " outputs = tokenizer(\n", " element[\"text\"],\n", " truncation=False,\n", " max_length=CONTEXT_LENGTH,\n", " return_overflowing_tokens=True,\n", " return_length=True,\n", " )\n", " # Combine all tokens\n", " combined = []\n", " for tokenized_doc in outputs['input_ids']:\n", " combined += tokenized_doc + [tokenizer.eos_token_id]\n", " # Chunk\n", " input_batch = []\n", " for i in range(0, len(combined) - CONTEXT_LENGTH, CONTEXT_LENGTH):\n", " input_batch.append(combined[i:i+CONTEXT_LENGTH])\n", " return {\"input_ids\": input_batch}\n", "\n", "\n", "\n", "tokenized_data = data.map(\n", " tokenize, batched=True, remove_columns=data[\"train\"].column_names,\n", ")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:55:40.522736Z", "start_time": "2024-04-16T22:55:39.803487Z" } }, "execution_count": 80 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": "DatasetDict({\n train: Dataset({\n features: ['input_ids'],\n num_rows: 476702\n })\n})" }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_data" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:55:41.976375Z", "start_time": "2024-04-16T22:55:41.955375Z" } }, "execution_count": 81 }, { "cell_type": "code", "outputs": [], "source": [ "from datasets import DatasetDict\n", "\n", "# Set the number of rows\n", "tokenized_data['train'].set_format(type='pandas')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:55:43.097378Z", "start_time": "2024-04-16T22:55:43.076377Z" } }, "execution_count": 82 }, { "cell_type": "code", "outputs": [], "source": [ "sampled_dataset = tokenized_data['train'].select(range(1000))\n", "sampled_dataset_dict = DatasetDict({\n", " 'train': sampled_dataset\n", "})" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:55:44.559805Z", "start_time": "2024-04-16T22:55:44.537823Z" } }, "execution_count": 83 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": " input_ids\n0 [1, 2266, 338, 385, 6597, 515, 263, 24499, 299...", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n
input_ids
0[1, 2266, 338, 385, 6597, 515, 263, 24499, 299...
\n
" }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sampled_dataset_dict['train'][0]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:56:05.152275Z", "start_time": "2024-04-16T22:56:05.132254Z" } }, "execution_count": 85 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": "DatasetDict({\n train: Dataset({\n features: ['input_ids'],\n num_rows: 1000\n })\n})" }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_data = sampled_dataset_dict\n", "tokenized_data" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:56:06.004257Z", "start_time": "2024-04-16T22:56:05.990254Z" } }, "execution_count": 86 }, { "cell_type": "code", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training on 256_000 tokens\n", "Model size: 77.5M parameters\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\saad.naeem\\AppData\\Local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages\\accelerate\\accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", " warnings.warn(\n" ] } ], "source": [ "total_tokens = tokenized_data['train'].num_rows * CONTEXT_LENGTH\n", "print(f\"Training on {total_tokens:_} tokens\")\n", "\n", "config = AutoConfig.from_pretrained(\n", " MODEL_CONFIG,\n", " vocab_size=len(tokenizer),\n", " n_ctx=CONTEXT_LENGTH,\n", " bos_token_id=tokenizer.bos_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", ")\n", "\n", "config.hidden_size = DIMENSIONS\n", "config.max_position_embeddings = DIMENSIONS\n", "config.num_attention_heads = HEADS\n", "config.num_hidden_layers = LAYERS\n", "config.num_key_value_heads = HEADS\n", "config.intermediate_size = INTERMEDIATE_SIZE\n", "\n", "### Create the llama model with our custom config. Convert it to bitnet.\n", "model = LlamaForCausalLM(config)\n", "convert_to_bitnet(model, copy_weights=False)\n", "model_size = sum(t.numel() for t in model.parameters())\n", "print(f\"Model size: {model_size/1000**2:.1f}M parameters\")\n", "tokenizer.pad_token = tokenizer.eos_token\n", "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", "\n", "output_path = \"./Llama2-70M-Cosmopedia-100k-Pretrained\"\n", "args = TrainingArguments(\n", " output_dir=output_path,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " logging_steps=100,\n", " gradient_accumulation_steps=2,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=0.01,\n", " warmup_steps=0.1,\n", " lr_scheduler_type=\"cosine\",\n", " learning_rate=LEARNING_RATE,\n", " # max_steps=5000,\n", " save_steps=0.25,\n", " fp16=True,\n", " report_to=\"wandb\"\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " tokenizer=tokenizer,\n", " args=args,\n", " data_collator=data_collator,\n", " train_dataset=tokenized_data[\"train\"],\n", ")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:56:11.670768Z", "start_time": "2024-04-16T22:56:09.804760Z" } }, "execution_count": 87 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": "", "text/html": "\n
\n \n \n [16/16 01:46, Epoch 1/1]\n
\n \n \n \n \n \n \n \n \n \n
StepTraining Loss

" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "TrainOutput(global_step=16, training_loss=9.391032218933105, metrics={'train_runtime': 110.6973, 'train_samples_per_second': 9.034, 'train_steps_per_second': 0.145, 'total_flos': 81244717056000.0, 'train_loss': 9.391032218933105, 'epoch': 1.0})" }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T22:58:19.694121Z", "start_time": "2024-04-16T22:56:28.225808Z" } }, "execution_count": 88 }, { "cell_type": "code", "source": [ "trainer.save_model(f\"{output_path}\")\n", "folder = f\"{output_path}\"\n", "api = HfApi()\n", "create_repo(\n", " repo_id = f\"{HUGGINGFACE_ID}/{NEW_MODEL}\",\n", " repo_type=\"model\",\n", " exist_ok=True,\n", " token=HF_TOKEN,\n", ")\n", "\n", "api.upload_folder(\n", " folder_path=folder,\n", " repo_type=\"model\",\n", " repo_id=f\"{HUGGINGFACE_ID}/{NEW_MODEL}\",\n", " token=HF_TOKEN,\n", ")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 180, "referenced_widgets": [ "a848b25d4d39481c820f8c3f23bcd42a", "cc902988873a4ab9834a24e0af8b2d20", "2a0b1217926d4275b60372115cc98865", "13fb9a000a534fa48ab42645e176fac9", "c5702e0bfb7a48238578de1cb745e840", "1296ba56040e4edf9ccc2e2e8e95013b", "1ae401d7df7a4e1bb830620147dbacfe", "acedbeec367b4b35bb6057215eeb15b4", "c359f46995fb47d7821ae766307da9b3", "85566c08eb3f48f1b8be446cd2d6a317", "ae482357a4da4167b1f03a0e66c1c2ba", "23bef7e85efc401cb585c4f152f90e3c", "16b1e6fc66e0460fa96d754a14be00e0", "e6df2bc2ab824f569ffddcc1da9b2f1e", "0804fd95faad462f9f357d9ac29803a4", "69b053193a7a403285aa81ed0bc2e58d", "74e7862905644de287d15b0f9edac963", "18cc938d01f64e368cec6d191c82752b", "16ad95ade9c74a37a1bc22a77ce09b7b", "cc533389f6d44dbdbd6db5987f27b899", "09a0af0767fc493a8fb538ebc1999729", "8a9793e95bc34719aaf321dbfc29fc49", "034a1b93dd274c079ab78534d9735514", "17b2a21e1b5d44099c5b2621077b0d5d", "736074bd252b42b5ad50e292b5b4867f", "3eebb06c6444454f807c93da83cfd7db", "5f2f9367d4c34006b0e4676caf88a9f4", "4d0998c6a38f455495adb68f6ac89caf", "00c4e4ebe0404d2fa26861092a29ac25", "cf28eca638474e2fa2667e53185d2a17", "127e34d4e8e44071b7a6a3c1463dab97", "781a1633fe07463d9e764b7708bfbd47", "671280bc2bdb4d97973e8d4bc99ac36f", "a7dee397e405418a84d122fd308818d8", "26e8629200724daf91feeed8938ca836", "9670e65c8f144a3a92255030862150cf", "864f9d376f2a47ba9b71f738db56d5f9", "fdf5d01ee85f4ecfa3d1741f454a4516", "b7df02dc187b40e99293f182a8084b72", "a005ae2fef0e44b5aeefb967542631de", "43e9230138fe45adba98d89e051319fe", "f099fe879f404002abd817bd5fa32d43", "d3038606a04844ba810073f7d6e027cc", "740c3d5338af41dfa2d7e1d90823cb6c" ] }, "id": "mnHZU06l5tG3", "outputId": "bfa63618-ae11-4415-a695-0349dfecf4ad", "is_executing": true, "ExecuteTime": { "start_time": "2024-04-16T22:59:36.439911Z" } }, "execution_count": null, "outputs": [ { "data": { "text/plain": "Upload 9 LFS files: 0%| | 0/9 [00:00