joaoalvarenga commited on
Commit
30df0c3
1 Parent(s): 2598b16

Upload fine-tuning-example.ipynb

Browse files
Files changed (1) hide show
  1. fine-tuning-example.ipynb +325 -0
fine-tuning-example.ipynb ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e13eff4e-c134-4dac-9523-07b297164250",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Example of Fine-tuning 176 billion Bloom with 8-bit weights\n",
9
+ "\n",
10
+ "This notebook shows an example of how to fine tune Bloom with Low Rank Adapters. Heavily inspired by [Hivemind's work](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "699e94eb-3ce1-4788-999b-fb6d593ba7e9",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "!pip install transformers==4.20.1\n",
21
+ "!pip install bitsandbytes-cuda110\n",
22
+ "!pip install datasets"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "0afea72c-691d-4719-a84a-663f1891af6e",
28
+ "metadata": {},
29
+ "source": [
30
+ "### Load and convert original Bloom structure to 8-bit LoRA\n",
31
+ "\n",
32
+ "You can load an already compressed 8-bit version of Bloom from [joaoalvarenga/bloom-8bit](https://huggingface.co/joaoalvarenga/bloom-8bit), but first we need to make some adaptations into original model structure. Some of the following code is an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "aa5f4118-d4d9-474f-ac36-acaadb920c1f",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "import transformers\n",
43
+ "\n",
44
+ "import torch\n",
45
+ "import torch.nn.functional as F\n",
46
+ "from torch import nn\n",
47
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
48
+ "\n",
49
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
50
+ "\n",
51
+ "from tqdm.auto import tqdm"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "cc4f262e-70de-4a06-a5a6-52d1cd5223d3",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "class FrozenBNBLinear(nn.Module):\n",
62
+ " def __init__(self, weight, absmax, code, bias=None):\n",
63
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
64
+ " super().__init__()\n",
65
+ " self.out_features, self.in_features = weight.shape\n",
66
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
67
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
68
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
69
+ " self.adapter = None\n",
70
+ " self.bias = bias\n",
71
+ " \n",
72
+ " def forward(self, input):\n",
73
+ " output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
74
+ " if self.adapter:\n",
75
+ " output += self.adapter(input)\n",
76
+ " return output\n",
77
+ " \n",
78
+ " @classmethod\n",
79
+ " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n",
80
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
81
+ " return cls(weights_int8, *state, linear.bias)\n",
82
+ " \n",
83
+ " def __repr__(self):\n",
84
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
85
+ " \n",
86
+ " \n",
87
+ "class DequantizeAndLinear(torch.autograd.Function): \n",
88
+ " @staticmethod\n",
89
+ " @custom_fwd\n",
90
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
91
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
92
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
93
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
94
+ " ctx._has_bias = bias is not None\n",
95
+ " return F.linear(input, weights_deq, bias)\n",
96
+ " \n",
97
+ " @staticmethod\n",
98
+ " @custom_bwd\n",
99
+ " def backward(ctx, grad_output: torch.Tensor):\n",
100
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
101
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
102
+ " # grad_output: [*batch, out_features]\n",
103
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
104
+ " grad_input = grad_output @ weights_deq\n",
105
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
106
+ " return grad_input, None, None, None, grad_bias\n",
107
+ " \n",
108
+ " \n",
109
+ "class FrozenBNBEmbedding(nn.Module):\n",
110
+ " def __init__(self, weight, absmax, code):\n",
111
+ " super().__init__()\n",
112
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
113
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
114
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
115
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
116
+ " self.adapter = None\n",
117
+ " \n",
118
+ " def forward(self, input, **kwargs):\n",
119
+ " with torch.no_grad():\n",
120
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
121
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
122
+ " output = F.embedding(input, weight_deq, **kwargs)\n",
123
+ " if self.adapter:\n",
124
+ " output += self.adapter(input)\n",
125
+ " return output \n",
126
+ " \n",
127
+ " @classmethod\n",
128
+ " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n",
129
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
130
+ " return cls(weights_int8, *state)\n",
131
+ " \n",
132
+ " def __repr__(self):\n",
133
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n",
134
+ " \n",
135
+ " \n",
136
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
137
+ " assert chunk_size % 4096 == 0\n",
138
+ " code = None\n",
139
+ " chunks = []\n",
140
+ " absmaxes = []\n",
141
+ " flat_tensor = matrix.view(-1)\n",
142
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
143
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
144
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
145
+ " chunks.append(quantized_chunk)\n",
146
+ " absmaxes.append(absmax_chunk)\n",
147
+ " \n",
148
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
149
+ " absmax = torch.cat(absmaxes)\n",
150
+ " return matrix_i8, (absmax, code)\n",
151
+ "\n",
152
+ "\n",
153
+ "def convert_to_int8(model):\n",
154
+ " \"\"\"Convert linear and embedding modules to 8-bit with optional adapters\"\"\"\n",
155
+ " for module in list(model.modules()):\n",
156
+ " for name, child in module.named_children():\n",
157
+ " if isinstance(child, nn.Linear):\n",
158
+ " print(name, child)\n",
159
+ " setattr( \n",
160
+ " module,\n",
161
+ " name,\n",
162
+ " FrozenBNBLinear(\n",
163
+ " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n",
164
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
165
+ " code=torch.zeros(256),\n",
166
+ " bias=child.bias,\n",
167
+ " ),\n",
168
+ " )\n",
169
+ " elif isinstance(child, nn.Embedding):\n",
170
+ " setattr(\n",
171
+ " module,\n",
172
+ " name,\n",
173
+ " FrozenBNBEmbedding(\n",
174
+ " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n",
175
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
176
+ " code=torch.zeros(256),\n",
177
+ " )\n",
178
+ " )"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "f4673d4c-0f4e-482e-ac04-b7389397af6e",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):\n",
189
+ " def __init__(self, config, layer_number=None):\n",
190
+ " super().__init__(config, layer_number)\n",
191
+ "\n",
192
+ " convert_to_int8(self.self_attention)\n",
193
+ " convert_to_int8(self.mlp)\n",
194
+ "\n",
195
+ "\n",
196
+ "class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):\n",
197
+ " def __init__(self, config):\n",
198
+ " super().__init__(config)\n",
199
+ " convert_to_int8(self)\n",
200
+ " \n",
201
+ "\n",
202
+ "class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):\n",
203
+ " def __init__(self, config):\n",
204
+ " super().__init__(config)\n",
205
+ " convert_to_int8(self)\n",
206
+ " \n",
207
+ "transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "id": "eca11b11-9b0b-4958-89f4-401f7a2cac0e",
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "from transformers import BloomForCausalLM \n",
218
+ "tokenizer = transformers.AutoTokenizer.from_pretrained('joaoalvarenga/bloom-8bit')\n",
219
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
220
+ "model = BloomForCausalLM.from_pretrained('joaoalvarenga/bloom-8bit', low_cpu_mem_usage=True)\n",
221
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
222
+ "model.to(device)"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "id": "82ea942b-7fcf-4bbc-adb9-be0bbd98b9f8",
228
+ "metadata": {},
229
+ "source": [
230
+ "### Fine-tune and save model"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "id": "26cacf36-56f7-4f9c-b975-33dd34b1ff9c",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "def add_adapters(model, adapter_dim=16):\n",
241
+ " assert adapter_dim > 0\n",
242
+ "\n",
243
+ " for module in model.modules():\n",
244
+ " if isinstance(module, FrozenBNBLinear):\n",
245
+ " module.adapter = nn.Sequential(\n",
246
+ " nn.Linear(module.in_features, adapter_dim, bias=False),\n",
247
+ " nn.Linear(adapter_dim, module.out_features, bias=False),\n",
248
+ " )\n",
249
+ " nn.init.zeros_(module.adapter[1].weight)\n",
250
+ " elif isinstance(module, FrozenBNBEmbedding):\n",
251
+ " module.adapter = nn.Sequential(\n",
252
+ " nn.Embedding(module.num_embeddings, adapter_dim),\n",
253
+ " nn.Linear(adapter_dim, module.embedding_dim, bias=False),\n",
254
+ " )\n",
255
+ " nn.init.zeros_(module.adapter[1].weight)\n",
256
+ "\n",
257
+ "add_adapters(model)\n",
258
+ "model.to(device)"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "4e293eb3-979a-46d7-97b8-cde296f45da8",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "from datasets import load_dataset\n",
269
+ "from bitsandbytes.optim import Adam8bit\n",
270
+ "\n",
271
+ "model.gradient_checkpointing_enable()\n",
272
+ "\n",
273
+ "wikisql = load_dataset(\"wikisql\", streaming=True)\n",
274
+ "optimizer = Adam8bit(model.parameters(), lr=1e-5)\n",
275
+ "\n",
276
+ "with torch.cuda.amp.autocast():\n",
277
+ " for row in tqdm(wikisql['train']):\n",
278
+ "\n",
279
+ " batch = tokenizer(row['question'] + row['sql']['human_readable'], truncation=True, max_length=128, return_tensors='pt')\n",
280
+ " batch = {k: v.cuda() for k, v in batch.items()}\n",
281
+ "\n",
282
+ " out = gpt.forward(**batch,)\n",
283
+ "\n",
284
+ " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
285
+ " reduction='mean')\n",
286
+ " print(loss)\n",
287
+ " loss.backward()\n",
288
+ "\n",
289
+ " optimizer.step()\n",
290
+ " optimizer.zero_grad()"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "4e2251f6-1a5c-4193-b971-0840d6d59c32",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "model.save_pretrained('bloom-8bit-fine-tuned')"
301
+ ]
302
+ }
303
+ ],
304
+ "metadata": {
305
+ "kernelspec": {
306
+ "display_name": "Python 3 (ipykernel)",
307
+ "language": "python",
308
+ "name": "python3"
309
+ },
310
+ "language_info": {
311
+ "codemirror_mode": {
312
+ "name": "ipython",
313
+ "version": 3
314
+ },
315
+ "file_extension": ".py",
316
+ "mimetype": "text/x-python",
317
+ "name": "python",
318
+ "nbconvert_exporter": "python",
319
+ "pygments_lexer": "ipython3",
320
+ "version": "3.9.12"
321
+ }
322
+ },
323
+ "nbformat": 4,
324
+ "nbformat_minor": 5
325
+ }