marisming commited on
Commit
729bc74
·
verified ·
1 Parent(s): 1df7ad4

Upload folder using huggingface_hub

Browse files
Files changed (41) hide show
  1. 04-gene-sft/.ipynb_checkpoints/1-finetue-intro-checkpoint.ipynb +138 -8
  2. 04-gene-sft/.ipynb_checkpoints/2-gpt2-instruction-ft-checkpoint.ipynb +117 -0
  3. 04-gene-sft/.ipynb_checkpoints/3-llama-expand-dict-checkpoint.ipynb +78 -11
  4. 04-gene-sft/.ipynb_checkpoints/4-deepspeed-intro-checkpoint.ipynb +10 -0
  5. 04-gene-sft/.ipynb_checkpoints/5-peft-intro-checkpoint.ipynb +870 -0
  6. 04-gene-sft/.ipynb_checkpoints/6-llama-continue-train-checkpoint.ipynb +491 -0
  7. 04-gene-sft/.ipynb_checkpoints/7-llama-instruction-ft-checkpoint.ipynb +624 -0
  8. 04-gene-sft/.ipynb_checkpoints/gene_bpe_seg-checkpoint.vocab +0 -0
  9. 04-gene-sft/.ipynb_checkpoints/llama_sft_test-checkpoint.ipynb +1627 -0
  10. 04-gene-sft/.ipynb_checkpoints/merge_pt_model-checkpoint.sh +6 -0
  11. 04-gene-sft/.ipynb_checkpoints/merge_sft_model-checkpoint.sh +6 -0
  12. 04-gene-sft/.ipynb_checkpoints/run_clm_pt_with_peft-checkpoint.py +637 -0
  13. 04-gene-sft/.ipynb_checkpoints/run_clm_sft_with_peft-checkpoint.py +449 -0
  14. 04-gene-sft/.ipynb_checkpoints/run_pt-checkpoint.sh +55 -0
  15. 04-gene-sft/.ipynb_checkpoints/run_sft-checkpoint.sh +59 -0
  16. 04-gene-sft/1-finetue-intro.ipynb +187 -8
  17. 04-gene-sft/2-gpt2-instruction-ft.ipynb +117 -0
  18. 04-gene-sft/3-llama-expand-dict.ipynb +78 -11
  19. 04-gene-sft/4-deepspeed-intro.ipynb +10 -0
  20. 04-gene-sft/5-peft-intro.ipynb +870 -0
  21. 04-gene-sft/6-llama-continue-train.ipynb +491 -0
  22. 04-gene-sft/7-llama-instruction-ft.ipynb +624 -0
  23. 04-gene-sft/gene_bpe_seg.model +3 -0
  24. 04-gene-sft/gene_bpe_seg.vocab +0 -0
  25. 04-gene-sft/img/.ipynb_checkpoints/sft-checkpoint.png +0 -0
  26. 04-gene-sft/img/.ipynb_checkpoints/sft2-checkpoint.png +0 -0
  27. 04-gene-sft/img/deepspeed.png +0 -0
  28. 04-gene-sft/llama_sft_test.ipynb +1627 -0
  29. 04-gene-sft/merge_llama_with_dna_lora.py +367 -0
  30. 04-gene-sft/merge_pt_model.sh +6 -0
  31. 04-gene-sft/merge_sft_model.sh +6 -0
  32. 04-gene-sft/merged_gene_eng_tokenizer_hf/special_tokens_map.json +23 -0
  33. 04-gene-sft/merged_gene_eng_tokenizer_hf/tokenizer.model +3 -0
  34. 04-gene-sft/merged_gene_eng_tokenizer_hf/tokenizer_config.json +43 -0
  35. 04-gene-sft/merged_gene_eng_tokenizer_sp/gene_eng_llama_tokenizer.model +3 -0
  36. 04-gene-sft/run_clm_pt_with_peft.py +10 -2
  37. 04-gene-sft/run_clm_sft_with_peft.py +12 -2
  38. 04-gene-sft/run_sft.sh +1 -2
  39. 04-gene-sft/train_data/dna_1g.txt +3 -0
  40. 04-gene-sft/train_data/english_500m.txt +3 -0
  41. 04-gene-sft/train_data/protein_1g.txt +3 -0
04-gene-sft/.ipynb_checkpoints/1-finetue-intro-checkpoint.ipynb CHANGED
@@ -31,6 +31,12 @@
31
  "\"yuanzhoulvpi/gpt2_chinese\", num_labels=2\n",
32
  ")\n",
33
  "\n",
 
 
 
 
 
 
34
  "\n",
35
  "\n",
36
  "2 如果是把分类问题,改成指令微调的模式,就是像\n",
@@ -174,7 +180,7 @@
174
  },
175
  {
176
  "cell_type": "code",
177
- "execution_count": null,
178
  "id": "64312191-423f-4a18-aa0c-036374e93fb2",
179
  "metadata": {},
180
  "outputs": [],
@@ -192,10 +198,44 @@
192
  },
193
  {
194
  "cell_type": "code",
195
- "execution_count": null,
196
  "id": "32c16282-f9f1-4545-b522-daf2b39b4ead",
197
  "metadata": {},
198
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  "source": [
200
  "#原始模型\n",
201
  "from transformers import AutoModel\n",
@@ -205,10 +245,55 @@
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": null,
209
  "id": "1149163f-4d89-472e-8d45-ebcbb5f9575e",
210
  "metadata": {},
211
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  "source": [
213
  "#分类微调模型\n",
214
  "from transformers import AutoModelForSequenceClassification\n",
@@ -218,16 +303,61 @@
218
  },
219
  {
220
  "cell_type": "code",
221
- "execution_count": 1,
222
  "id": "09735059-507c-48c4-893f-ca0da21ce5e8",
223
  "metadata": {},
224
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  "source": [
226
  "#指令微调模型\n",
227
  "from transformers import AutoModelForCausalLM\n",
228
- "sft_model = AutoModelForMaskedLM.from_pretrained(\"gpt2\")\n",
229
  "sft_model"
230
  ]
 
 
 
 
 
 
 
 
231
  }
232
  ],
233
  "metadata": {
 
31
  "\"yuanzhoulvpi/gpt2_chinese\", num_labels=2\n",
32
  ")\n",
33
  "\n",
34
+ "对应的训练数据一般是这样的:\n",
35
+ "\n",
36
+ "| seq | label |\n",
37
+ "|------------------------------|-------|\n",
38
+ "| 他家的奶茶超级好喝。。。 | 1 |\n",
39
+ "| 他家的奶茶超级难喝。。。 | 0 |\n",
40
  "\n",
41
  "\n",
42
  "2 如果是把分类问题,改成指令微调的模式,就是像\n",
 
180
  },
181
  {
182
  "cell_type": "code",
183
+ "execution_count": 1,
184
  "id": "64312191-423f-4a18-aa0c-036374e93fb2",
185
  "metadata": {},
186
  "outputs": [],
 
198
  },
199
  {
200
  "cell_type": "code",
201
+ "execution_count": 2,
202
  "id": "32c16282-f9f1-4545-b522-daf2b39b4ead",
203
  "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "data": {
207
+ "text/plain": [
208
+ "GPT2Model(\n",
209
+ " (wte): Embedding(50257, 768)\n",
210
+ " (wpe): Embedding(1024, 768)\n",
211
+ " (drop): Dropout(p=0.1, inplace=False)\n",
212
+ " (h): ModuleList(\n",
213
+ " (0-11): 12 x GPT2Block(\n",
214
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
215
+ " (attn): GPT2SdpaAttention(\n",
216
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
217
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
218
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
219
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
220
+ " )\n",
221
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
222
+ " (mlp): GPT2MLP(\n",
223
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
224
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
225
+ " (act): NewGELUActivation()\n",
226
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
227
+ " )\n",
228
+ " )\n",
229
+ " )\n",
230
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
231
+ ")"
232
+ ]
233
+ },
234
+ "execution_count": 2,
235
+ "metadata": {},
236
+ "output_type": "execute_result"
237
+ }
238
+ ],
239
  "source": [
240
  "#原始模型\n",
241
  "from transformers import AutoModel\n",
 
245
  },
246
  {
247
  "cell_type": "code",
248
+ "execution_count": 3,
249
  "id": "1149163f-4d89-472e-8d45-ebcbb5f9575e",
250
  "metadata": {},
251
+ "outputs": [
252
+ {
253
+ "name": "stderr",
254
+ "output_type": "stream",
255
+ "text": [
256
+ "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']\n",
257
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
258
+ ]
259
+ },
260
+ {
261
+ "data": {
262
+ "text/plain": [
263
+ "GPT2ForSequenceClassification(\n",
264
+ " (transformer): GPT2Model(\n",
265
+ " (wte): Embedding(50257, 768)\n",
266
+ " (wpe): Embedding(1024, 768)\n",
267
+ " (drop): Dropout(p=0.1, inplace=False)\n",
268
+ " (h): ModuleList(\n",
269
+ " (0-11): 12 x GPT2Block(\n",
270
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
271
+ " (attn): GPT2SdpaAttention(\n",
272
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
273
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
274
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
275
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
276
+ " )\n",
277
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
278
+ " (mlp): GPT2MLP(\n",
279
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
280
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
281
+ " (act): NewGELUActivation()\n",
282
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
283
+ " )\n",
284
+ " )\n",
285
+ " )\n",
286
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
287
+ " )\n",
288
+ " (score): Linear(in_features=768, out_features=2, bias=False)\n",
289
+ ")"
290
+ ]
291
+ },
292
+ "execution_count": 3,
293
+ "metadata": {},
294
+ "output_type": "execute_result"
295
+ }
296
+ ],
297
  "source": [
298
  "#分类微调模型\n",
299
  "from transformers import AutoModelForSequenceClassification\n",
 
303
  },
304
  {
305
  "cell_type": "code",
306
+ "execution_count": 5,
307
  "id": "09735059-507c-48c4-893f-ca0da21ce5e8",
308
  "metadata": {},
309
+ "outputs": [
310
+ {
311
+ "data": {
312
+ "text/plain": [
313
+ "GPT2LMHeadModel(\n",
314
+ " (transformer): GPT2Model(\n",
315
+ " (wte): Embedding(50257, 768)\n",
316
+ " (wpe): Embedding(1024, 768)\n",
317
+ " (drop): Dropout(p=0.1, inplace=False)\n",
318
+ " (h): ModuleList(\n",
319
+ " (0-11): 12 x GPT2Block(\n",
320
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
321
+ " (attn): GPT2SdpaAttention(\n",
322
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
323
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
324
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
325
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
326
+ " )\n",
327
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
328
+ " (mlp): GPT2MLP(\n",
329
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
330
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
331
+ " (act): NewGELUActivation()\n",
332
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
333
+ " )\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
337
+ " )\n",
338
+ " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
339
+ ")"
340
+ ]
341
+ },
342
+ "execution_count": 5,
343
+ "metadata": {},
344
+ "output_type": "execute_result"
345
+ }
346
+ ],
347
  "source": [
348
  "#指令微调模型\n",
349
  "from transformers import AutoModelForCausalLM\n",
350
+ "sft_model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
351
  "sft_model"
352
  ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "id": "d1407cbe-4996-4898-a135-e26d28da2a2a",
358
+ "metadata": {},
359
+ "outputs": [],
360
+ "source": []
361
  }
362
  ],
363
  "metadata": {
04-gene-sft/.ipynb_checkpoints/2-gpt2-instruction-ft-checkpoint.ipynb CHANGED
@@ -8,6 +8,123 @@
8
  "# 4.2 基于GPT2的指令微调"
9
  ]
10
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  {
12
  "cell_type": "code",
13
  "execution_count": null,
 
8
  "# 4.2 基于GPT2的指令微调"
9
  ]
10
  },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "02cd6e13-bbfb-413a-8236-ff092456fd1c",
14
+ "metadata": {},
15
+ "source": [
16
+ "我还是用第二章中的分类的例子,使用指令微调的形式,来再次解决分类问题。\n",
17
+ "\n",
18
+ "使用 GPT-2 进行文本分类的两种方法:**使用 GPT-2 的分类头(Classification Header)** 和 **将分类任务转换为指令微调**,在思路、实现、优劣势和适用场景上存在明显差异。以下是详细对比:\n",
19
+ "\n",
20
+ "---\n",
21
+ "\n",
22
+ "### **1. 核心思路**\n",
23
+ "\n",
24
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
25
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
26
+ "| **基本概念** | 在 GPT-2 顶部添加一个分类头(通常是一个线性层),直接预测分类标签。 | 将分类任务转化为自然语言指令,模型通过微调理解并完成指令形式的任务。 |\n",
27
+ "| **实现方式** | 修改 GPT-2 模型,添加 `num_labels` 分类头并定义分类损失函数。 | 构建任务指令数据(Instruction + Input + Output),然后微调模型。 |\n",
28
+ "| **数据形式** | 文本与其分类标签的直接映射。 | 文本通过指令转化为生成任务。例如:<br>`Input`: 文章内容<br>`Output`: 分类结果。 |\n",
29
+ "\n",
30
+ "---\n",
31
+ "\n",
32
+ "### **2. 数据格式**\n",
33
+ "\n",
34
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
35
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
36
+ "| **数据格式** | - 输入:文本 <br>- 标签:离散类别标签(如 0, 1, 2)。 | - 指令:自然语言描述任务(如 \"请分类以下文本\")。<br>- 输入:分类文本。<br>- 输出:分类结果(文本形式)。 |\n",
37
+ "| **示例** | 输入:`\"This is a happy day!\"`<br>标签:`1`(表示积极) | `Instruction`: \"请对以下文本进行情感分类\"<br>`Input`: `\"This is a happy day!\"`<br>`Output`: `\"积极\"` |\n",
38
+ "\n",
39
+ "---\n",
40
+ "\n",
41
+ "### **3. 模型结构**\n",
42
+ "\n",
43
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
44
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
45
+ "| **模型结构** | - GPT-2 + 分类头(线性层)。 | - GPT-2 原始结构,无需额外的分类头。 |\n",
46
+ "| **损失函数** | - 使用交叉熵损失(Cross Entropy Loss)。 | - 使用自回归的语言建模损失(Language Modeling Loss)。 |\n",
47
+ "\n",
48
+ "---\n",
49
+ "\n",
50
+ "### **4. 训练过程**\n",
51
+ "\n",
52
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
53
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
54
+ "| **微调对象** | 主要微调分类头部分的参数(可选择冻结 GPT-2 的主干部分)。 | 微调整个 GPT-2 模型(或使用参数高效微调如 LoRA)。 |\n",
55
+ "| **标签处理** | 离散化标签(如 0, 1, 2)。 | 标签转化为自然语言(如“积极”、“中立”、“消极”)。 |\n",
56
+ "| **训练难度** | - 简单,标准分类任务流程。<br>- 数据需求较小,适合小规模微调。 | - 复杂,需要构造高质量的指令数据集。<br>- 数据需求较大,适合多任务场景。 |\n",
57
+ "\n",
58
+ "---\n",
59
+ "\n",
60
+ "### **5. 优缺点分析**\n",
61
+ "\n",
62
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
63
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
64
+ "| **优点** | - 训练速度快,计算资源需求较低。<br>- 实现简单,适合单一任务。 | - 泛化能力��,支持多任务扩展。<br>- 与多任务微调和开放式生成兼容。 |\n",
65
+ "| **缺点** | - 只能处理分类任务,难以扩展为其他任务。<br>- 需要人工调整分类头和损失函数。 | - 数据构造复杂且对数据质量依赖较高。<br>- 训练资源需求较大,训练时间较长。 |\n",
66
+ "\n",
67
+ "---\n",
68
+ "\n",
69
+ "### **6. 适用场景**\n",
70
+ "\n",
71
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
72
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
73
+ "| **适用场景** | - 单任务文本分类,如情感分析、垃圾邮件检测等。 | - 多任务场景,支持分类、翻译、摘要等任务的统一处理。 |\n",
74
+ "| **数据规模** | 适合小数据集,数千到数万条数据即可训练效果良好。 | 适合大数据集,特别是多任务、多领域的数据集。 |\n",
75
+ "| **需求类型** | 专注于提高单一任务的分类准确率。 | 需要增强模型的多任务泛化能力,同时提升用户交互体验。 |\n",
76
+ "\n",
77
+ "---\n",
78
+ "\n",
79
+ "### **7. 综合对比总结**\n",
80
+ "\n",
81
+ "| **维度** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
82
+ "|-------------------------|--------------------------------------------------------------|-------------------------------------------------------|\n",
83
+ "| **实现复杂度** | 较低,直接添加分类头并使用标准分类流程即可完成。 | 较高,需要构造高质量指令数据,并调整训练流程。 |\n",
84
+ "| **资源需求** | 较低,仅需调整分类头部分,训练时间和显存消耗较少。 | 较高,需要微调整个模型,且对数据和算力需求更大。 |\n",
85
+ "| **性能表现** | 对单一分类任务效果较好,但泛化能力较弱。 | 在多任务、多样化分类场景中表现更强,且可扩展为其他任务类型。 |\n",
86
+ "| **扩展性** | 较差,仅适用于当前任务,难以迁移到其他任务。 | 较强,可适应多任务指令和开放式生成场景。 |\n",
87
+ "\n",
88
+ "---\n",
89
+ "\n",
90
+ "### **选择建议**\n",
91
+ "\n",
92
+ "1. **使用 GPT-2 分类头**:\n",
93
+ " - 如果任务是单一分类问题(如情感分析、垃圾邮件检测),并且数据量有限,推荐使用分类头方法。\n",
94
+ " - 适合快速实现和部署,无需复杂的预处理和指令数据集构建。\n",
95
+ "\n",
96
+ "2. **转换为指令微调**:\n",
97
+ " - 如果任务需要多样化(分类+生成+翻译等),或需要对未见任务有更好的泛化能力,推荐使用指令微调。\n",
98
+ " - 适合多任务、多场景部署,尤其是在 ChatGPT 风格的应用中更为适用。\n",
99
+ "\n",
100
+ "通过综合任务需求、数据规模和资源条件选择合适的方法,能够有效提升模型性能并实现更广泛的适用性。\n",
101
+ "\n",
102
+ "\n",
103
+ "原始的数据格式如下:\n",
104
+ "| sequence | label | label_name |\n",
105
+ "|--------------------------------------------------------|-------|----------------|\n",
106
+ "| TATATTTTCTCAGCTGAGTTAATTAGTTTCACTAGTTAACTGAGAATAAAAGAA | 1 | promoter |\n",
107
+ "| TGGGGAGGGTCCGGTGTTAGTTAGATACATCCCCAGACCCACACCCCGGATAGA | 0 | Non-promoter |\n",
108
+ "\n",
109
+ "转成指令的格式为:\n",
110
+ "```\n",
111
+ "{'instruction': 'Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.', \n",
112
+ "'input': 'CATGCGGGTCG...', \n",
113
+ "'output': 'Non-promoter'}\n",
114
+ "```\n",
115
+ "\n",
116
+ "然后写成指令微调数据格式,当做一般的文本进行训练:\n",
117
+ "```\n",
118
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
119
+ "### Instruction:\n",
120
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
121
+ "### Input:\n",
122
+ "TCTTTCTCTTCTGTATCATTCTACTT...\n",
123
+ "### Response:\n",
124
+ "Non-promoter\n",
125
+ "```\n"
126
+ ]
127
+ },
128
  {
129
  "cell_type": "code",
130
  "execution_count": null,
04-gene-sft/.ipynb_checkpoints/3-llama-expand-dict-checkpoint.ipynb CHANGED
@@ -114,10 +114,18 @@
114
  },
115
  {
116
  "cell_type": "code",
117
- "execution_count": null,
118
  "id": "19a06b82-31b8-48cb-9c83-ec016da2da8a",
119
  "metadata": {},
120
- "outputs": [],
 
 
 
 
 
 
 
 
121
  "source": [
122
  "from sentencepiece import SentencePieceProcessor\n",
123
  "model_path = \"gene_bpe_seg.model\"\n",
@@ -147,7 +155,7 @@
147
  },
148
  {
149
  "cell_type": "code",
150
- "execution_count": null,
151
  "id": "3bafcc33-2923-4026-bc39-c6ec716d2e3c",
152
  "metadata": {},
153
  "outputs": [],
@@ -161,10 +169,28 @@
161
  },
162
  {
163
  "cell_type": "code",
164
- "execution_count": null,
165
  "id": "66cb86ed-3225-4bb0-8aca-6005bc918d03",
166
  "metadata": {},
167
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  "source": [
169
  "llama_tokenizer_dir = \"llama-7b-hf\" \n",
170
  "dna_sp_model_file = \"gene_bpe_seg.model\"\n",
@@ -188,10 +214,20 @@
188
  },
189
  {
190
  "cell_type": "code",
191
- "execution_count": null,
192
  "id": "7ba4240e-bc08-4be0-8ca3-c4e7a47fa055",
193
  "metadata": {},
194
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
195
  "source": [
196
  "## Add dna tokens to LLaMA tokenizer\n",
197
  "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n",
@@ -210,10 +246,18 @@
210
  },
211
  {
212
  "cell_type": "code",
213
- "execution_count": null,
214
  "id": "a240a7d8-c1a9-4473-a5c5-157a25f97c16",
215
  "metadata": {},
216
- "outputs": [],
 
 
 
 
 
 
 
 
217
  "source": [
218
  "## Save\n",
219
  "output_sp_dir = 'merged_gene_eng_tokenizer_sp'\n",
@@ -229,10 +273,25 @@
229
  },
230
  {
231
  "cell_type": "code",
232
- "execution_count": null,
233
  "id": "cbd1f648-f8a0-4f16-b516-2ce3e7c7cfee",
234
  "metadata": {},
235
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  "source": [
237
  "# Test\n",
238
  "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
@@ -246,6 +305,14 @@
246
  "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
247
  "print(f\"Tokenized by GENE-LLaMA tokenizer:{dna_llama_tokenizer.tokenize(text)}\")"
248
  ]
 
 
 
 
 
 
 
 
249
  }
250
  ],
251
  "metadata": {
 
114
  },
115
  {
116
  "cell_type": "code",
117
+ "execution_count": 1,
118
  "id": "19a06b82-31b8-48cb-9c83-ec016da2da8a",
119
  "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "['▁TCG', 'ACGGC', 'ACGCG', 'ACAGC', 'AGCG', 'AGCCCC', 'GCGC', 'ACCCG', 'AGCGCG', 'AKCG', 'FVGP', 'MV', 'HLKV', 'HLE', 'ADV', 'ASSC', 'RS', 'AVI', 'YL', 'TS', 'EEP', 'FEG', 'VLGL', 'RLKE', 'GI', 'AI', 'TGC', 'WPR', 'WP', 'DEM', 'DE', 'RS', 'AVW', 'RV', 'EPY', 'TR', 'HFG', 'RVL', 'YS', 'FGV']\n"
126
+ ]
127
+ }
128
+ ],
129
  "source": [
130
  "from sentencepiece import SentencePieceProcessor\n",
131
  "model_path = \"gene_bpe_seg.model\"\n",
 
155
  },
156
  {
157
  "cell_type": "code",
158
+ "execution_count": 2,
159
  "id": "3bafcc33-2923-4026-bc39-c6ec716d2e3c",
160
  "metadata": {},
161
  "outputs": [],
 
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": 3,
173
  "id": "66cb86ed-3225-4bb0-8aca-6005bc918d03",
174
  "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message\n"
181
+ ]
182
+ },
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "32000 60000\n",
188
+ "['<s>', '</s>', '<unk>']\n",
189
+ "[1, 2, 0]\n",
190
+ "{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}\n"
191
+ ]
192
+ }
193
+ ],
194
  "source": [
195
  "llama_tokenizer_dir = \"llama-7b-hf\" \n",
196
  "dna_sp_model_file = \"gene_bpe_seg.model\"\n",
 
214
  },
215
  {
216
  "cell_type": "code",
217
+ "execution_count": 4,
218
  "id": "7ba4240e-bc08-4be0-8ca3-c4e7a47fa055",
219
  "metadata": {},
220
+ "outputs": [
221
+ {
222
+ "name": "stdout",
223
+ "output_type": "stream",
224
+ "text": [
225
+ "32000\n",
226
+ "Before:32000\n",
227
+ "New model pieces: 91643\n"
228
+ ]
229
+ }
230
+ ],
231
  "source": [
232
  "## Add dna tokens to LLaMA tokenizer\n",
233
  "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n",
 
246
  },
247
  {
248
  "cell_type": "code",
249
+ "execution_count": 5,
250
  "id": "a240a7d8-c1a9-4473-a5c5-157a25f97c16",
251
  "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "name": "stdout",
255
+ "output_type": "stream",
256
+ "text": [
257
+ "gene-LLaMA tokenizer has been saved to merged_gene_eng_tokenizer_hf\n"
258
+ ]
259
+ }
260
+ ],
261
  "source": [
262
  "## Save\n",
263
  "output_sp_dir = 'merged_gene_eng_tokenizer_sp'\n",
 
273
  },
274
  {
275
  "cell_type": "code",
276
+ "execution_count": 6,
277
  "id": "cbd1f648-f8a0-4f16-b516-2ce3e7c7cfee",
278
  "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "['<s>', '</s>', '<unk>']\n",
285
+ "[1, 2, 0]\n",
286
+ "{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}\n",
287
+ "Test text:\n",
288
+ " TCGACGGCACGCGACAGCAGCGAGCCCCGCGCACCCGAGCGCGAKCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV,\n",
289
+ "The primary use of LLaMA is research on large language models, including\n",
290
+ "Tokenized by LLaMA tokenizer:['▁T', 'CG', 'AC', 'G', 'GC', 'AC', 'GC', 'G', 'AC', 'AG', 'CA', 'GC', 'G', 'AG', 'CC', 'CC', 'GC', 'GC', 'AC', 'CC', 'GA', 'GC', 'GC', 'GA', 'K', 'CG', 'F', 'V', 'G', 'PM', 'V', 'HL', 'K', 'V', 'H', 'LE', 'AD', 'VA', 'SS', 'CR', 'S', 'AV', 'I', 'Y', 'LT', 'SEE', 'PF', 'EG', 'V', 'L', 'GL', 'RL', 'KE', 'G', 'IA', 'IT', 'GC', 'W', 'PR', 'WP', 'DE', 'MD', 'ERS', 'AV', 'WR', 'VE', 'PY', 'TR', 'H', 'F', 'GR', 'V', 'LY', 'SF', 'GV', ',', '<0x0A>', 'The', '▁primary', '▁use', '▁of', '▁L', 'La', 'MA', '▁is', '▁research', '▁on', '▁large', '▁language', '▁models', ',', '▁including']\n",
291
+ "Tokenized by GENE-LLaMA tokenizer:['▁TCG', 'ACGGC', 'ACGCG', 'ACAG', 'CA', 'GCG', 'AGCCCC', 'GCGC', 'ACCCG', 'AGCGCG', 'AKCG', 'FVGP', 'MVHL', 'KV', 'HLE', 'ADV', 'ASSC', 'RSAV', 'I', 'YL', 'TSEE', 'P', 'FEG', 'VLGL', 'RLK', 'EGI', 'AI', 'TGC', 'W', 'PRW', 'P', 'DEM', 'DER', 'SAV', 'W', 'RVE', 'PY', 'TRH', 'FG', 'RVLY', 'SFGV', ',', '<0x0A>', 'The', '▁primary', '▁use', '▁of', '▁L', 'La', 'MA', '▁is', '▁research', '▁on', '▁large', '▁language', '▁models', ',', '▁including']\n"
292
+ ]
293
+ }
294
+ ],
295
  "source": [
296
  "# Test\n",
297
  "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
 
305
  "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
306
  "print(f\"Tokenized by GENE-LLaMA tokenizer:{dna_llama_tokenizer.tokenize(text)}\")"
307
  ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "46ae7605-2ef8-4927-bff3-2c0325f8df0d",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": []
316
  }
317
  ],
318
  "metadata": {
04-gene-sft/.ipynb_checkpoints/4-deepspeed-intro-checkpoint.ipynb CHANGED
@@ -56,6 +56,8 @@
56
  "\n",
57
  "每个阶段都进一步减少显存需求,Stage 3 可支持超大规模模型(如 GPT-3)。\n",
58
  "\n",
 
 
59
  "#### **(2)混合精度训练**\n",
60
  "通过 FP16 或 BF16(半精度浮点数)计算,显著减少显存占用并提升计算效率。\n",
61
  "\n",
@@ -567,6 +569,14 @@
567
  "metadata": {},
568
  "outputs": [],
569
  "source": []
 
 
 
 
 
 
 
 
570
  }
571
  ],
572
  "metadata": {
 
56
  "\n",
57
  "每个阶段都进一步减少显存需求,Stage 3 可支持超大规模模型(如 GPT-3)。\n",
58
  "\n",
59
+ "<img src='img/deepspeed.png' width='600px' />\n",
60
+ "\n",
61
  "#### **(2)混合精度训练**\n",
62
  "通过 FP16 或 BF16(半精度浮点数)计算,显著减少显存占用并提升计算效率。\n",
63
  "\n",
 
569
  "metadata": {},
570
  "outputs": [],
571
  "source": []
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "id": "ce701aeb-c8c7-450a-bbf9-b793a19cd0c6",
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": []
580
  }
581
  ],
582
  "metadata": {
04-gene-sft/.ipynb_checkpoints/5-peft-intro-checkpoint.ipynb ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.6 基于llama的基因大模型指令微调"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "182b82c4-d484-4c15-a600-03c3b51367ec",
14
+ "metadata": {},
15
+ "source": [
16
+ "**PEFT**(Parameter-Efficient Fine-Tuning,参数高效微调)是一种优化技术,旨在以最小的参数更新实现对大规模预训练模型(如 GPT、BERT 等)的微调。PEFT 技术通过减少微调所需的参数量,显著降低了存储和计算开销,同时保留模型的性能,特别适合资源受限的场景和领域特定任务的定制化。\n",
17
+ "\n",
18
+ "---\n",
19
+ "\n",
20
+ "### **1. 核心思想**\n",
21
+ "传统的微调方式需要更新整个预训练模型的所有参数,PEFT 技术通过只调整少量的参数(如特定层或额外添加的小型模块)实现微调目标,大幅减少了训练开销和存储需求。\n",
22
+ "\n",
23
+ "---\n",
24
+ "\n",
25
+ "### **2. 常见的 PEFT 方法**\n",
26
+ "\n",
27
+ "#### **(1)Adapter 模型**\n",
28
+ "- 在每一层 Transformer 的输出中插入小型适配器模块,仅训练适配器模块的参数。\n",
29
+ "- 原始模型参数保持冻结不变。\n",
30
+ "- 优点:适配器模块参数量小,能适应不同任务。\n",
31
+ "\n",
32
+ "示例方法:\n",
33
+ "- **AdapterFusion**\n",
34
+ "- **MAD-X**\n",
35
+ "\n",
36
+ "---\n",
37
+ "\n",
38
+ "#### **(2)Prefix Tuning**\n",
39
+ "- 在 Transformer 的输入前添加一组可学习的前缀向量,这些前缀与模型的注意力机制交互。\n",
40
+ "- 只调整前缀向量的参数,而不更新原始模型。\n",
41
+ "- 优点:对生成任务效果显著,参数量进一步减少。\n",
42
+ "\n",
43
+ "---\n",
44
+ "\n",
45
+ "#### **(3)LoRA(Low-Rank Adaptation)**\n",
46
+ "- 将预训练模型中的部分权重分解为两个低秩矩阵,仅调整这些低秩矩阵的参数。\n",
47
+ "- 原始权重保持冻结状态。\n",
48
+ "- 优点:参数量极小,计算高效。\n",
49
+ " \n",
50
+ "---\n",
51
+ "\n",
52
+ "#### **(4)Prompt Tuning**\n",
53
+ "- 在输入文本中添加可学习的提示(Prompt)。\n",
54
+ "- 适合 NLP 任务中的文本生成、分类等。\n",
55
+ "- 优点:实现简单,易于集成到现有框架。\n",
56
+ "\n",
57
+ "---\n",
58
+ "\n",
59
+ "### **3. PEFT 的优势**\n",
60
+ "\n",
61
+ "1. **显著减少参数更新量**:\n",
62
+ " - 微调传统的大模型(如 GPT-3)需要更新数百亿参数,而 PEFT 仅需更新百万级别甚至更少的参数。\n",
63
+ "\n",
64
+ "2. **高效存储**:\n",
65
+ " - 每个任务的微调结果只需存储少量额外参数,而不是整个模型。\n",
66
+ "\n",
67
+ "3. **适用多任务**:\n",
68
+ " - 同一预训练模型可以通过不同的 PEFT 模块适配多个任务,无需重新训练。\n",
69
+ "\n",
70
+ "4. **降低计算开销**:\n",
71
+ " - 训练所需的内存和计算显著减少,适合资源有限的环境。\n",
72
+ "\n",
73
+ "---\n",
74
+ "\n",
75
+ "### **4. 应用场景**\n",
76
+ "\n",
77
+ "1. **领域特定任务**:\n",
78
+ " - 医疗、法律、金融等领域微调预训练模型。\n",
79
+ "\n",
80
+ "2. **多任务学习**:\n",
81
+ " - 适配多个任务,复用同一模型的预训练权重。\n",
82
+ "\n",
83
+ "3. **资源受限场景**:\n",
84
+ " - 移动设备、边缘设备上的模型部署。\n",
85
+ "\n",
86
+ "---\n",
87
+ "\n",
88
+ "### **5. Hugging Face PEFT 库**\n",
89
+ "\n",
90
+ "Hugging Face 提供了专门的 PEFT 库,支持多种参数高效微调技术:\n",
91
+ "- **安装**:\n",
92
+ " ```bash\n",
93
+ " pip install peft\n",
94
+ " ```\n",
95
+ "- **使用 LoRA 微调示例**:\n",
96
+ " ```python\n",
97
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
98
+ " from peft import LoraConfig, get_peft_model, TaskType\n",
99
+ "\n",
100
+ " # 加载模型和分词器\n",
101
+ " model_name = \"gpt2\"\n",
102
+ " model = AutoModelForCausalLM.from_pretrained(model_name)\n",
103
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
104
+ "\n",
105
+ " # 配置 LoRA\n",
106
+ " lora_config = LoraConfig(\n",
107
+ " task_type=TaskType.CAUSAL_LM,\n",
108
+ " r=8,\n",
109
+ " lora_alpha=32,\n",
110
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
111
+ " lora_dropout=0.1,\n",
112
+ " bias=\"none\"\n",
113
+ " )\n",
114
+ "\n",
115
+ " # 使用 LoRA 微调模型\n",
116
+ " model = get_peft_model(model, lora_config)\n",
117
+ " model.print_trainable_parameters()\n",
118
+ "\n",
119
+ " # 微调代码...\n",
120
+ " ```\n",
121
+ "\n",
122
+ "---\n",
123
+ "\n",
124
+ "### **6. PEFT 的局限性**\n",
125
+ "1. **特定任务限制**:\n",
126
+ " - 在一些复杂任务中,PEFT 方法可能不如全量微调效果好。\n",
127
+ "\n",
128
+ "2. **需要设计合适的模块**:\n",
129
+ " - 不同任务需要选择和设计合��的 PEFT 技术。\n",
130
+ "\n",
131
+ "3. **与模型架构相关**:\n",
132
+ " - PEFT 技术可能需要对模型架构进行一定程度的修改。\n",
133
+ "\n",
134
+ "---\n",
135
+ "\n",
136
+ "### **7. 小结**\n",
137
+ "PEFT 是一个极具潜力的技术,特别适合在有限资源下对大模型进行微调。它在许多领域和任务中已显示出良好的效果,例如 LoRA 和 Adapter 模型已经成为高效微调的主流方法。\n",
138
+ "\n",
139
+ "如果您需要实现高效微调,可以结合 Hugging Face 的 PEFT 库快速上手。"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 1,
145
+ "id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "import subprocess\n",
150
+ "import os\n",
151
+ "# 设置环境变量, autodl一般区域\n",
152
+ "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
153
+ "output = result.stdout\n",
154
+ "for line in output.splitlines():\n",
155
+ " if '=' in line:\n",
156
+ " var, value = line.split('=', 1)\n",
157
+ " os.environ[var] = value"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "id": "17bdb69d-3f0f-465e-bd60-2047a088e264",
163
+ "metadata": {},
164
+ "source": [
165
+ "如果您不确定模型中有哪些模块可以微调,可以打印模型结构:"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 2,
171
+ "id": "41a0c049-9134-4d89-aad0-1aa2241a9fca",
172
+ "metadata": {},
173
+ "outputs": [
174
+ {
175
+ "data": {
176
+ "application/vnd.jupyter.widget-view+json": {
177
+ "model_id": "4becc479adbc472bb7672d49da16aafd",
178
+ "version_major": 2,
179
+ "version_minor": 0
180
+ },
181
+ "text/plain": [
182
+ "generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]"
183
+ ]
184
+ },
185
+ "metadata": {},
186
+ "output_type": "display_data"
187
+ },
188
+ {
189
+ "name": "stdout",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "\n",
193
+ "transformer\n",
194
+ "transformer.wte\n",
195
+ "transformer.wpe\n",
196
+ "transformer.drop\n",
197
+ "transformer.h\n",
198
+ "transformer.h.0\n",
199
+ "transformer.h.0.ln_1\n",
200
+ "transformer.h.0.attn\n",
201
+ "transformer.h.0.attn.c_attn\n",
202
+ "transformer.h.0.attn.c_proj\n",
203
+ "transformer.h.0.attn.attn_dropout\n",
204
+ "transformer.h.0.attn.resid_dropout\n",
205
+ "transformer.h.0.ln_2\n",
206
+ "transformer.h.0.mlp\n",
207
+ "transformer.h.0.mlp.c_fc\n",
208
+ "transformer.h.0.mlp.c_proj\n",
209
+ "transformer.h.0.mlp.act\n",
210
+ "transformer.h.0.mlp.dropout\n",
211
+ "transformer.h.1\n",
212
+ "transformer.h.1.ln_1\n",
213
+ "transformer.h.1.attn\n",
214
+ "transformer.h.1.attn.c_attn\n",
215
+ "transformer.h.1.attn.c_proj\n",
216
+ "transformer.h.1.attn.attn_dropout\n",
217
+ "transformer.h.1.attn.resid_dropout\n",
218
+ "transformer.h.1.ln_2\n",
219
+ "transformer.h.1.mlp\n",
220
+ "transformer.h.1.mlp.c_fc\n",
221
+ "transformer.h.1.mlp.c_proj\n",
222
+ "transformer.h.1.mlp.act\n",
223
+ "transformer.h.1.mlp.dropout\n",
224
+ "transformer.h.2\n",
225
+ "transformer.h.2.ln_1\n",
226
+ "transformer.h.2.attn\n",
227
+ "transformer.h.2.attn.c_attn\n",
228
+ "transformer.h.2.attn.c_proj\n",
229
+ "transformer.h.2.attn.attn_dropout\n",
230
+ "transformer.h.2.attn.resid_dropout\n",
231
+ "transformer.h.2.ln_2\n",
232
+ "transformer.h.2.mlp\n",
233
+ "transformer.h.2.mlp.c_fc\n",
234
+ "transformer.h.2.mlp.c_proj\n",
235
+ "transformer.h.2.mlp.act\n",
236
+ "transformer.h.2.mlp.dropout\n",
237
+ "transformer.h.3\n",
238
+ "transformer.h.3.ln_1\n",
239
+ "transformer.h.3.attn\n",
240
+ "transformer.h.3.attn.c_attn\n",
241
+ "transformer.h.3.attn.c_proj\n",
242
+ "transformer.h.3.attn.attn_dropout\n",
243
+ "transformer.h.3.attn.resid_dropout\n",
244
+ "transformer.h.3.ln_2\n",
245
+ "transformer.h.3.mlp\n",
246
+ "transformer.h.3.mlp.c_fc\n",
247
+ "transformer.h.3.mlp.c_proj\n",
248
+ "transformer.h.3.mlp.act\n",
249
+ "transformer.h.3.mlp.dropout\n",
250
+ "transformer.h.4\n",
251
+ "transformer.h.4.ln_1\n",
252
+ "transformer.h.4.attn\n",
253
+ "transformer.h.4.attn.c_attn\n",
254
+ "transformer.h.4.attn.c_proj\n",
255
+ "transformer.h.4.attn.attn_dropout\n",
256
+ "transformer.h.4.attn.resid_dropout\n",
257
+ "transformer.h.4.ln_2\n",
258
+ "transformer.h.4.mlp\n",
259
+ "transformer.h.4.mlp.c_fc\n",
260
+ "transformer.h.4.mlp.c_proj\n",
261
+ "transformer.h.4.mlp.act\n",
262
+ "transformer.h.4.mlp.dropout\n",
263
+ "transformer.h.5\n",
264
+ "transformer.h.5.ln_1\n",
265
+ "transformer.h.5.attn\n",
266
+ "transformer.h.5.attn.c_attn\n",
267
+ "transformer.h.5.attn.c_proj\n",
268
+ "transformer.h.5.attn.attn_dropout\n",
269
+ "transformer.h.5.attn.resid_dropout\n",
270
+ "transformer.h.5.ln_2\n",
271
+ "transformer.h.5.mlp\n",
272
+ "transformer.h.5.mlp.c_fc\n",
273
+ "transformer.h.5.mlp.c_proj\n",
274
+ "transformer.h.5.mlp.act\n",
275
+ "transformer.h.5.mlp.dropout\n",
276
+ "transformer.h.6\n",
277
+ "transformer.h.6.ln_1\n",
278
+ "transformer.h.6.attn\n",
279
+ "transformer.h.6.attn.c_attn\n",
280
+ "transformer.h.6.attn.c_proj\n",
281
+ "transformer.h.6.attn.attn_dropout\n",
282
+ "transformer.h.6.attn.resid_dropout\n",
283
+ "transformer.h.6.ln_2\n",
284
+ "transformer.h.6.mlp\n",
285
+ "transformer.h.6.mlp.c_fc\n",
286
+ "transformer.h.6.mlp.c_proj\n",
287
+ "transformer.h.6.mlp.act\n",
288
+ "transformer.h.6.mlp.dropout\n",
289
+ "transformer.h.7\n",
290
+ "transformer.h.7.ln_1\n",
291
+ "transformer.h.7.attn\n",
292
+ "transformer.h.7.attn.c_attn\n",
293
+ "transformer.h.7.attn.c_proj\n",
294
+ "transformer.h.7.attn.attn_dropout\n",
295
+ "transformer.h.7.attn.resid_dropout\n",
296
+ "transformer.h.7.ln_2\n",
297
+ "transformer.h.7.mlp\n",
298
+ "transformer.h.7.mlp.c_fc\n",
299
+ "transformer.h.7.mlp.c_proj\n",
300
+ "transformer.h.7.mlp.act\n",
301
+ "transformer.h.7.mlp.dropout\n",
302
+ "transformer.h.8\n",
303
+ "transformer.h.8.ln_1\n",
304
+ "transformer.h.8.attn\n",
305
+ "transformer.h.8.attn.c_attn\n",
306
+ "transformer.h.8.attn.c_proj\n",
307
+ "transformer.h.8.attn.attn_dropout\n",
308
+ "transformer.h.8.attn.resid_dropout\n",
309
+ "transformer.h.8.ln_2\n",
310
+ "transformer.h.8.mlp\n",
311
+ "transformer.h.8.mlp.c_fc\n",
312
+ "transformer.h.8.mlp.c_proj\n",
313
+ "transformer.h.8.mlp.act\n",
314
+ "transformer.h.8.mlp.dropout\n",
315
+ "transformer.h.9\n",
316
+ "transformer.h.9.ln_1\n",
317
+ "transformer.h.9.attn\n",
318
+ "transformer.h.9.attn.c_attn\n",
319
+ "transformer.h.9.attn.c_proj\n",
320
+ "transformer.h.9.attn.attn_dropout\n",
321
+ "transformer.h.9.attn.resid_dropout\n",
322
+ "transformer.h.9.ln_2\n",
323
+ "transformer.h.9.mlp\n",
324
+ "transformer.h.9.mlp.c_fc\n",
325
+ "transformer.h.9.mlp.c_proj\n",
326
+ "transformer.h.9.mlp.act\n",
327
+ "transformer.h.9.mlp.dropout\n",
328
+ "transformer.h.10\n",
329
+ "transformer.h.10.ln_1\n",
330
+ "transformer.h.10.attn\n",
331
+ "transformer.h.10.attn.c_attn\n",
332
+ "transformer.h.10.attn.c_proj\n",
333
+ "transformer.h.10.attn.attn_dropout\n",
334
+ "transformer.h.10.attn.resid_dropout\n",
335
+ "transformer.h.10.ln_2\n",
336
+ "transformer.h.10.mlp\n",
337
+ "transformer.h.10.mlp.c_fc\n",
338
+ "transformer.h.10.mlp.c_proj\n",
339
+ "transformer.h.10.mlp.act\n",
340
+ "transformer.h.10.mlp.dropout\n",
341
+ "transformer.h.11\n",
342
+ "transformer.h.11.ln_1\n",
343
+ "transformer.h.11.attn\n",
344
+ "transformer.h.11.attn.c_attn\n",
345
+ "transformer.h.11.attn.c_proj\n",
346
+ "transformer.h.11.attn.attn_dropout\n",
347
+ "transformer.h.11.attn.resid_dropout\n",
348
+ "transformer.h.11.ln_2\n",
349
+ "transformer.h.11.mlp\n",
350
+ "transformer.h.11.mlp.c_fc\n",
351
+ "transformer.h.11.mlp.c_proj\n",
352
+ "transformer.h.11.mlp.act\n",
353
+ "transformer.h.11.mlp.dropout\n",
354
+ "transformer.ln_f\n",
355
+ "lm_head\n"
356
+ ]
357
+ }
358
+ ],
359
+ "source": [
360
+ "from transformers import AutoModelForCausalLM\n",
361
+ "\n",
362
+ "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
363
+ "\n",
364
+ "# 打印所有模块名称\n",
365
+ "for name, module in model.named_modules():\n",
366
+ " print(name)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "37aa6abb-ab1c-4e9c-b968-579dd74044db",
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": []
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "id": "0add2f79-f35c-4638-80bb-0d8a87a9b6a7",
380
+ "metadata": {},
381
+ "source": [
382
+ "在选择 `target_modules` 时,通常会根据模块的名称选择模型的特定部分,通常使用列表中最后一个点 `.` 后的字段名或整个路径名(如果需要更精确)。以下是对这些模块的详细分析和选择建议:\n",
383
+ "\n",
384
+ "---\n",
385
+ "\n",
386
+ "### **1. 分析模块结构**\n",
387
+ "\n",
388
+ "从列表中可以看出,GPT-2 的模块层次分为以下几类:\n",
389
+ "\n",
390
+ "1. **Embedding 层**:\n",
391
+ " - `transformer.wte`:词嵌入层(Word Token Embeddings)。\n",
392
+ " - `transformer.wpe`:位置嵌入层(Position Embeddings)。\n",
393
+ "\n",
394
+ "2. **Transformer 编码器层**:\n",
395
+ " - 每层编号为 `transformer.h.<层号>`(共 12 层)。\n",
396
+ " - 每层中包含:\n",
397
+ " - **层归一化**:\n",
398
+ " - `transformer.h.<层号>.ln_1`:第一层归一化。\n",
399
+ " - `transformer.h.<层号>.ln_2`:第二层归一化。\n",
400
+ " - **自注意力模块**:\n",
401
+ " - `transformer.h.<层号>.attn.c_attn`:注意力模块的 Query、Key 和 Value 投影。\n",
402
+ " - `transformer.h.<层号>.attn.c_proj`:注意力的输出投影。\n",
403
+ " - `transformer.h.<层号>.attn.attn_dropout`:注意力的 Dropout。\n",
404
+ " - `transformer.h.<层号>.attn.resid_dropout`:残差的 Dropout。\n",
405
+ " - **前馈网络模块(MLP)**:\n",
406
+ " - `transformer.h.<层号>.mlp.c_fc`:MLP 的第一层全连接。\n",
407
+ " - `transformer.h.<层号>.mlp.c_proj`:MLP 的第二层全连接(输出投影)。\n",
408
+ " - `transformer.h.<层号>.mlp.act`:激活函数(如 GELU)。\n",
409
+ " - `transformer.h.<层号>.mlp.dropout`:MLP 的 Dropout。\n",
410
+ "\n",
411
+ "3. **最终层**:\n",
412
+ " - `transformer.ln_f`:最终层归一化(LayerNorm)。\n",
413
+ " - `lm_head`:语言建模头,用于生成预测的 token 分布。\n",
414
+ "\n",
415
+ "---\n",
416
+ "\n",
417
+ "### **2. 如何选择 `target_modules`**\n",
418
+ "\n",
419
+ "#### **(1)常见目标模块**\n",
420
+ "- `transformer.h.<层号>.attn.c_attn`:对自注意力模块的 Query、Key 和 Value 投影层微调。\n",
421
+ "- `transformer.h.<层号>.attn.c_proj`:对注意力输出的投影层微调。\n",
422
+ "- `transformer.h.<层号>.mlp.c_fc`:对前馈网络的输入全连接层微调。\n",
423
+ "- `transformer.h.<层号>.mlp.c_proj`:对前馈网络的输出投影层微调。\n",
424
+ "\n",
425
+ "#### **(2)推荐设置**\n",
426
+ "- **文本生成任务**:\n",
427
+ " ```python\n",
428
+ " target_modules = [\"transformer.h.*.attn.c_attn\", \"transformer.h.*.attn.c_proj\"]\n",
429
+ " ```\n",
430
+ " 解释:\n",
431
+ " - `*.attn.c_attn`:调整 Query、Key、Value 的生成。\n",
432
+ " - `*.attn.c_proj`:调整注意力输出。\n",
433
+ "\n",
434
+ "- **文本分类任务**:\n",
435
+ " ```python\n",
436
+ " target_modules = [\"transformer.h.*.attn.c_attn\"]\n",
437
+ " ```\n",
438
+ " 解释:\n",
439
+ " - 微调自注意力模块最重要的部分即可。\n",
440
+ "\n",
441
+ "- **特定任务需要更细粒度控制**:\n",
442
+ " - 仅微调某几层:\n",
443
+ " ```python\n",
444
+ " target_modules = [\"transformer.h.0.attn.c_attn\", \"transformer.h.0.mlp.c_fc\"]\n",
445
+ " ```\n",
446
+ "\n",
447
+ "#### **(3)通配符选择**\n",
448
+ "使用 `*` 通配符可以指定所有层的某些模块:\n",
449
+ "- `transformer.h.*.attn.c_attn`:所有层的 Query、Key 和 Value 投影。\n",
450
+ "- `transformer.h.*.mlp.*`:所有层的 MLP 模块。\n",
451
+ "\n",
452
+ "---\n",
453
+ "\n",
454
+ "### **3. 示例:指定多个模块**\n",
455
+ "\n",
456
+ "```python\n",
457
+ "lora_config = LoraConfig(\n",
458
+ " task_type=TaskType.CAUSAL_LM,\n",
459
+ " r=8,\n",
460
+ " lora_alpha=32,\n",
461
+ " target_modules=[\n",
462
+ " \"transformer.h.*.attn.c_attn\",\n",
463
+ " \"transformer.h.*.mlp.c_fc\"\n",
464
+ " ],\n",
465
+ " lora_dropout=0.1,\n",
466
+ " bias=\"none\"\n",
467
+ ")\n",
468
+ "```\n",
469
+ "\n",
470
+ "- 这表示对所有层的 `attn.c_attn` 和 `mlp.c_fc` 模块进行 LoRA 微调。\n",
471
+ "\n",
472
+ "---\n",
473
+ "\n",
474
+ "### **4. 小提示:如何确定适合的模块**\n",
475
+ "\n",
476
+ "1. **任务相关性**:\n",
477
+ " - 文本生成:优先选择自注意力模块(如 `c_attn`)。\n",
478
+ " - 文本分类:通常需要全局语义表示,选择 `attn.c_attn` 或 `mlp.c_fc`。\n",
479
+ "\n",
480
+ "2. **性能与资源平衡**:\n",
481
+ " - 如果显存有限,可以只微调部分层。例如,仅选择浅层和深层的模块:\n",
482
+ " ```python\n",
483
+ " target_modules = [\"transformer.h.0.attn.c_attn\", \"transformer.h.11.attn.c_attn\"]\n",
484
+ " ```\n",
485
+ "\n",
486
+ "3. **打印模块名称以调试**:\n",
487
+ " - 确保选择的 `target_modules` 在模型中实际存在:\n",
488
+ " ```python\n",
489
+ " for name, _ in model.named_modules():\n",
490
+ " if \"c_attn\" in name:\n",
491
+ " print(name)\n",
492
+ " ```\n",
493
+ "\n",
494
+ "---\n",
495
+ "\n",
496
+ "### **建议**\n",
497
+ "- 一般情况下,`c_attn` 和 `c_proj` 是首选模块。\n",
498
+ "- 使用 `transformer.h.*` 通配符可以轻松指定多层。\n",
499
+ "- 根据任务需求和资源限制灵活调整目标模块,以实现最佳性能和效率。"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "id": "b4a41750-420f-49c4-845d-69db394794f9",
506
+ "metadata": {},
507
+ "outputs": [],
508
+ "source": []
509
+ },
510
+ {
511
+ "cell_type": "markdown",
512
+ "id": "10c99eb9-8007-4297-972e-7be71768c9c3",
513
+ "metadata": {},
514
+ "source": [
515
+ "以下是对 `LoraConfig` 配置的更详细解释,特别是如何设置微调哪些参数、冻结哪些参数,以及一般如何选择这些设置:\n",
516
+ "\n",
517
+ "---\n",
518
+ "\n",
519
+ "### **1. `LoraConfig` 参数解析**\n",
520
+ "\n",
521
+ "```python\n",
522
+ "lora_config = LoraConfig(\n",
523
+ " task_type=TaskType.SEQ_CLS, # 序列分类任务\n",
524
+ " r=8, # 降低矩阵秩\n",
525
+ " lora_alpha=32, # LoRA 的 alpha 超参数\n",
526
+ " target_modules=[\"c_attn\"], # GPT-2 中的自注意力模块\n",
527
+ " lora_dropout=0.1, # dropout 概率\n",
528
+ " bias=\"none\", # 是否微调偏置参数\n",
529
+ ")\n",
530
+ "```\n",
531
+ "\n",
532
+ "#### **(1)`task_type`**\n",
533
+ "- 定义任务类型,用于指导 PEFT 的具体行为。\n",
534
+ "- **常见选项**:\n",
535
+ " - `TaskType.CAUSAL_LM`:自回归语言建模(如 GPT 系列模型)。\n",
536
+ " - `TaskType.SEQ_CLS`:序列分类(如情感分析)。\n",
537
+ " - `TaskType.TOKEN_CLS`:标注任务(如命名实体识别)。\n",
538
+ " - `TaskType.SEQ_2_SEQ_LM`:序列到序列任务(如翻译、摘要)。\n",
539
+ "\n",
540
+ "**当前设置**:\n",
541
+ "- `TaskType.SEQ_CLS` 表示目标是文本分类任务。\n",
542
+ "\n",
543
+ "---\n",
544
+ "\n",
545
+ "#### **(2)`r`**\n",
546
+ "- 表示 LoRA 的 **秩**(rank),是降低矩阵秩的核心参数。\n",
547
+ "- LoRA 通过将模型的权重分解为两个低秩矩阵(`A` 和 `B`),只更新这两个矩阵。\n",
548
+ "- `r` 的值越大,微调能力越强,但需要的额外参数也越多。\n",
549
+ "- **典型范围**:`4` 至 `64`,大多数任务中 `8` 或 `16` 是常用值。\n",
550
+ "\n",
551
+ "**当前设置**:\n",
552
+ "- `r=8` 表示使用低秩分解,并微调 8 维的参数矩阵。\n",
553
+ "\n",
554
+ "---\n",
555
+ "\n",
556
+ "#### **(3)`lora_alpha`**\n",
557
+ "- 是 LoRA 的一个缩放因子,用于调节两个低秩矩阵的更新速率。\n",
558
+ "- **公式**:实际更新 = LoRA 输出 × `lora_alpha / r`\n",
559
+ "- **典型范围**:`16` 至 `128`,较大任务中可以选择更高的值。\n",
560
+ "\n",
561
+ "**当前设置**:\n",
562
+ "- `lora_alpha=32`,表示适中幅度的更新速率。\n",
563
+ "\n",
564
+ "---\n",
565
+ "\n",
566
+ "#### **(4)`target_modules`**\n",
567
+ "- 指定要应用 LoRA 微调的模块。\n",
568
+ "- **常见选择**:\n",
569
+ " - 对 Transformer 模型中的 **注意力模块**(如 `query`、`key`、`value`)进行微调,因为这些模块对任务性能影响较大。\n",
570
+ " - 对 GPT-2,通常选择 `c_attn`(GPT-2 中负责自注意力机制的组合模块)。\n",
571
+ "\n",
572
+ "**当前设置**:\n",
573
+ "- `target_modules=[\"c_attn\"]` 表示只对 GPT-2 的自注意力模块 `c_attn` 应用 LoRA。\n",
574
+ "\n",
575
+ "---\n",
576
+ "\n",
577
+ "#### **(5)`lora_dropout`**\n",
578
+ "- 表示 LoRA 层的 dropout 概率,用于防止过拟合。\n",
579
+ "- **典型范围**:`0.0` 至 `0.1`,视任务复杂性而定。\n",
580
+ "\n",
581
+ "**当前设置**:\n",
582
+ "- `lora_dropout=0.1`,表示有 10% 的概率随机丢弃 LoRA 层的输出。\n",
583
+ "\n",
584
+ "---\n",
585
+ "\n",
586
+ "#### **(6)`bias`**\n",
587
+ "- 决定是否微调偏置参数。\n",
588
+ "- **选项**:\n",
589
+ " - `\"none\"`:不微调任何偏置。\n",
590
+ " - `\"all\"`:微调所有偏置。\n",
591
+ " - `\"lora_only\"`:只微调 LoRA 层的偏置。\n",
592
+ "\n",
593
+ "**当前设置**:\n",
594
+ "- `bias=\"none\"`,表示所有偏置参数保持冻结。\n",
595
+ "\n",
596
+ "---\n",
597
+ "\n",
598
+ "### **2. 微调哪些参数,冻结哪些参数**\n",
599
+ "\n",
600
+ "LoRA 的核心思想是通过 **分解矩阵**,只更新少量参数,而冻结模型的大部分参数。以下是常见设置的说明:\n",
601
+ "\n",
602
+ "#### **微调的参数**\n",
603
+ "- LoRA 通过 `target_modules` 指定的模块,例如:\n",
604
+ " - GPT-2 的 `c_attn`(自注意力模块)。\n",
605
+ " - BERT 的 `query` 和 `key`。\n",
606
+ "- 这些模块是模型中对性能贡献最大的部分,通过微调这些模块,任务性能可以显著提升。\n",
607
+ "\n",
608
+ "#### **冻结的参数**\n",
609
+ "- 除了 `target_modules` 中指定的参数外,所有其他模型参数默认冻结,包括:\n",
610
+ " - 预训练权重的绝大部分。\n",
611
+ " - 偏置参数(如果 `bias=\"none\"`)。\n",
612
+ "\n",
613
+ "---\n",
614
+ "\n",
615
+ "### **3. 一般如何设置**\n",
616
+ "\n",
617
+ "#### **(1)针对不同任务调整**\n",
618
+ "- **文本分类任务**:\n",
619
+ " - 优先选择自注意力模块(如 `c_attn`)作为 `target_modules`。\n",
620
+ " - `r=8` 或 `r=16` 是常见选择,适中计算开销。\n",
621
+ " - 设置适当的 dropout(如 `lora_dropout=0.1`)以防止过拟合。\n",
622
+ " \n",
623
+ "- **语言生成任务**:\n",
624
+ " - 对 GPT-2 或 GPT-3,选择 `q_proj` 和 `v_proj`(query 和 value 投影模块)。\n",
625
+ " - `r=16` 或更高,适应生成任务的高复杂性。\n",
626
+ "\n",
627
+ "- **命名实体识别任务**:\n",
628
+ " - 优先选择 `q_proj` 和 `k_proj`(query 和 key 模块)。\n",
629
+ "\n",
630
+ "#### **(2)参数量与显存的权衡**\n",
631
+ "- 如果显存有限,减少 `r` 的值。\n",
632
+ "- 对小型任务,`r=4` 或 `r=8` 通常已经足够。\n",
633
+ "\n",
634
+ "#### **(3)偏置设置**\n",
635
+ "- 偏置参数的影响较小,在大多数情况下,可以选择 `bias=\"none\"` 保持冻结。\n",
636
+ "- 对非常依赖偏置的任务(如生成风格微调),可以尝试 `bias=\"lora_only\"`。\n",
637
+ "\n",
638
+ "---\n",
639
+ "\n",
640
+ "### **4. 示例:如何选择目标模块**\n",
641
+ "\n",
642
+ "#### **GPT-2**\n",
643
+ "对 GPT-2 来说,以下模块通常是微调的目标:\n",
644
+ "- **`c_attn`**:注意力模块的组��层。\n",
645
+ "- **`q_proj` 和 `v_proj`**:Query 和 Value 的线性投影。\n",
646
+ "\n",
647
+ "#### **BERT**\n",
648
+ "对 BERT 来说,以下模块通常是微调的目标:\n",
649
+ "- **`query`**:Attention 的 Query 模块。\n",
650
+ "- **`key`**:Attention 的 Key 模块。\n",
651
+ "\n",
652
+ "---\n",
653
+ "\n",
654
+ "### **5. 总结建议**\n",
655
+ "- **微调的参数**:优先选择模型中注意力相关模块。\n",
656
+ "- **冻结的参数**:大部分参数默认冻结以节省显存。\n",
657
+ "- **配置选择**:根据任务复杂性调整 `r` 和 `target_modules`。\n",
658
+ "- **推荐起点**:\n",
659
+ " - 文本分类:`target_modules=[\"c_attn\"]`, `r=8`, `lora_dropout=0.1`。\n",
660
+ " - 文本生成:`target_modules=[\"q_proj\", \"v_proj\"]`, `r=16`, `lora_dropout=0.1`。\n",
661
+ "\n",
662
+ "通过这些设置,LoRA 可以在参数量极小的情况下实现高效微调,适合各种任务场景。"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": null,
668
+ "id": "26d9f362-18cc-471f-b208-f29a6933c06a",
669
+ "metadata": {},
670
+ "outputs": [],
671
+ "source": [
672
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer\n",
673
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
674
+ "from datasets import load_dataset\n",
675
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
676
+ "\n",
677
+ "# **1. 加载模型和分词器**\n",
678
+ "model_name = \"gpt2\" # 基础模型\n",
679
+ "num_labels = 2 # 二分类任务\n",
680
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)\n",
681
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
682
+ "tokenizer.pad_token = tokenizer.eos_token # 设置 pad_token 为 eos_token\n",
683
+ "\n",
684
+ "# **2. 定义数据集**\n",
685
+ "# 示例数据集:dna_promoter_300\n",
686
+ "dataset = load_dataset(\"dnagpt/dna_promoter_300\")['train'].train_test_split(test_size=0.1)\n",
687
+ "\n",
688
+ "# **3. 数据预处理**\n",
689
+ "def preprocess_function(examples):\n",
690
+ " examples['label'] = [int(item) for item in examples['label']]\n",
691
+ " return tokenizer(\n",
692
+ " examples[\"sequence\"], truncation=True, padding=\"max_length\", max_length=128\n",
693
+ " )\n",
694
+ "\n",
695
+ "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
696
+ "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\") # Hugging Face Trainer 要求标签列名为 'labels'\n",
697
+ "\n",
698
+ "# **4. 划分数据集**\n",
699
+ "train_dataset = tokenized_datasets[\"train\"]\n",
700
+ "test_dataset = tokenized_datasets[\"test\"]\n",
701
+ "\n",
702
+ "# **5. 配置 LoRA**\n",
703
+ "lora_config = LoraConfig(\n",
704
+ " task_type=TaskType.SEQ_CLS, # 序列分类任务\n",
705
+ " r=8, # 降低矩阵秩\n",
706
+ " lora_alpha=32, # LoRA 的 alpha 超参数\n",
707
+ " target_modules=[\"c_attn\"], # GPT-2 中的自注意力模块\n",
708
+ " lora_dropout=0.1, # dropout 概率\n",
709
+ " bias=\"none\", # 是否微调偏置参数\n",
710
+ ")\n",
711
+ "\n",
712
+ "# 使用 LoRA 包装模型\n",
713
+ "model = get_peft_model(model, lora_config)\n",
714
+ "model.print_trainable_parameters() # 打印可训练的参数信息\n",
715
+ "\n",
716
+ "# **6. 计算指标**\n",
717
+ "def compute_metrics(eval_pred):\n",
718
+ " predictions, labels = eval_pred\n",
719
+ " preds = predictions.argmax(axis=-1)\n",
720
+ " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=\"binary\")\n",
721
+ " acc = accuracy_score(labels, preds)\n",
722
+ " return {\"accuracy\": acc, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
723
+ "\n",
724
+ "# **7. 定义训练参数**\n",
725
+ "training_args = TrainingArguments(\n",
726
+ " output_dir=\"./gpt2_lora_text_classification\", # 模型保存路径\n",
727
+ " evaluation_strategy=\"epoch\", # 每个 epoch 评估一次\n",
728
+ " save_strategy=\"epoch\", # 每个 epoch 保存一次\n",
729
+ " learning_rate=2e-5, # 学习率\n",
730
+ " per_device_train_batch_size=8, # 每设备的批量大小\n",
731
+ " per_device_eval_batch_size=8, # 每设备评估的批量大小\n",
732
+ " num_train_epochs=3, # 训练轮数\n",
733
+ " weight_decay=0.01, # 权重衰减\n",
734
+ " logging_dir=\"./logs\", # 日志路径\n",
735
+ " fp16=True, # 启用混合精度训练\n",
736
+ " save_total_limit=2, # 保留最多两个检查点\n",
737
+ " load_best_model_at_end=True, # 加载最佳模型\n",
738
+ " metric_for_best_model=\"accuracy\", # 根据准确率选择最佳模型\n",
739
+ " greater_is_better=True,\n",
740
+ ")\n",
741
+ "\n",
742
+ "# **8. 定义 Trainer**\n",
743
+ "trainer = Trainer(\n",
744
+ " model=model,\n",
745
+ " args=training_args,\n",
746
+ " train_dataset=train_dataset,\n",
747
+ " eval_dataset=test_dataset,\n",
748
+ " tokenizer=tokenizer,\n",
749
+ " compute_metrics=compute_metrics,\n",
750
+ ")\n",
751
+ "\n",
752
+ "# **9. 开始训练**\n",
753
+ "trainer.train()\n",
754
+ "\n",
755
+ "# **10. 保存模型**\n",
756
+ "model.save_pretrained(\"./gpt2_lora_text_classification\")\n",
757
+ "tokenizer.save_pretrained(\"./gpt2_lora_text_classification\")\n",
758
+ "\n",
759
+ "print(\"训练完成,模型已保存至 ./gpt2_lora_text_classification\")"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": null,
765
+ "id": "49a60fed-3a7d-4608-98b1-b4e313b94dbb",
766
+ "metadata": {},
767
+ "outputs": [],
768
+ "source": [
769
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
770
+ "from peft import PeftModel\n",
771
+ "\n",
772
+ "# 加载分词器\n",
773
+ "model_path = \"./gpt2_lora_text_classification\"\n",
774
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
775
+ "\n",
776
+ "# 加载微调后的 PEFT 模型\n",
777
+ "base_model = AutoModelForSequenceClassification.from_pretrained(\"gpt2\", num_labels=2)\n",
778
+ "model = PeftModel.from_pretrained(base_model, model_path)"
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "execution_count": null,
784
+ "id": "3c0d8f02-c3dc-4961-8b3a-50eefc5f9448",
785
+ "metadata": {},
786
+ "outputs": [],
787
+ "source": [
788
+ "import torch\n",
789
+ "\n",
790
+ "def predict(texts, model, tokenizer):\n",
791
+ " \"\"\"\n",
792
+ " 使用微调后的 PEFT 模型进行推理。\n",
793
+ " \n",
794
+ " Args:\n",
795
+ " texts (list of str): 待分类的文本列表。\n",
796
+ " model (PeftModel): 微调后的模型。\n",
797
+ " tokenizer (AutoTokenizer): 分词器。\n",
798
+ " \n",
799
+ " Returns:\n",
800
+ " list of dict: 每个文本的预测结果,包括 logits 和预测的类别标签。\n",
801
+ " \"\"\"\n",
802
+ " # 对输入文本进行分词和编码\n",
803
+ " inputs = tokenizer(\n",
804
+ " texts,\n",
805
+ " padding=True,\n",
806
+ " truncation=True,\n",
807
+ " max_length=512,\n",
808
+ " return_tensors=\"pt\"\n",
809
+ " )\n",
810
+ " \n",
811
+ " # 将输入数据移动到模型的设备上(CPU/GPU)\n",
812
+ " inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
813
+ " \n",
814
+ " # 模型推理\n",
815
+ " model.eval()\n",
816
+ " with torch.no_grad():\n",
817
+ " outputs = model(**inputs)\n",
818
+ " \n",
819
+ " # 获取 logits 并计算预测类别\n",
820
+ " logits = outputs.logits\n",
821
+ " probs = torch.nn.functional.softmax(logits, dim=-1)\n",
822
+ " predictions = torch.argmax(probs, dim=-1)\n",
823
+ " \n",
824
+ " # 返回每个文本的预测结果\n",
825
+ " results = [\n",
826
+ " {\"text\": text, \"logits\": logit.tolist(), \"predicted_class\": int(pred)}\n",
827
+ " for text, logit, pred in zip(texts, logits, predictions)\n",
828
+ " ]\n",
829
+ " return results\n"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "code",
834
+ "execution_count": null,
835
+ "id": "9c0cfe65-f4f3-4274-a4f4-1ac13725b15a",
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "Text: This movie was fantastic! I loved every part of it.\n",
840
+ "Predicted Class: 1\n",
841
+ "Logits: [-2.345, 3.567]\n",
842
+ "\n",
843
+ "Text: The plot was terrible and the acting was worse.\n",
844
+ "Predicted Class: 0\n",
845
+ "Logits: [4.123, -1.234]\n"
846
+ ]
847
+ }
848
+ ],
849
+ "metadata": {
850
+ "kernelspec": {
851
+ "display_name": "Python 3 (ipykernel)",
852
+ "language": "python",
853
+ "name": "python3"
854
+ },
855
+ "language_info": {
856
+ "codemirror_mode": {
857
+ "name": "ipython",
858
+ "version": 3
859
+ },
860
+ "file_extension": ".py",
861
+ "mimetype": "text/x-python",
862
+ "name": "python",
863
+ "nbconvert_exporter": "python",
864
+ "pygments_lexer": "ipython3",
865
+ "version": "3.12.3"
866
+ }
867
+ },
868
+ "nbformat": 4,
869
+ "nbformat_minor": 5
870
+ }
04-gene-sft/.ipynb_checkpoints/6-llama-continue-train-checkpoint.ipynb ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1e6d4978-4f0f-4268-aa23-d864857bd6c8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.6 基于llama的基因大模型持续预训练"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "2c201732-e736-463c-8446-637bf517479f",
14
+ "metadata": {},
15
+ "source": [
16
+ "LLaMA(**Large Language Model Meta AI**)是由 Meta(Facebook)开发的一系列大型语言模型,专注于提供高性能和高效的大语言模型,面向学术研究和开发社区。LLaMA 系列主要强调训练效率、模型性能和对计算资源的高效利用,是 GPT 系列模型的有力竞争者之一。\n",
17
+ "\n",
18
+ "---\n",
19
+ "\n",
20
+ "### **1. LLaMA 模型概述**\n",
21
+ "\n",
22
+ "#### **1.1 LLaMA 1**\n",
23
+ "- **发布**:2023 年 2 月。\n",
24
+ "- **模型参数规模**:\n",
25
+ " - 7B(70 亿)\n",
26
+ " - 13B(130 亿)\n",
27
+ " - 33B(330 亿)\n",
28
+ " - 65B(650 亿)\n",
29
+ "- **特点**:\n",
30
+ " - 专注于效率:与 GPT-3 等模型相比,LLaMA 在相同的训练成本下实现了更高的性能。\n",
31
+ " - 针对研究开放:提供预训练模型权重供研究使用。\n",
32
+ " - 使用高质量的数据:模型训练使用大量从网络中筛选的高质量文本数据,包括维基百科、书籍和其他高质量来源。\n",
33
+ "- **性能**:\n",
34
+ " - 在许多 NLP 任务中,LLaMA 的性能超过 GPT-3 和其他同类模型。\n",
35
+ " - 参数规模较小的版本(如 LLaMA-13B)性能可与 GPT-3(175B 参数)媲美。\n",
36
+ "\n",
37
+ "#### **1.2 LLaMA 2**\n",
38
+ "- **发布**:2023 年 7 月。\n",
39
+ "- **改进**:\n",
40
+ " - 增强的训练数据:相比 LLaMA 1,使用了更多的高质量数据。\n",
41
+ " - 引入微调版本:发布了开箱即用的对话模型(LLaMA 2-Chat)。\n",
42
+ " - 更好的开源支持:LLaMA 2 在商业用途上比 LLaMA 1 更加开放。\n",
43
+ "- **模型参数规模**:\n",
44
+ " - 7B(70 亿)\n",
45
+ " - 13B(130 亿)\n",
46
+ " - 70B(700 亿)\n",
47
+ "- **性能**:\n",
48
+ " - LLaMA 2 的性能相比 LLaMA 1 有显著提升。\n",
49
+ " - LLaMA 2-Chat 在对话任务中的表现优于许多现有开源模型。\n",
50
+ " - 在多个标准基准(如 MMLU)上超过 GPT-4 和 Claude 的开源实现。\n",
51
+ "\n",
52
+ "---\n",
53
+ "\n",
54
+ "### **2. LLaMA 的关键技术特点**\n",
55
+ "\n",
56
+ "#### **2.1 高效的架构设计**\n",
57
+ "- 基于 Transformer 架构。\n",
58
+ "- 针对训练效率和推理速度进行了优化,适合研究和开发。\n",
59
+ "\n",
60
+ "#### **2.2 模型压缩**\n",
61
+ "- 提供更小的参数规模(如 7B 和 13B),以便在更低的计算资源上运行。\n",
62
+ "- 在性能与参数量之间实现了很好的平衡。\n",
63
+ "\n",
64
+ "#### **2.3 训练数据**\n",
65
+ "- 使用从互联网中提取的高质量数据,注重数据清洗和筛选,避免低质量文本对模型的负面影响。\n",
66
+ "\n",
67
+ "#### **2.4 微调能力**\n",
68
+ "- 支持指令微调(Instruction Tuning)和 RLHF(基于人类反馈的强化学习),特别是在 LLaMA 2-Chat 模型中表现优异。\n",
69
+ "\n",
70
+ "---\n",
71
+ "\n",
72
+ "### **3. LLaMA 的性能对比**\n",
73
+ "\n",
74
+ "#### **与 GPT-3 比较**\n",
75
+ "- LLaMA 1-13B 参数模型在许多任务上的性能接近 GPT-3-175B。\n",
76
+ "- LLaMA 2-70B 在多个任务上超过 GPT-3。\n",
77
+ "\n",
78
+ "#### **与其他开源模型比较**\n",
79
+ "- LLaMA 2 在许多基准测试中优于其他开源模型(如 Falcon 和 MPT)。\n",
80
+ "- LLaMA 2-Chat 提供了与 ChatGPT 类似的对话能力,适用于对话任务。\n",
81
+ "\n",
82
+ "---\n",
83
+ "\n",
84
+ "### **4. 应用场景**\n",
85
+ "\n",
86
+ "1. **研究**:\n",
87
+ " - 开源权重适合学术研究,推动了对大语言模型的进一步探索。\n",
88
+ "\n",
89
+ "2. **对话系统**:\n",
90
+ " - LLaMA 2-Chat 专为对话任务设计,适合开发智能客服、聊天机器人等应用。\n",
91
+ "\n",
92
+ "3. **生成任务**:\n",
93
+ " - 支持文本生成、补全、摘要等任务。\n",
94
+ "\n",
95
+ "4. **微调与定制**:\n",
96
+ " - 可以基于特定领域数据进行微调,如医学、法律、教育等领域的专用模型。\n",
97
+ "\n",
98
+ "---\n",
99
+ "\n",
100
+ "### **5. 开源与获取方式**\n",
101
+ "\n",
102
+ "#### **1. 开源**\n",
103
+ "- LLaMA 1:需要申请权限才能获得模型权重。\n",
104
+ "- LLaMA 2:更加开放,允许商业用途,模型和权重可以通过 Meta 的合作平台获取(如 Hugging Face 和 AWS)。\n",
105
+ "\n",
106
+ "#### **2. 下载与使用**\n",
107
+ "使用 Hugging Face 加载模型:\n",
108
+ "```python\n",
109
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
110
+ "\n",
111
+ "model_name = \"meta-llama/Llama-2-7b-hf\" # 替换为具体模型\n",
112
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
113
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
114
+ "\n",
115
+ "# 使用模型生成文本\n",
116
+ "inputs = tokenizer(\"Hello, how are you?\", return_tensors=\"pt\")\n",
117
+ "outputs = model.generate(**inputs, max_length=50)\n",
118
+ "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
119
+ "```\n",
120
+ "\n",
121
+ "---\n",
122
+ "\n",
123
+ "### **6. 总结**\n",
124
+ "\n",
125
+ "#### **优势**\n",
126
+ "- **高性能**:在多个基准任务上表现出色。\n",
127
+ "- **高效训练**:小参数模型能与大模型媲美。\n",
128
+ "- **开放性**:LLaMA 2 提供了较为开放的商用许可。\n",
129
+ "\n",
130
+ "#### **局限**\n",
131
+ "- 模型需要高质量数据和强大算力训练,对推理设备也有一定要求。\n",
132
+ "\n",
133
+ "LLaMA 系列以其高效和开放的特点,为大模型研究和应用带来了强大动力,是当前大语言模型生态的重要组成部分。"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "7fb0d648-f891-47b9-a644-af5263fa9718",
139
+ "metadata": {},
140
+ "source": [
141
+ "---\n",
142
+ "---"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "id": "8b3c9ebb-213b-4dc4-a712-5a819fea3197",
148
+ "metadata": {},
149
+ "source": [
150
+ "**大模型的持续预训练**(Continual Pretraining of Large Models)是指在基础预训练模型(如 GPT、BERT 等)的基础上,通过引入新的数据或特定领域的数据继续进行预训练的过程。这一过程旨在让模型在特定场景或任务中表现更好,同时保留其通用能力。\n",
151
+ "\n",
152
+ "---\n",
153
+ "\n",
154
+ "### **1. 持续预训练的概念**\n",
155
+ "\n",
156
+ "持续预训练是一种在通用大模型的预训练基础上,进一步优化和适配模型的方法,主要包括以下两种场景:\n",
157
+ "1. **领域适配**:\n",
158
+ " - 将预训练模型在特定领域的数据上继续训练,使其对该领域的语料理解更深刻,例如法律、医学、金融等领域。\n",
159
+ "2. **性能优化**:\n",
160
+ " - 通过引入更多的通用数据或多样化的数据类型,扩展模型的通用能力,提高性能。\n",
161
+ "\n",
162
+ "---\n",
163
+ "\n",
164
+ "### **2. 持续预训练的目标**\n",
165
+ "\n",
166
+ "1. **提升领域性能**:\n",
167
+ " - 在特定领域任务上,模型能够更好地理解特定领域的语言模式和知识。\n",
168
+ " \n",
169
+ "2. **增强模型鲁棒性**:\n",
170
+ " - 通过引入新的数据或增强数据多样性,使模型对未见数据表现更稳定。\n",
171
+ "\n",
172
+ "3. **优化资源利用**:\n",
173
+ " - 通过复用已有的大模型权重,只需训练少量额外步骤,避免从零开始重新训练模型。\n",
174
+ "\n",
175
+ "---\n",
176
+ "\n",
177
+ "### **3. 持续预训练的步骤**\n",
178
+ "\n",
179
+ "#### **(1)数据准备**\n",
180
+ "- **领域数据**:针对特定领域(如医学、法律、科技)收集高质量语料。\n",
181
+ "- **新语料整合**:补充模型未见过的多样化语料。\n",
182
+ "- **数据清洗**:确保数据无噪声、语言风格一致。\n",
183
+ "\n",
184
+ "#### **(2)模型初始化**\n",
185
+ "- 使用现有的预训练模型作为初始权重,例如 Hugging Face 提供的 GPT-2 或 BERT 模型。\n",
186
+ "\n",
187
+ "#### **(3)训练设置**\n",
188
+ "- **超参数调整**:\n",
189
+ " - 通常使用较小的学习率(例如 `1e-5` 或 `2e-5`)以避免破坏已有的知识。\n",
190
+ "- **训练策略**:\n",
191
+ " - 冻结部分参数(如嵌入层或前几层)以保留通用能力,仅调整高层或新加入的部分。\n",
192
+ "\n",
193
+ "#### **(4)评估和验证**\n",
194
+ "- 使用领域任务的数据集对模型进行评估,验证其在目标任务中的改进效果。\n",
195
+ "\n",
196
+ "---\n",
197
+ "\n",
198
+ "### **4. 持续预训练的常见方法**\n",
199
+ "\n",
200
+ "#### **(1)全量持续预训练**\n",
201
+ "- 对整个模型的参数进行调整。\n",
202
+ "- **优点**:适合较大规模的新数据训练,能显著提升领域性能。\n",
203
+ "- **缺点**:计算资源需求大,可能导致模型过拟合。\n",
204
+ "\n",
205
+ "#### **(2)冻结部分参数**\n",
206
+ "- 冻结低层参数,仅微调高层。\n",
207
+ "- **优点**:保留通用知识,减少计算开销。\n",
208
+ "- **缺点**:对领域特定知识的适配可能不足。\n",
209
+ "\n",
210
+ "#### **(3)参数高效微调(PEFT)**\n",
211
+ "- 使用 PEFT 方法(如 LoRA、Adapter)进行预训练:\n",
212
+ " - **LoRA**:通过低秩矩阵分解,微调部分关键模块。\n",
213
+ " - **Adapter**:在 Transformer 层中插入小型适配模块。\n",
214
+ "- **优点**:显著减少需要更新的参数量。\n",
215
+ "\n",
216
+ "---\n",
217
+ "\n",
218
+ "### **5. 持续预训练的典型应用**\n",
219
+ "\n",
220
+ "1. **领域适配**\n",
221
+ " - **医学**:将预训练模型在 PubMed 或生物医学数据集上进行持续预训练。\n",
222
+ " - **法律**:使用法律文档进一步训练基础模型。\n",
223
+ " - **金融**:通过金融新闻、报告语料提升模型在金融领域的表现。\n",
224
+ "\n",
225
+ "2. **多语言扩展**\n",
226
+ " - 引入多语言语料,扩展模型的多语言能力。\n",
227
+ "\n",
228
+ "3. **数据更新**\n",
229
+ " - 持续加入新数据(如时事新闻)以适配最新语言模式。\n",
230
+ "\n",
231
+ "4. **特殊任务优化**\n",
232
+ " - 针对特定任务(如代码生成、对话)引入专用数据进行训练。\n",
233
+ "\n",
234
+ "---\n",
235
+ "\n",
236
+ "### **6. 实现持续预训练的代码示例**\n",
237
+ "\n",
238
+ "以下示例基于 Hugging Face 实现 GPT-2 的持续预训练:\n",
239
+ "\n",
240
+ "```python\n",
241
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n",
242
+ "from datasets import load_dataset\n",
243
+ "\n",
244
+ "# 1. 加载预训练模型和分词器\n",
245
+ "model_name = \"gpt2\"\n",
246
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
247
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
248
+ "\n",
249
+ "# 2. 加载新语料数据\n",
250
+ "dataset = load_dataset(\"text\", data_files={\"train\": \"domain_corpus.txt\"})\n",
251
+ "\n",
252
+ "# 3. 数据预处理\n",
253
+ "def tokenize_function(examples):\n",
254
+ " return tokenizer(examples[\"text\"], truncation=True, max_length=1024, padding=\"max_length\")\n",
255
+ "\n",
256
+ "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
257
+ "\n",
258
+ "# 4. 设置训练参数\n",
259
+ "training_args = TrainingArguments(\n",
260
+ " output_dir=\"./gpt2_domain_adapted\",\n",
261
+ " overwrite_output_dir=True,\n",
262
+ " per_device_train_batch_size=4,\n",
263
+ " num_train_epochs=3,\n",
264
+ " learning_rate=5e-5,\n",
265
+ " save_steps=500,\n",
266
+ " save_total_limit=2,\n",
267
+ " logging_dir=\"./logs\",\n",
268
+ " evaluation_strategy=\"no\", # 评估策略可以根据需要调整\n",
269
+ " fp16=True, # 混合精度训练\n",
270
+ ")\n",
271
+ "\n",
272
+ "# 5. 定义 Trainer 并启动训练\n",
273
+ "trainer = Trainer(\n",
274
+ " model=model,\n",
275
+ " args=training_args,\n",
276
+ " train_dataset=tokenized_dataset[\"train\"],\n",
277
+ " tokenizer=tokenizer,\n",
278
+ ")\n",
279
+ "\n",
280
+ "trainer.train()\n",
281
+ "\n",
282
+ "# 6. 保存模型\n",
283
+ "model.save_pretrained(\"./gpt2_domain_adapted\")\n",
284
+ "tokenizer.save_pretrained(\"./gpt2_domain_adapted\")\n",
285
+ "```\n",
286
+ "\n",
287
+ "---\n",
288
+ "\n",
289
+ "### **7. 持续预训练的挑战**\n",
290
+ "\n",
291
+ "1. **灾难性遗忘**:\n",
292
+ " - 持续预训练可能导致模型丧失之前学到的知识。\n",
293
+ " - **解决方法**:使用少量原始数据进行联合训练。\n",
294
+ "\n",
295
+ "2. **计算资源需求**:\n",
296
+ " - 需要大量显存和算力,特别是对于大规模模型和数据。\n",
297
+ "\n",
298
+ "3. **数据质量和多样性**:\n",
299
+ " - 新引入的数据可能包含噪声,影响模型性能。\n",
300
+ "\n",
301
+ "---\n",
302
+ "\n",
303
+ "### **8. 持续预训练的优势**\n",
304
+ "\n",
305
+ "- 提高特定领域或任务的性能。\n",
306
+ "- 更高效地利用已有模型权重,避免从头训练。\n",
307
+ "- 保留原始模型的通用能力,同时增强领域适应性。\n",
308
+ "\n",
309
+ "---\n",
310
+ "\n",
311
+ "### **总结**\n",
312
+ "\n",
313
+ "持续预训练是适配领域任务和提升模型性能的重要方法,通过引入新数据或优化模型训练策略,可以让大模型在特定场景中表现更优。配合参数高效微调方法(如 LoRA),还可显著降低计算开销,提升训练效率。这种技术在学术研究、工业应用和前沿领域(如法律、医学等)中均具有广泛价值。"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": null,
319
+ "id": "ca41ad33-18fb-44da-8f79-0380b5c9dcaa",
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": []
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "id": "3038550c-cc92-45c9-8bb4-46c58688bfc5",
327
+ "metadata": {},
328
+ "source": [
329
+ "## 本节任务\n",
330
+ "本节任务是基于llama。训练一个能够处理dna和protein蛋白质数据的基础预训练大模型,数据为第一章中的预训练数据,包括英文数据。"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "b1bd33b8-2e05-4b59-9d8f-c48de194cfd6",
336
+ "metadata": {},
337
+ "source": [
338
+ "## 代码运行\n",
339
+ "\n",
340
+ "```\n",
341
+ "# 复制第一章训练数据,包括dna,protein,还有英文数据,添加英文数据是为了避免遗忘问题\n",
342
+ "mkdir train_data\n",
343
+ "cp ../01-data_env/data/*.txt train_data/\n",
344
+ "\n",
345
+ "#持续预训练\n",
346
+ "./run_pt.sh\n",
347
+ "\n",
348
+ "#合并模型\n",
349
+ "./merge_sft_model.sh\n",
350
+ "\n",
351
+ "```"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "id": "4960a36c-7529-4db8-b91d-df91245f79d9",
357
+ "metadata": {},
358
+ "source": [
359
+ "## 模型验证"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": null,
365
+ "id": "69b3e97f-a801-4264-a651-a854bcfba9c6",
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
370
+ "from transformers import DataCollatorForLanguageModeling\n",
371
+ "from transformers import Trainer, TrainingArguments\n",
372
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
373
+ "from tokenizers import Tokenizer\n",
374
+ "from datasets import load_dataset"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "339435d9-9379-4b30-ae8b-50feee1ba714",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-merge-hf\")\n",
385
+ "tokenizer.pad_token = tokenizer.eos_token\n",
386
+ "tokenizer"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": null,
392
+ "id": "d0f154bb-b1ab-4611-a14c-9b403043fd96",
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-merge-hf\") #continue pretrain\n",
397
+ "model"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "id": "792a9f78-1828-4695-9f6e-479a704ea7e8",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "from transformers import AutoConfig\n",
408
+ "# 加载配置\n",
409
+ "config = AutoConfig.from_pretrained('dnahlm-merge-hf')\n",
410
+ "config"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "49021c65-54bb-4a97-a96d-b030cc3dcd13",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "text='''GCTGACTCTGCCAGGATGGAATGAAATTAGGTTGTTTTAATTATAATGTAAAGTCAGTTCTAGTCAGACATAGTCACATAGGCAAGTAAGGGAACCTAAAATTGCTTGGAAT,\n",
421
+ "KCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV,\n",
422
+ "The primary use of LLaMA is research on large language models, including'''\n",
423
+ "print(\"Test text:\\n\",text)\n",
424
+ "print(f\"Tokenized by DNA-LLaMA tokenizer:{tokenizer.tokenize(text)}\")"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "ebf869c8-866d-4770-8f64-79d671f88663",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "import torch\n",
435
+ "from transformers import pipeline\n",
436
+ "\n",
437
+ "model_id = \"dnahlm-merge-hf\"\n",
438
+ "\n",
439
+ "pipe = pipeline(\n",
440
+ " \"text-generation\", \n",
441
+ " model=model_id, \n",
442
+ " #torch_dtype=torch.bfloat16, \n",
443
+ " device_map=\"auto\",\n",
444
+ ")\n",
445
+ "\n",
446
+ "pipe(\"The key to life is\")"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "40a22c70-f1c4-4cd5-a118-2f5db40790e6",
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "pipe(\"GGAATGAAATTAGGTTGTTTTAATTATAATGTAAAGTCAGTTCT\")"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "id": "aec95d0a-4269-4540-bf14-4ce157b9a194",
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "pipe(\"KCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV\")"
467
+ ]
468
+ }
469
+ ],
470
+ "metadata": {
471
+ "kernelspec": {
472
+ "display_name": "Python 3 (ipykernel)",
473
+ "language": "python",
474
+ "name": "python3"
475
+ },
476
+ "language_info": {
477
+ "codemirror_mode": {
478
+ "name": "ipython",
479
+ "version": 3
480
+ },
481
+ "file_extension": ".py",
482
+ "mimetype": "text/x-python",
483
+ "name": "python",
484
+ "nbconvert_exporter": "python",
485
+ "pygments_lexer": "ipython3",
486
+ "version": "3.12.3"
487
+ }
488
+ },
489
+ "nbformat": 4,
490
+ "nbformat_minor": 5
491
+ }
04-gene-sft/.ipynb_checkpoints/7-llama-instruction-ft-checkpoint.ipynb ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.7 基于llama的基因大模型指令微调"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "c844103d-4e27-41b9-9bf1-c6a577846ab6",
14
+ "metadata": {},
15
+ "source": [
16
+ "### **大模型的指令微调(Instruction Fine-Tuning)**\n",
17
+ "\n",
18
+ "指令微调是指通过对大语言模型(如 GPT、T5、LLaMA 等)进行微调,使其能够更好地理解和执行人类以指令形式表达的任务。这种技术是大模型适配实际应用和增强用户交互能力的关键手段。\n",
19
+ "\n",
20
+ "---\n",
21
+ "\n",
22
+ "### **1. 指令微调的核心概念**\n",
23
+ "\n",
24
+ "指令微调的目标是通过在包含指令的专用数据集上进行微调,让模型能够:\n",
25
+ "1. 理解用户的任务需求(以自然语言表达的指令形式)。\n",
26
+ "2. 根据指令内容生成符合预期的高质量响应。\n",
27
+ "3. 适应多任务场景,减少特定任务的单独训练需求。\n",
28
+ "\n",
29
+ "---\n",
30
+ "\n",
31
+ "### **2. 指令微调的关键特点**\n",
32
+ "\n",
33
+ "1. **多任务统一**:\n",
34
+ " - 不需要针对每个任务单独微调,而是通过指令微调使模型能适应多种任务。\n",
35
+ " \n",
36
+ "2. **自然语言交互**:\n",
37
+ " - 用户可以用自然语言指令与模型交互,无需提供特定格式的输入。\n",
38
+ "\n",
39
+ "3. **泛化能力**:\n",
40
+ " - 微调后的模型能够对未见过的任务产生合理的推断和响应。\n",
41
+ "\n",
42
+ "---\n",
43
+ "\n",
44
+ "### **3. 数据集的构建与使用**\n",
45
+ "\n",
46
+ "#### **(1)指令微调数据集的特点**\n",
47
+ "- 数据通常包含以下三部分:\n",
48
+ " 1. **指令(Instruction)**:任务描述或问题,例如“将以下文本翻译为法语”。\n",
49
+ " 2. **输入(Input)**:任务相关的上下文或数据,可以为空。\n",
50
+ " 3. **输出(Output)**:模型期望生成的结果。\n",
51
+ "\n",
52
+ "#### **(2)常用指令微调数据集**\n",
53
+ "- **FLAN**:包含多个 NLP 任务的指令数据集,用于 T5 等模型的微调。\n",
54
+ "- **OpenAI 提供的指令数据**:如 GPT 系列的 ChatGPT 调优数据集。\n",
55
+ "- **InstructGPT 数据**:通过人类标注的多任务指令数据,用于模型优化。\n",
56
+ "- **Self-Instruct**:通过模型自生成指令和回答,进一步扩展训练数据。\n",
57
+ "\n",
58
+ "#### **(3)构建自己的数据集**\n",
59
+ "- 如果需要特定领域的指令微调,可以自行构建数据集:\n",
60
+ " - 收集任务需求和示例。\n",
61
+ " - 设计多样化的指令。\n",
62
+ " - 使用专家标注或模型辅助生成高质量答案。\n",
63
+ "\n",
64
+ "---\n",
65
+ "\n",
66
+ "### **4. 微调的步骤**\n",
67
+ "\n",
68
+ "#### **(1)加载基础模型**\n",
69
+ "从 Hugging Face 或其他框架加载预训练的大语言模型,例如 GPT-2、T5、LLaMA。\n",
70
+ "\n",
71
+ "#### **(2)准备数据集**\n",
72
+ "将指令微调数据集格式化为:\n",
73
+ "```python\n",
74
+ "{\n",
75
+ " \"instruction\": \"Translate the following text to French\",\n",
76
+ " \"input\": \"Hello, how are you?\",\n",
77
+ " \"output\": \"Bonjour, comment ça va?\"\n",
78
+ "}\n",
79
+ "```\n",
80
+ "\n",
81
+ "#### **(3)定义微调方法**\n",
82
+ "使用 `Trainer` 或分布式框架(如 DeepSpeed、Accelerate)进行微调。\n",
83
+ "\n",
84
+ "---\n",
85
+ "\n",
86
+ "### **5. 示例代码:指令微调实现**\n",
87
+ "\n",
88
+ "以下是基于 Hugging Face 的指令微调代码示例:\n",
89
+ "\n",
90
+ "```python\n",
91
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer\n",
92
+ "from datasets import load_dataset\n",
93
+ "\n",
94
+ "# 1. 加载预训练模型和分词器\n",
95
+ "model_name = \"gpt2\"\n",
96
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
97
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
98
+ "\n",
99
+ "# 2. 加载指令微调数据集\n",
100
+ "# 数据格式应包含 instruction, input, output 字段\n",
101
+ "dataset = load_dataset(\"path/to/instruction_dataset\")\n",
102
+ "\n",
103
+ "# 3. 数据预处理\n",
104
+ "def preprocess_function(example):\n",
105
+ " # 将指令和输入拼接成完整的提示\n",
106
+ " prompt = example[\"instruction\"]\n",
107
+ " if example[\"input\"]:\n",
108
+ " prompt += f\"\\n{example['input']}\"\n",
109
+ " labels = example[\"output\"]\n",
110
+ " tokenized = tokenizer(prompt, truncation=True, max_length=512, padding=\"max_length\")\n",
111
+ " with tokenizer.as_target_tokenizer():\n",
112
+ " tokenized_labels = tokenizer(labels, truncation=True, max_length=512, padding=\"max_length\")\n",
113
+ " tokenized[\"labels\"] = tokenized_labels[\"input_ids\"]\n",
114
+ " return tokenized\n",
115
+ "\n",
116
+ "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
117
+ "\n",
118
+ "# 4. 设置训练参数\n",
119
+ "training_args = TrainingArguments(\n",
120
+ " output_dir=\"./instruction_finetuned_model\",\n",
121
+ " per_device_train_batch_size=4,\n",
122
+ " num_train_epochs=3,\n",
123
+ " evaluation_strategy=\"epoch\",\n",
124
+ " save_strategy=\"epoch\",\n",
125
+ " learning_rate=5e-5,\n",
126
+ " weight_decay=0.01,\n",
127
+ " logging_dir=\"./logs\",\n",
128
+ " fp16=True,\n",
129
+ ")\n",
130
+ "\n",
131
+ "# 5. 定义 Trainer\n",
132
+ "trainer = Trainer(\n",
133
+ " model=model,\n",
134
+ " args=training_args,\n",
135
+ " train_dataset=tokenized_datasets[\"train\"],\n",
136
+ " eval_dataset=tokenized_datasets[\"test\"],\n",
137
+ " tokenizer=tokenizer,\n",
138
+ ")\n",
139
+ "\n",
140
+ "# 6. 开始训练\n",
141
+ "trainer.train()\n",
142
+ "\n",
143
+ "# 7. 保存模型\n",
144
+ "model.save_pretrained(\"./instruction_finetuned_model\")\n",
145
+ "tokenizer.save_pretrained(\"./instruction_finetuned_model\")\n",
146
+ "```\n",
147
+ "\n",
148
+ "---\n",
149
+ "\n",
150
+ "### **6. 指令微调的挑战**\n",
151
+ "\n",
152
+ "1. **数据质量**:\n",
153
+ " - 低质量或噪声数据可能导致模型生成结果不符合指令。\n",
154
+ "\n",
155
+ "2. **指令覆盖范围**:\n",
156
+ " - 数据集指令种类不足会限制模型的泛化能力。\n",
157
+ "\n",
158
+ "3. **计算资源需求**:\n",
159
+ " - 大模型的微调需要高性能 GPU 和大容量存储。\n",
160
+ "\n",
161
+ "4. **灾难性遗忘**:\n",
162
+ " - 微调过程中可能导致模型丧失部分原始能力。\n",
163
+ "\n",
164
+ "---\n",
165
+ "\n",
166
+ "### **7. 指令微调的应用场景**\n",
167
+ "\n",
168
+ "1. **多任务问答**:\n",
169
+ " - 适配多任务场景,支持翻译、总结、推理等功能。\n",
170
+ "\n",
171
+ "2. **特定领域优化**:\n",
172
+ " - 在法律、医疗等特定领域的任务指令上进行微调。\n",
173
+ "\n",
174
+ "3. **用户交互优化**:\n",
175
+ " - 提升模型对自然语言指令的理解和响应能力。\n",
176
+ "\n",
177
+ "4. **开放式对话生成**:\n",
178
+ " - 优化模型在对话场景下的表现,例如 ChatGPT 的微调。\n",
179
+ "\n",
180
+ "---\n",
181
+ "\n",
182
+ "### **总结**\n",
183
+ "\n",
184
+ "指令微调通过在特定格式的数据集上进一步训练大模型,使其能够更好地理解和执行用户的自然语言指令。这种方法适合多任务场景,并能提升模型的交互能力和领域适应性。借助高质量的指令数据集和高效的微调技术,大模型在实际应用中的表现可以得到显著提升。"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "id": "7be8b814-42f6-4fb6-bf4b-ae23292030f6",
190
+ "metadata": {},
191
+ "source": []
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "f9bed0ae-337d-49af-85f0-c8e6263d78db",
196
+ "metadata": {},
197
+ "source": [
198
+ "**大模型的持续预训练**和**指令微调**是两种针对大模型的后续优化策略,虽然它们的目标都是提升模型性能,但在应用场景、方法和效果等方面有明显区别。以下是它们的对比分析:\n",
199
+ "\n",
200
+ "---\n",
201
+ "\n",
202
+ "### **1. 概念与目标**\n",
203
+ "\n",
204
+ "| **特性** | **持续预训练** | **指令微调** |\n",
205
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
206
+ "| **定义** | 在通用预训练模型上,使用新的大规模语料(通用或领域特定数据)进行进一步预训练。 | 在包含指令任务的数据集上对大模型进行微调,以提升模型对人类指令的理解和执行能力。 |\n",
207
+ "| **目标** | 提升模型的通用能力或适应特定领域的语言理解与生成能力。 | 提高模型对多任务指令的泛化能力,让模型更好地理解和执行自然语言表达的具体任务。 |\n",
208
+ "| **典型应用** | 领域适配(医学、法律、金融)、性能优化、跨语言适配等。 | 多任务问答、开放式对话生成、翻译、推理等需要用户直接交互的场景。 |\n",
209
+ "\n",
210
+ "---\n",
211
+ "\n",
212
+ "### **2. 数据使用**\n",
213
+ "\n",
214
+ "| **特性** | **持续预训练** | **指令微调** |\n",
215
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
216
+ "| **数据类型** | 通用语料(如新闻、社交媒体文本)或领域特定语料(如 PubMed、法律文档、金融报告)。 | 任务指令数据集,包括指令(Instruction)、输入(Input)和输出(Output)。 |\n",
217
+ "| **数据构建** | 通常需要清洗和去重大规模语料数据,避免与原始预训练数据重叠。 | 通常由人工标注或模型生成的指令数据构成,例如 FLAN、InstructGPT 数据集。 |\n",
218
+ "| **多样性要求** | 数据应覆盖尽可能广的领域或目标领域的多种场景,以提升模型在这些场景的表现。 | 数据需要覆盖多种任务类型(如翻译、分类、摘要)和丰富的指令表达形式,以提高模型对多任务的适配能力。 |\n",
219
+ "\n",
220
+ "---\n",
221
+ "\n",
222
+ "### **3. 方法与技术**\n",
223
+ "\n",
224
+ "| **特性** | **持续预训练** | **指令微调** |\n",
225
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
226
+ "| **主要技术** | 继续使用自监督学习目标(如语言建模、掩码预测)进行训练。 | 使用监督学习,通常以任务输入和目标输出对为数据,通过微调适配特定任务需求。 |\n",
227
+ "| **模型调整** | - 可选择全量参数更新或冻结部分参数。<br>- 可结合参数高效微调技术(如 LoRA、Adapter)。 | - 通常使用监督训练方式,可能结合参数高效微调技术(如 LoRA)。 |\n",
228
+ "| **学习率** | 通常使用较小的学习率(如 `1e-5` 或更小),以防止破坏原始权重。 | 同样使用较小的学习率,但任务指令微调可能需要更高的关注任务特定的标签对准。 |\n",
229
+ "\n",
230
+ "---\n",
231
+ "\n",
232
+ "### **4. 模型能力与效果**\n",
233
+ "\n",
234
+ "| **特性** | **持续预训练** | **指令微调** |\n",
235
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
236
+ "| **提升的能力** | - 对领域特定语言模式和知识的适配性提升显著。<br>- 对未见过的通用场景生成能力增强(扩展模型知识广度)。 | - 显著提升模型对指令理解的能力,尤其是自然语言表达的任务需求。<br>- 对多任务和零样本任务的泛化能力有较大提升。 |\n",
237
+ "| **局限性** | - 对具体任务的直接适配能力较弱,可能需要额外的任务微调。<br>- 数据选择不当可能导致灾难性遗忘。 | - 依赖高质量的指令数据集,数据质量不高会导致模型生成结果不稳定。<br>- 对通用能力的提升有限。 |\n",
238
+ "\n",
239
+ "---\n",
240
+ "\n",
241
+ "### **5. 应用场景与示例**\n",
242
+ "\n",
243
+ "| **特性** | **持续预训练** | **指令微调** |\n",
244
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
245
+ "| **典型应用场景** | - 医学文献总结(通过 PubMed 语料持续预训练)。<br>- 法律条文分析(通过法律文档进一步训练)。<br>- 增强多语言生成能力(跨语言语料)。 | - ChatGPT 的多任务对话生成。<br>- 翻译、摘要、问答等用户交互任务的泛化处理。 |\n",
246
+ "| **实际示例** | - BioBERT:在 BERT 基础上使用生物医学语料持续预训练的模型。<br>- FinBERT:针对��融领域持续预训练的语言模型。 | - InstructGPT:在 GPT-3 基础上进行指令微调,用于多任务用户交互。<br>- FLAN-T5:通过 FLAN 数据集进行指令微调。 |\n",
247
+ "\n",
248
+ "---\n",
249
+ "\n",
250
+ "### **6. 持续预训练与指令微调的结合**\n",
251
+ "\n",
252
+ "持续预训练和指令微调可以结合使用,形成一个从领域适配到任务适配的完整流程:\n",
253
+ "1. **持续预训练**:\n",
254
+ " - 先在领域特定数据(如医学、法律、金融语料)上进行持续预训练,获取领域知识。\n",
255
+ "2. **指令微调**:\n",
256
+ " - 再利用多任务指令数据集对模型微调,使其能够高效执行领域内的多样化任务。\n",
257
+ "\n",
258
+ "这种结合方式特别适用于需要领域知识和任务适配的场景,例如医学问答系统或金融文本分析。\n",
259
+ "\n",
260
+ "---\n",
261
+ "\n",
262
+ "### **总结**\n",
263
+ "\n",
264
+ "| **维度** | **持续预训练** | **指令微调** |\n",
265
+ "|------------------------|-------------------------------------|----------------------------------|\n",
266
+ "| **目标** | 增强通用能力或适配特定领域。 | 提升对任务指令的理解和执行能力。 |\n",
267
+ "| **数据集** | 通用或领域语料。 | 指令数据集,包含输入和输出对。 |\n",
268
+ "| **方法** | 自监督学习,扩展语言建模能力。 | 监督学习,强化任务适配能力。 |\n",
269
+ "| **适用场景** | 领域特定任务(如医学、法律)。 | 多任务交互(如问答、对话生成)。 |\n",
270
+ "| **局限性** | 对具体任务适配较弱。 | 通用能力提升有限,依赖数据质量。 |\n",
271
+ "\n",
272
+ "两者各有侧重,且在许多场景下可以结合使用,形成一个强大的任务和领域适配框架。"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "markdown",
277
+ "id": "f97a705a-b946-4dc1-a173-a9df033d6f2b",
278
+ "metadata": {},
279
+ "source": [
280
+ "## 本节任务\n",
281
+ "本节任务是基于上一节预训练的llama生物大模型。对一些生物学任务进行微调,包含了多个不同类型的分类问题和多序列交换问题。具体可见sft_data下的数据。"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "markdown",
286
+ "id": "9782db62-95bd-40a6-9759-966b9a0b362e",
287
+ "metadata": {},
288
+ "source": [
289
+ "## 代码运行\n",
290
+ "\n",
291
+ "```\n",
292
+ "\n",
293
+ "#微调\n",
294
+ "./run_sft.sh\n",
295
+ "\n",
296
+ "#合并模型\n",
297
+ "./merge_sft_model.sh\n",
298
+ "\n",
299
+ "```"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "id": "182b82c4-d484-4c15-a600-03c3b51367ec",
305
+ "metadata": {},
306
+ "source": [
307
+ "## 模型验证"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 1,
313
+ "id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": [
317
+ "import subprocess\n",
318
+ "import os\n",
319
+ "# 设置环境变量, autodl一般区域\n",
320
+ "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
321
+ "output = result.stdout\n",
322
+ "for line in output.splitlines():\n",
323
+ " if '=' in line:\n",
324
+ " var, value = line.split('=', 1)\n",
325
+ " os.environ[var] = value"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "17bdb69d-3f0f-465e-bd60-2047a088e264",
331
+ "metadata": {},
332
+ "source": [
333
+ "如果您不确定模型中有哪些模块可以微调,可以打印模型结构:"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "054a2956-9045-4ad5-a878-1bfc84ad4ed8",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
344
+ "from transformers import DataCollatorForLanguageModeling\n",
345
+ "from transformers import Trainer, TrainingArguments\n",
346
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
347
+ "from tokenizers import Tokenizer\n",
348
+ "from datasets import load_dataset"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "id": "63c8bf16-9576-41bc-b27c-c92ba4289cf4",
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "from datasets import load_dataset\n",
359
+ "dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n",
360
+ "dna_ft_dataset"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "95928da3-ca64-4a17-80f4-945da395702c",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n",
371
+ "data"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "id": "a3e65bcd-85ce-4261-8ba6-7665c4ec60e2",
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n",
382
+ "tokenizer.pad_token = tokenizer.eos_token"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "id": "3d3fe49b-f48f-42b2-bc97-028e443111e4",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #continue pretrain\n",
393
+ "model"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "id": "c54df9fe-86c4-4963-b313-b438894bf9dd",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "#构建提示词\n",
404
+ "def format_input(entry):\n",
405
+ " instruction_text = (\n",
406
+ " f\"Below is an instruction that describes a task. \"\n",
407
+ " f\"Write a response that appropriately completes the request.\"\n",
408
+ " f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
409
+ " )\n",
410
+ "\n",
411
+ " input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
412
+ "\n",
413
+ " return instruction_text + input_text + \"\\n\\n### Response:\\n\"\n",
414
+ "\n",
415
+ "#构建提示词\n",
416
+ "def build_prompt(entry):\n",
417
+ "\n",
418
+ " input_data = format_input(entry)\n",
419
+ "\n",
420
+ " desired_response = entry['output']\n",
421
+ "\n",
422
+ " return input_data + desired_response\n"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "ee540cfb-1f6e-4e02-a3bc-c814e43685cb",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "example = data[\"test\"][0]\n",
433
+ "example"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "7ee35528-7b3f-4e60-b88b-1bc3e950012b",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "prompt = build_prompt(example)\n",
444
+ "print(prompt)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "8aa6f38f-3bcc-4566-8a66-a541db91e031",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "tokenizer.tokenize(prompt)"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "11875339-4901-4912-86e5-afe8c74921d9",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n",
465
+ " # Tokenize\n",
466
+ " input_ids = tokenizer.encode(\n",
467
+ " text,\n",
468
+ " return_tensors=\"pt\",\n",
469
+ " truncation=True,\n",
470
+ " max_length=max_input_tokens\n",
471
+ " # return_attention_mask=True,\n",
472
+ " )\n",
473
+ "\n",
474
+ " # Generate\n",
475
+ " device = model.device\n",
476
+ " generated_tokens_with_prompt = model.generate(\n",
477
+ " input_ids=input_ids.to(device),\n",
478
+ " #max_length=max_output_tokens,\n",
479
+ " max_new_tokens=8,\n",
480
+ " temperature=0.01 # 控制生成的多样性\n",
481
+ " )\n",
482
+ "\n",
483
+ " # Decode\n",
484
+ " generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n",
485
+ " generated_text_answer = generated_text_with_prompt[len(text):]\n",
486
+ "\n",
487
+ "\n",
488
+ " return generated_text_answer\n",
489
+ "\n",
490
+ "# 如果需要进一步清理\n",
491
+ "def clean_generated_text(text):\n",
492
+ " # 去除 'Ġ' 符号并替换为空格\n",
493
+ " text = text.replace('Ġ', ' ')\n",
494
+ " # 去除多余的空格\n",
495
+ " text = ' '.join(text.split())\n",
496
+ " return text"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "id": "1b02644a-8b24-45aa-b22d-0f7ce2270dd9",
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "input_text = format_input(data[\"test\"][0])\n",
507
+ "\n",
508
+ "print(\"input (test):\", input_text)\n",
509
+ "\n",
510
+ "print(\"real answer:\", data[\"test\"][0][\"output\"])\n",
511
+ "\n",
512
+ "print(\"--------------------------\\n\")\n",
513
+ "\n",
514
+ "print(\"model's answer: \\n\")\n",
515
+ "print(inference(input_text, model, tokenizer))"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "id": "e2df1569-7f70-46ee-b93f-cbd879e32e14",
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "test_data = data[\"test\"].shuffle(seed=199).select(range(100))\n",
526
+ "\n",
527
+ "data_list = []\n",
528
+ "\n",
529
+ "for entry in test_data:\n",
530
+ " input_text = format_input(entry)\n",
531
+ " #print(input_text)\n",
532
+ " response_text = inference(input_text, model, tokenizer)\n",
533
+ " #print(response_text)\n",
534
+ " data = {\n",
535
+ " \"instruction\":entry[\"instruction\"],\n",
536
+ " \"input\":entry[\"input\"],\n",
537
+ " \"output\":entry[\"output\"],\n",
538
+ " \"model_response\":response_text\n",
539
+ " }\n",
540
+ "\n",
541
+ " data_list.append(data)"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "id": "0c6e47cb-1b64-4690-a51d-f1816b82f15f",
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "import json\n",
552
+ "\n",
553
+ "# 定义输出文件路径\n",
554
+ "output_file = 'llama-sft-2.json'\n",
555
+ "\n",
556
+ "# 将 Dataset 对象导出为 JSON 文件\n",
557
+ "# test_data.to_json(output_file)\n",
558
+ "with open(output_file, \"w\") as file:\n",
559
+ " json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n",
560
+ "\n"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "68831e19-5a99-46d8-9f40-e8bf6957dbfc",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "import json\n",
571
+ "from tqdm import tqdm\n",
572
+ "\n",
573
+ "\n",
574
+ "\n",
575
+ "with open(output_file, \"r\") as file:\n",
576
+ " test_data = json.load(file)\n",
577
+ "\n",
578
+ "all_num = len(test_data)\n",
579
+ "right_sum = 0\n",
580
+ "same_sum = 0\n",
581
+ "for item in test_data:\n",
582
+ " output = item[\"output\"]\n",
583
+ " #output = \" \".join(tokenizer.tokenize(output))\n",
584
+ " model_response = item[\"model_response\"]\n",
585
+ "\n",
586
+ " print(output,\"||||||||||||\", model_response)\n",
587
+ "\n",
588
+ " if model_response == output: #same it\n",
589
+ " same_sum = same_sum + 1\n",
590
+ " \n",
591
+ " if output.find(\"Non\")==-1: # no Non\n",
592
+ " if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n",
593
+ " right_sum = right_sum + 1\n",
594
+ " else:\n",
595
+ " if model_response.find(output)!=-1: #find it\n",
596
+ " right_sum = right_sum + 1\n",
597
+ "\n",
598
+ "\n",
599
+ "print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n"
600
+ ]
601
+ }
602
+ ],
603
+ "metadata": {
604
+ "kernelspec": {
605
+ "display_name": "Python 3 (ipykernel)",
606
+ "language": "python",
607
+ "name": "python3"
608
+ },
609
+ "language_info": {
610
+ "codemirror_mode": {
611
+ "name": "ipython",
612
+ "version": 3
613
+ },
614
+ "file_extension": ".py",
615
+ "mimetype": "text/x-python",
616
+ "name": "python",
617
+ "nbconvert_exporter": "python",
618
+ "pygments_lexer": "ipython3",
619
+ "version": "3.12.3"
620
+ }
621
+ },
622
+ "nbformat": 4,
623
+ "nbformat_minor": 5
624
+ }
04-gene-sft/.ipynb_checkpoints/gene_bpe_seg-checkpoint.vocab ADDED
The diff for this file is too large to render. See raw diff
 
04-gene-sft/.ipynb_checkpoints/llama_sft_test-checkpoint.ipynb ADDED
@@ -0,0 +1,1627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "81a2413e-8629-4016-aace-17d2f757f726",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "https://hf-mirror.com\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import os\n",
19
+ "\n",
20
+ "# 设置环境变量\n",
21
+ "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
22
+ "\n",
23
+ "# 打印环境变量以确认设置成功\n",
24
+ "print(os.environ.get('HF_ENDPOINT'))"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "id": "89e2d33a-6d84-4ef3-b44e-daa57ac81e58",
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stderr",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "2024-11-24 11:21:51.020375: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
38
+ "2024-11-24 11:21:51.036615: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
39
+ "2024-11-24 11:21:51.053557: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
40
+ "2024-11-24 11:21:51.058466: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
41
+ "2024-11-24 11:21:51.071840: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
42
+ "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
43
+ "2024-11-24 11:21:51.923693: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
49
+ "from transformers import DataCollatorForLanguageModeling\n",
50
+ "from transformers import Trainer, TrainingArguments\n",
51
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
52
+ "from tokenizers import Tokenizer\n",
53
+ "from datasets import load_dataset"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 3,
59
+ "id": "68fc5c44-b444-402e-aaf2-0ba4e2000e42",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "DatasetDict({\n",
66
+ " train: Dataset({\n",
67
+ " features: ['instruction', 'input', 'output'],\n",
68
+ " num_rows: 19839\n",
69
+ " })\n",
70
+ "})"
71
+ ]
72
+ },
73
+ "execution_count": 3,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "from datasets import load_dataset\n",
80
+ "dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n",
81
+ "dna_ft_dataset"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 4,
87
+ "id": "4ab4fd3e-5b59-470e-9b46-f0ffd7b9d1aa",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "data": {
92
+ "text/plain": [
93
+ "DatasetDict({\n",
94
+ " train: Dataset({\n",
95
+ " features: ['instruction', 'input', 'output'],\n",
96
+ " num_rows: 1983\n",
97
+ " })\n",
98
+ " test: Dataset({\n",
99
+ " features: ['instruction', 'input', 'output'],\n",
100
+ " num_rows: 17856\n",
101
+ " })\n",
102
+ "})"
103
+ ]
104
+ },
105
+ "execution_count": 4,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n",
112
+ "data"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 5,
118
+ "id": "85ca97f5-6864-4d6f-944a-182ed1fa2f00",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n",
123
+ "tokenizer.pad_token = tokenizer.eos_token"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 6,
129
+ "id": "e904c0b2-bf21-4036-b510-8e57177c1767",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "application/vnd.jupyter.widget-view+json": {
135
+ "model_id": "99ce92d0373a498d929bed42f770ed16",
136
+ "version_major": 2,
137
+ "version_minor": 0
138
+ },
139
+ "text/plain": [
140
+ "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
141
+ ]
142
+ },
143
+ "metadata": {},
144
+ "output_type": "display_data"
145
+ },
146
+ {
147
+ "data": {
148
+ "text/plain": [
149
+ "LlamaForCausalLM(\n",
150
+ " (model): LlamaModel(\n",
151
+ " (embed_tokens): Embedding(61973, 4096, padding_idx=0)\n",
152
+ " (layers): ModuleList(\n",
153
+ " (0-31): 32 x LlamaDecoderLayer(\n",
154
+ " (self_attn): LlamaSdpaAttention(\n",
155
+ " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
156
+ " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
157
+ " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
158
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
159
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
160
+ " )\n",
161
+ " (mlp): LlamaMLP(\n",
162
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
163
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
164
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
165
+ " (act_fn): SiLU()\n",
166
+ " )\n",
167
+ " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
168
+ " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
169
+ " )\n",
170
+ " )\n",
171
+ " (norm): LlamaRMSNorm((4096,), eps=1e-06)\n",
172
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
173
+ " )\n",
174
+ " (lm_head): Linear(in_features=4096, out_features=61973, bias=False)\n",
175
+ ")"
176
+ ]
177
+ },
178
+ "execution_count": 6,
179
+ "metadata": {},
180
+ "output_type": "execute_result"
181
+ }
182
+ ],
183
+ "source": [
184
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #continue pretrain\n",
185
+ "model"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 7,
191
+ "id": "5b361c5c-c43f-4ed9-a5c7-c72403cd7a0a",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "#构建提示词\n",
196
+ "def format_input(entry):\n",
197
+ " instruction_text = (\n",
198
+ " f\"Below is an instruction that describes a task. \"\n",
199
+ " f\"Write a response that appropriately completes the request.\"\n",
200
+ " f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
201
+ " )\n",
202
+ "\n",
203
+ " input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
204
+ "\n",
205
+ " return instruction_text + input_text + \"\\n\\n### Response:\\n\"\n",
206
+ "\n",
207
+ "#构建提示词\n",
208
+ "def build_prompt(entry):\n",
209
+ "\n",
210
+ " input_data = format_input(entry)\n",
211
+ "\n",
212
+ " desired_response = entry['output']\n",
213
+ "\n",
214
+ " return input_data + desired_response\n",
215
+ "\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 8,
221
+ "id": "ed031a26-d79e-4f50-85d1-169ebd409c6d",
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "{'instruction': 'Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.',\n",
228
+ " 'input': 'CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC',\n",
229
+ " 'output': 'promoter'}"
230
+ ]
231
+ },
232
+ "execution_count": 8,
233
+ "metadata": {},
234
+ "output_type": "execute_result"
235
+ }
236
+ ],
237
+ "source": [
238
+ "example = data[\"test\"][0]\n",
239
+ "example"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 9,
245
+ "id": "31bd4bb5-86a6-4046-b510-492b0548323b",
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
253
+ "\n",
254
+ "### Instruction:\n",
255
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
256
+ "\n",
257
+ "### Input:\n",
258
+ "CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
259
+ "\n",
260
+ "### Response:\n",
261
+ "promoter\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "prompt = build_prompt(example)\n",
267
+ "print(prompt)"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 10,
273
+ "id": "ed0b5b8b-916c-499b-a6da-f1124b9add1c",
274
+ "metadata": {
275
+ "scrolled": true
276
+ },
277
+ "outputs": [
278
+ {
279
+ "data": {
280
+ "text/plain": [
281
+ "['▁Below',\n",
282
+ " '▁is',\n",
283
+ " '▁an',\n",
284
+ " '▁instruction',\n",
285
+ " '▁that',\n",
286
+ " '▁describes',\n",
287
+ " '▁a',\n",
288
+ " '▁task',\n",
289
+ " '.',\n",
290
+ " '▁Write',\n",
291
+ " '▁a',\n",
292
+ " '▁response',\n",
293
+ " '▁that',\n",
294
+ " '▁appropri',\n",
295
+ " 'ately',\n",
296
+ " '▁comple',\n",
297
+ " 'tes',\n",
298
+ " '▁the',\n",
299
+ " '▁request',\n",
300
+ " '.',\n",
301
+ " '<0x0A>',\n",
302
+ " '<0x0A>',\n",
303
+ " '##',\n",
304
+ " '#',\n",
305
+ " '▁Inst',\n",
306
+ " 'ruction',\n",
307
+ " ':',\n",
308
+ " '<0x0A>',\n",
309
+ " 'Det',\n",
310
+ " 'erm',\n",
311
+ " 'ine',\n",
312
+ " '▁core',\n",
313
+ " '▁prom',\n",
314
+ " 'oter',\n",
315
+ " '▁detection',\n",
316
+ " '▁of',\n",
317
+ " '▁following',\n",
318
+ " '▁d',\n",
319
+ " 'na',\n",
320
+ " '▁sequence',\n",
321
+ " ',',\n",
322
+ " '▁The',\n",
323
+ " '▁result',\n",
324
+ " '▁will',\n",
325
+ " '▁be',\n",
326
+ " '▁one',\n",
327
+ " '▁of',\n",
328
+ " '▁the',\n",
329
+ " '▁following',\n",
330
+ " ':',\n",
331
+ " '▁Non',\n",
332
+ " '-',\n",
333
+ " 'prom',\n",
334
+ " 'oter',\n",
335
+ " ',',\n",
336
+ " '▁prom',\n",
337
+ " 'oter',\n",
338
+ " '.',\n",
339
+ " '<0x0A>',\n",
340
+ " '<0x0A>',\n",
341
+ " '##',\n",
342
+ " '#',\n",
343
+ " '▁Input',\n",
344
+ " ':',\n",
345
+ " '<0x0A>',\n",
346
+ " 'CCGTG',\n",
347
+ " 'C',\n",
348
+ " 'GAC',\n",
349
+ " 'CGGAA',\n",
350
+ " 'GTG',\n",
351
+ " 'GGGC',\n",
352
+ " 'GGC',\n",
353
+ " 'GAC',\n",
354
+ " 'CCCGGAA',\n",
355
+ " 'GTCC',\n",
356
+ " 'CCGCC',\n",
357
+ " 'GGGTG',\n",
358
+ " 'CA',\n",
359
+ " 'GCT',\n",
360
+ " 'TG',\n",
361
+ " 'GTC',\n",
362
+ " 'GGT',\n",
363
+ " 'TC',\n",
364
+ " 'GATCGCC',\n",
365
+ " '<0x0A>',\n",
366
+ " '<0x0A>',\n",
367
+ " '##',\n",
368
+ " '#',\n",
369
+ " '▁Response',\n",
370
+ " ':',\n",
371
+ " '<0x0A>',\n",
372
+ " 'prom',\n",
373
+ " 'oter']"
374
+ ]
375
+ },
376
+ "execution_count": 10,
377
+ "metadata": {},
378
+ "output_type": "execute_result"
379
+ }
380
+ ],
381
+ "source": [
382
+ "tokenizer.tokenize(prompt)"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 11,
388
+ "id": "f0449aee-1ac6-4db5-873f-afdfb0fc9691",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n",
393
+ " # Tokenize\n",
394
+ " input_ids = tokenizer.encode(\n",
395
+ " text,\n",
396
+ " return_tensors=\"pt\",\n",
397
+ " truncation=True,\n",
398
+ " max_length=max_input_tokens\n",
399
+ " # return_attention_mask=True,\n",
400
+ " )\n",
401
+ "\n",
402
+ " # Generate\n",
403
+ " device = model.device\n",
404
+ " generated_tokens_with_prompt = model.generate(\n",
405
+ " input_ids=input_ids.to(device),\n",
406
+ " #max_length=max_output_tokens,\n",
407
+ " max_new_tokens=8,\n",
408
+ " temperature=0.01 # 控制生成的多样性\n",
409
+ " )\n",
410
+ "\n",
411
+ " # Decode\n",
412
+ " generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n",
413
+ " generated_text_answer = generated_text_with_prompt[len(text):]\n",
414
+ "\n",
415
+ "\n",
416
+ " return generated_text_answer\n",
417
+ "\n",
418
+ "# 如果需要进一步清理\n",
419
+ "def clean_generated_text(text):\n",
420
+ " # 去除 'Ġ' 符号并替换为空格\n",
421
+ " text = text.replace('Ġ', ' ')\n",
422
+ " # 去除多余的空格\n",
423
+ " text = ' '.join(text.split())\n",
424
+ " return text"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": 12,
430
+ "id": "e9041426-eb59-4314-82dd-7b6d6d477783",
431
+ "metadata": {},
432
+ "outputs": [
433
+ {
434
+ "name": "stdout",
435
+ "output_type": "stream",
436
+ "text": [
437
+ "input (test): Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
438
+ "\n",
439
+ "### Instruction:\n",
440
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
441
+ "\n",
442
+ "### Input:\n",
443
+ "CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
444
+ "\n",
445
+ "### Response:\n",
446
+ "\n",
447
+ "real answer: promoter\n",
448
+ "--------------------------\n",
449
+ "\n",
450
+ "model's answer: \n",
451
+ "\n"
452
+ ]
453
+ },
454
+ {
455
+ "name": "stderr",
456
+ "output_type": "stream",
457
+ "text": [
458
+ "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.01` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
459
+ " warnings.warn(\n",
460
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
461
+ ]
462
+ },
463
+ {
464
+ "name": "stdout",
465
+ "output_type": "stream",
466
+ "text": [
467
+ " Non-promoter\n"
468
+ ]
469
+ }
470
+ ],
471
+ "source": [
472
+ "input_text = format_input(data[\"test\"][0])\n",
473
+ "\n",
474
+ "print(\"input (test):\", input_text)\n",
475
+ "\n",
476
+ "print(\"real answer:\", data[\"test\"][0][\"output\"])\n",
477
+ "\n",
478
+ "print(\"--------------------------\\n\")\n",
479
+ "\n",
480
+ "print(\"model's answer: \\n\")\n",
481
+ "print(inference(input_text, model, tokenizer))"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": 13,
487
+ "id": "d1489173-84af-4c8e-b66b-0cdbe42c7ea7",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "test_data = data[\"test\"].shuffle(seed=199).select(range(1000))\n",
492
+ "\n",
493
+ "data_list = []\n",
494
+ "\n",
495
+ "for entry in test_data:\n",
496
+ " input_text = format_input(entry)\n",
497
+ " #print(input_text)\n",
498
+ " response_text = inference(input_text, model, tokenizer)\n",
499
+ " #print(response_text)\n",
500
+ " data = {\n",
501
+ " \"instruction\":entry[\"instruction\"],\n",
502
+ " \"input\":entry[\"input\"],\n",
503
+ " \"output\":entry[\"output\"],\n",
504
+ " \"model_response\":response_text\n",
505
+ " }\n",
506
+ "\n",
507
+ " data_list.append(data)"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": 14,
513
+ "id": "39275fe6-ac3b-4558-9f4c-2853a41d48c4",
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": [
517
+ "import json\n",
518
+ "\n",
519
+ "# 定义输出文件路径\n",
520
+ "output_file = 'llama-sft-2.json'\n",
521
+ "\n",
522
+ "# 将 Dataset 对象导出为 JSON 文件\n",
523
+ "# test_data.to_json(output_file)\n",
524
+ "with open(output_file, \"w\") as file:\n",
525
+ " json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n",
526
+ "\n"
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "execution_count": 15,
532
+ "id": "7ffaba65-a270-4433-b234-932f5e288f7c",
533
+ "metadata": {},
534
+ "outputs": [
535
+ {
536
+ "data": {
537
+ "text/plain": [
538
+ "'▁prom oter'"
539
+ ]
540
+ },
541
+ "execution_count": 15,
542
+ "metadata": {},
543
+ "output_type": "execute_result"
544
+ }
545
+ ],
546
+ "source": [
547
+ "\" \".join(tokenizer.tokenize(\"promoter\"))"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": 16,
553
+ "id": "a7e373a4-6857-4874-b2da-58da2928925d",
554
+ "metadata": {},
555
+ "outputs": [
556
+ {
557
+ "name": "stdout",
558
+ "output_type": "stream",
559
+ "text": [
560
+ "Donor Sites |||||||||||| Donor Sites\n",
561
+ "promoter |||||||||||| promoter\n",
562
+ "promoter |||||||||||| promoter\n",
563
+ "promoter |||||||||||| promoter\n",
564
+ "promoter |||||||||||| promoter\n",
565
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
566
+ "promoter |||||||||||| Non-promoter\n",
567
+ "promoter |||||||||||| Non-promoter\n",
568
+ "Non-promoter |||||||||||| promoter\n",
569
+ "Non-promoter |||||||||||| Non-promoter\n",
570
+ "Donor Sites |||||||||||| Donor Sites\n",
571
+ "Non-promoter |||||||||||| Non-promoter\n",
572
+ "Non-promoter |||||||||||| promoter\n",
573
+ "Non-promoter |||||||||||| promoter\n",
574
+ "promoter |||||||||||| Non-promoter\n",
575
+ "promoter |||||||||||| Non-promoter\n",
576
+ "Donor Sites |||||||||||| Donor Sites\n",
577
+ "Background Sequences |||||||||||| Background Sequences\n",
578
+ "Non-promoter |||||||||||| Non-promoter\n",
579
+ "Non-promoter |||||||||||| promoter\n",
580
+ "promoter |||||||||||| promoter\n",
581
+ "promoter |||||||||||| promoter\n",
582
+ "promoter |||||||||||| promoter\n",
583
+ "promoter |||||||||||| promoter\n",
584
+ "promoter |||||||||||| Non-promoter\n",
585
+ "promoter |||||||||||| promoter\n",
586
+ "Non-promoter |||||||||||| Non-promoter\n",
587
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
588
+ "Non-promoter |||||||||||| Non-promoter\n",
589
+ "promoter |||||||||||| promoter\n",
590
+ "Non-promoter |||||||||||| Non-promoter\n",
591
+ "Binding Sites |||||||||||| Background Sequences\n",
592
+ "Non-promoter |||||||||||| Non-promoter\n",
593
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
594
+ "Non-promoter |||||||||||| Non-promoter\n",
595
+ "Non-promoter |||||||||||| promoter\n",
596
+ "Non-promoter |||||||||||| Non-promoter\n",
597
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
598
+ "Non-promoter |||||||||||| promoter\n",
599
+ "promoter |||||||||||| promoter\n",
600
+ "Background Sequences |||||||||||| Background Sequences\n",
601
+ "Non-promoter |||||||||||| Non-promoter\n",
602
+ "Binding Sites |||||||||||| Binding Sites\n",
603
+ "promoter |||||||||||| promoter\n",
604
+ "Non-promoter |||||||||||| Non-promoter\n",
605
+ "Non-promoter |||||||||||| Non-promoter\n",
606
+ "Non-promoter |||||||||||| Non-promoter\n",
607
+ "Non-promoter |||||||||||| Non-promoter\n",
608
+ "Donor Sites |||||||||||| Donor Sites\n",
609
+ "promoter |||||||||||| Non-promoter\n",
610
+ "promoter |||||||||||| Non-promoter\n",
611
+ "Non-promoter |||||||||||| promoter\n",
612
+ "Binding Sites |||||||||||| Binding Sites\n",
613
+ "promoter |||||||||||| Non-promoter\n",
614
+ "promoter |||||||||||| promoter\n",
615
+ "Background Sequences |||||||||||| Background Sequences\n",
616
+ "promoter |||||||||||| promoter\n",
617
+ "Non-promoter |||||||||||| Non-promoter\n",
618
+ "Background Sequences |||||||||||| Binding Sites\n",
619
+ "promoter |||||||||||| promoter\n",
620
+ "promoter |||||||||||| promoter\n",
621
+ "promoter |||||||||||| promoter\n",
622
+ "Donor Sites |||||||||||| Donor Sites\n",
623
+ "Binding Sites |||||||||||| Binding Sites\n",
624
+ "promoter |||||||||||| promoter\n",
625
+ "Donor Sites |||||||||||| Donor Sites\n",
626
+ "Non-promoter |||||||||||| Non-promoter\n",
627
+ "Binding Sites |||||||||||| Binding Sites\n",
628
+ "Donor Sites |||||||||||| Donor Sites\n",
629
+ "Non-promoter |||||||||||| promoter\n",
630
+ "Donor Sites |||||||||||| Donor Sites\n",
631
+ "Non-promoter |||||||||||| Non-promoter\n",
632
+ "promoter |||||||||||| promoter\n",
633
+ "promoter |||||||||||| promoter\n",
634
+ "promoter |||||||||||| promoter\n",
635
+ "Non-promoter |||||||||||| Non-promoter\n",
636
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
637
+ "promoter |||||||||||| Non-promoter\n",
638
+ "Donor Sites |||||||||||| Donor Sites\n",
639
+ "Donor Sites |||||||||||| Donor Sites\n",
640
+ "promoter |||||||||||| promoter\n",
641
+ "promoter |||||||||||| Non-promoter\n",
642
+ "promoter |||||||||||| promoter\n",
643
+ "Non-promoter |||||||||||| Non-promoter\n",
644
+ "Non-promoter |||||||||||| Non-promoter\n",
645
+ "promoter |||||||||||| promoter\n",
646
+ "Non-promoter |||||||||||| Non-promoter\n",
647
+ "promoter |||||||||||| promoter\n",
648
+ "Background Sequences |||||||||||| Binding Sites\n",
649
+ "Acceptor Sites |||||||||||| Donor Sites\n",
650
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
651
+ "Donor Sites |||||||||||| Donor Sites\n",
652
+ "Donor Sites |||||||||||| Donor Sites\n",
653
+ "Non-promoter |||||||||||| promoter\n",
654
+ "promoter |||||||||||| Non-promoter\n",
655
+ "Background Sequences |||||||||||| Background Sequences\n",
656
+ "promoter |||||||||||| promoter\n",
657
+ "promoter |||||||||||| promoter\n",
658
+ "Acceptor Sites |||||||||||| Donor Sites\n",
659
+ "promoter |||||||||||| promoter\n",
660
+ "Donor Sites |||||||||||| Donor Sites\n",
661
+ "Binding Sites |||||||||||| Courses\n",
662
+ "promoter |||||||||||| promoter\n",
663
+ "Donor Sites |||||||||||| Donor Sites\n",
664
+ "Non-promoter |||||||||||| Non-promoter\n",
665
+ "Non-promoter |||||||||||| Non-promoter\n",
666
+ "Donor Sites |||||||||||| Donor Sites\n",
667
+ "Donor Sites |||||||||||| Donor Sites\n",
668
+ "Non-promoter |||||||||||| Non-promoter\n",
669
+ "Binding Sites |||||||||||| Binding Sites\n",
670
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
671
+ "Binding Sites |||||||||||| Court\n",
672
+ "Donor Sites |||||||||||| Donor Sites\n",
673
+ "Non-promoter |||||||||||| promoter\n",
674
+ "promoter |||||||||||| Non-promoter\n",
675
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
676
+ "promoter |||||||||||| promoter\n",
677
+ "Binding Sites |||||||||||| Background Sequences\n",
678
+ "promoter |||||||||||| promoter\n",
679
+ "Non-promoter |||||||||||| Non-promoter\n",
680
+ "Binding Sites |||||||||||| Binding Sites\n",
681
+ "Non-promoter |||||||||||| Non-promoter\n",
682
+ "Non-promoter |||||||||||| Non-promoter\n",
683
+ "Non-promoter |||||||||||| promoter\n",
684
+ "Non-promoter |||||||||||| Non-promoter\n",
685
+ "promoter |||||||||||| promoter\n",
686
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
687
+ "promoter |||||||||||| promoter\n",
688
+ "promoter |||||||||||| Non-promoter\n",
689
+ "promoter |||||||||||| promoter\n",
690
+ "promoter |||||||||||| promoter\n",
691
+ "Donor Sites |||||||||||| Donor Sites\n",
692
+ "Background Sequences |||||||||||| Binding Sites\n",
693
+ "Background Sequences |||||||||||| Background Sequences\n",
694
+ "promoter |||||||||||| promoter\n",
695
+ "Non-promoter |||||||||||| Non-promoter\n",
696
+ "promoter |||||||||||| promoter\n",
697
+ "Donor Sites |||||||||||| Donor Sites\n",
698
+ "Non-promoter |||||||||||| promoter\n",
699
+ "Acceptor Sites |||||||||||| Non-Splice Sites\n",
700
+ "Non-promoter |||||||||||| promoter\n",
701
+ "Non-promoter |||||||||||| Non-promoter\n",
702
+ "Non-promoter |||||||||||| Non-promoter\n",
703
+ "promoter |||||||||||| promoter\n",
704
+ "promoter |||||||||||| promoter\n",
705
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
706
+ "Non-promoter |||||||||||| Non-promoter\n",
707
+ "Non-promoter |||||||||||| Non-promoter\n",
708
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
709
+ "Donor Sites |||||||||||| Donor Sites\n",
710
+ "Donor Sites |||||||||||| Donor Sites\n",
711
+ "Binding Sites |||||||||||| Background Sequences\n",
712
+ "Binding Sites |||||||||||| Binding Sites\n",
713
+ "promoter |||||||||||| promoter\n",
714
+ "Non-promoter |||||||||||| Non-promoter\n",
715
+ "Binding Sites |||||||||||| Background Sequences\n",
716
+ "Background Sequences |||||||||||| Background Sequences\n",
717
+ "Non-promoter |||||||||||| promoter\n",
718
+ "Non-promoter |||||||||||| Non-promoter\n",
719
+ "promoter |||||||||||| Non-promoter\n",
720
+ "Donor Sites |||||||||||| Donor Sites\n",
721
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
722
+ "Binding Sites |||||||||||| Binding Sites\n",
723
+ "Donor Sites |||||||||||| Donor Sites\n",
724
+ "promoter |||||||||||| Non-promoter\n",
725
+ "Acceptor Sites |||||||||||| Donor Sites\n",
726
+ "Non-promoter |||||||||||| Non-promoter\n",
727
+ "Non-promoter |||||||||||| Non-promoter\n",
728
+ "Donor Sites |||||||||||| Donor Sites\n",
729
+ "Donor Sites |||||||||||| Donor Sites\n",
730
+ "Donor Sites |||||||||||| Donor Sites\n",
731
+ "promoter |||||||||||| promoter\n",
732
+ "promoter |||||||||||| promoter\n",
733
+ "promoter |||||||||||| promoter\n",
734
+ "promoter |||||||||||| promoter\n",
735
+ "promoter |||||||||||| promoter\n",
736
+ "promoter |||||||||||| promoter\n",
737
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
738
+ "promoter |||||||||||| promoter\n",
739
+ "Background Sequences |||||||||||| Background Sequences\n",
740
+ "Non-promoter |||||||||||| Non-promoter\n",
741
+ "promoter |||||||||||| promoter\n",
742
+ "Donor Sites |||||||||||| Donor Sites\n",
743
+ "Non-promoter |||||||||||| promoter\n",
744
+ "Donor Sites |||||||||||| Donor Sites\n",
745
+ "Binding Sites |||||||||||| Binding Sites\n",
746
+ "Donor Sites |||||||||||| Donor Sites\n",
747
+ "Binding Sites |||||||||||| Binding Sites\n",
748
+ "Non-promoter |||||||||||| promoter\n",
749
+ "Non-promoter |||||||||||| Non-promoter\n",
750
+ "Background Sequences |||||||||||| Binding Sites\n",
751
+ "Non-promoter |||||||||||| Non-promoter\n",
752
+ "promoter |||||||||||| promoter\n",
753
+ "Background Sequences |||||||||||| Background Sequences\n",
754
+ "Non-promoter |||||||||||| promoter\n",
755
+ "Non-promoter |||||||||||| Non-promoter\n",
756
+ "Non-promoter |||||||||||| Non-promoter\n",
757
+ "Background Sequences |||||||||||| Binding Sites\n",
758
+ "Background Sequences |||||||||||| Background Sequences\n",
759
+ "Non-promoter |||||||||||| Non-promoter\n",
760
+ "promoter |||||||||||| promoter\n",
761
+ "Background Sequences |||||||||||| Background Sequences\n",
762
+ "Non-promoter |||||||||||| promoter\n",
763
+ "Non-promoter |||||||||||| promoter\n",
764
+ "promoter |||||||||||| promoter\n",
765
+ "promoter |||||||||||| promoter\n",
766
+ "promoter |||||||||||| promoter\n",
767
+ "Non-promoter |||||||||||| Non-promoter\n",
768
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
769
+ "promoter |||||||||||| Non-promoter\n",
770
+ "promoter |||||||||||| promoter\n",
771
+ "Background Sequences |||||||||||| Background Sequences\n",
772
+ "Background Sequences |||||||||||| Background Sequences\n",
773
+ "Background Sequences |||||||||||| Background Sequences\n",
774
+ "Donor Sites |||||||||||| Donor Sites\n",
775
+ "Binding Sites |||||||||||| Binding Sites\n",
776
+ "Non-promoter |||||||||||| Non-promoter\n",
777
+ "Non-promoter |||||||||||| promoter\n",
778
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
779
+ "Binding Sites |||||||||||| Binding Sites\n",
780
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
781
+ "Background Sequences |||||||||||| Binding Sites\n",
782
+ "promoter |||||||||||| promoter\n",
783
+ "Non-Splice Sites |||||||||||| Splice Sites\n",
784
+ "promoter |||||||||||| promoter\n",
785
+ "Donor Sites |||||||||||| Acceptor Sites\n",
786
+ "Binding Sites |||||||||||| Binding Sites\n",
787
+ "Non-promoter |||||||||||| promoter\n",
788
+ "promoter |||||||||||| promoter\n",
789
+ "Donor Sites |||||||||||| Acceptor Sites\n",
790
+ "Non-promoter |||||||||||| Non-promoter\n",
791
+ "promoter |||||||||||| promoter\n",
792
+ "Donor Sites |||||||||||| Donor Sites\n",
793
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
794
+ "Donor Sites |||||||||||| Donor Sites\n",
795
+ "Binding Sites |||||||||||| Binding Sites\n",
796
+ "promoter |||||||||||| promoter\n",
797
+ "Background Sequences |||||||||||| Background Sequences\n",
798
+ "promoter |||||||||||| promoter\n",
799
+ "Binding Sites |||||||||||| Coursing\n",
800
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
801
+ "Non-promoter |||||||||||| Non-promoter\n",
802
+ "Donor Sites |||||||||||| Donor Sites\n",
803
+ "Non-promoter |||||||||||| Non-promoter\n",
804
+ "Non-promoter |||||||||||| promoter\n",
805
+ "Binding Sites |||||||||||| Binding Sites\n",
806
+ "Binding Sites |||||||||||| Binding Sites\n",
807
+ "Background Sequences |||||||||||| Background Sequences\n",
808
+ "Non-promoter |||||||||||| Non-promoter\n",
809
+ "promoter |||||||||||| Non-promoter\n",
810
+ "promoter |||||||||||| promoter\n",
811
+ "Non-promoter |||||||||||| promoter\n",
812
+ "promoter |||||||||||| promoter\n",
813
+ "Non-promoter |||||||||||| promoter\n",
814
+ "Non-promoter |||||||||||| promoter\n",
815
+ "Non-promoter |||||||||||| promoter\n",
816
+ "promoter |||||||||||| Non-promoter\n",
817
+ "promoter |||||||||||| promoter\n",
818
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
819
+ "promoter |||||||||||| promoter\n",
820
+ "Non-promoter |||||||||||| Non-promoter\n",
821
+ "Acceptor Sites |||||||||||| Donor Sites\n",
822
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
823
+ "promoter |||||||||||| promoter\n",
824
+ "promoter |||||||||||| promoter\n",
825
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
826
+ "Donor Sites |||||||||||| Donor Sites\n",
827
+ "Non-promoter |||||||||||| Non-promoter\n",
828
+ "promoter |||||||||||| promoter\n",
829
+ "Acceptor Sites |||||||||||| Donor Sites\n",
830
+ "Non-promoter |||||||||||| Non-promoter\n",
831
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
832
+ "Acceptor Sites |||||||||||| Non-Splice Sites\n",
833
+ "Non-promoter |||||||||||| Non-promoter\n",
834
+ "Background Sequences |||||||||||| Background Sequences\n",
835
+ "Donor Sites |||||||||||| Donor Sites\n",
836
+ "promoter |||||||||||| promoter\n",
837
+ "promoter |||||||||||| promoter\n",
838
+ "Acceptor Sites |||||||||||| Donor Sites\n",
839
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
840
+ "promoter |||||||||||| promoter\n",
841
+ "Non-promoter |||||||||||| Non-promoter\n",
842
+ "promoter |||||||||||| promoter\n",
843
+ "Non-promoter |||||||||||| promoter\n",
844
+ "promoter |||||||||||| promoter\n",
845
+ "Non-promoter |||||||||||| Non-promoter\n",
846
+ "Donor Sites |||||||||||| Donor Sites\n",
847
+ "promoter |||||||||||| promoter\n",
848
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
849
+ "Donor Sites |||||||||||| Donor Sites\n",
850
+ "Donor Sites |||||||||||| Donor Sites\n",
851
+ "Donor Sites |||||||||||| Donor Sites\n",
852
+ "promoter |||||||||||| promoter\n",
853
+ "Non-promoter |||||||||||| promoter\n",
854
+ "Binding Sites |||||||||||| Binding Sites\n",
855
+ "promoter |||||||||||| promoter\n",
856
+ "promoter |||||||||||| promoter\n",
857
+ "Binding Sites |||||||||||| Binding Sites\n",
858
+ "Binding Sites |||||||||||| Binding Sites\n",
859
+ "Non-promoter |||||||||||| Non-promoter\n",
860
+ "Non-promoter |||||||||||| Non-promoter\n",
861
+ "Non-promoter |||||||||||| Non-promoter\n",
862
+ "promoter |||||||||||| promoter\n",
863
+ "Background Sequences |||||||||||| Background Sequences\n",
864
+ "promoter |||||||||||| promoter\n",
865
+ "promoter |||||||||||| promoter\n",
866
+ "Background Sequences |||||||||||| Background Sequences\n",
867
+ "Binding Sites |||||||||||| Binding Sites\n",
868
+ "Binding Sites |||||||||||| Background Sequences\n",
869
+ "Non-promoter |||||||||||| Non-promoter\n",
870
+ "Non-promoter |||||||||||| promoter\n",
871
+ "Non-promoter |||||||||||| Non-promoter\n",
872
+ "Non-promoter |||||||||||| promoter\n",
873
+ "Donor Sites |||||||||||| Donor Sites\n",
874
+ "promoter |||||||||||| promoter\n",
875
+ "promoter |||||||||||| promoter\n",
876
+ "Non-promoter |||||||||||| Non-promoter\n",
877
+ "Donor Sites |||||||||||| Donor Sites\n",
878
+ "Donor Sites |||||||||||| Donor Sites\n",
879
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
880
+ "promoter |||||||||||| promoter\n",
881
+ "Donor Sites |||||||||||| Donor Sites\n",
882
+ "promoter |||||||||||| promoter\n",
883
+ "Non-promoter |||||||||||| promoter\n",
884
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
885
+ "Binding Sites |||||||||||| Binding Sites\n",
886
+ "promoter |||||||||||| promoter\n",
887
+ "Donor Sites |||||||||||| Donor Sites\n",
888
+ "Donor Sites |||||||||||| Donor Sites\n",
889
+ "promoter |||||||||||| promoter\n",
890
+ "promoter |||||||||||| promoter\n",
891
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
892
+ "promoter |||||||||||| promoter\n",
893
+ "Binding Sites |||||||||||| Background Sequences\n",
894
+ "Non-promoter |||||||||||| Non-promoter\n",
895
+ "Donor Sites |||||||||||| Donor Sites\n",
896
+ "Non-promoter |||||||||||| promoter\n",
897
+ "promoter |||||||||||| promoter\n",
898
+ "Non-promoter |||||||||||| Non-promoter\n",
899
+ "promoter |||||||||||| promoter\n",
900
+ "promoter |||||||||||| promoter\n",
901
+ "Donor Sites |||||||||||| Donor Sites\n",
902
+ "Donor Sites |||||||||||| Donor Sites\n",
903
+ "Donor Sites |||||||||||| Donor Sites\n",
904
+ "Binding Sites |||||||||||| Binding Sites\n",
905
+ "Acceptor Sites |||||||||||| Donor Sites\n",
906
+ "Non-promoter |||||||||||| promoter\n",
907
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
908
+ "Background Sequences |||||||||||| Background Sequences\n",
909
+ "Donor Sites |||||||||||| Donor Sites\n",
910
+ "promoter |||||||||||| promoter\n",
911
+ "Donor Sites |||||||||||| Donor Sites\n",
912
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
913
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
914
+ "Background Sequences |||||||||||| Background Sequences\n",
915
+ "Non-promoter |||||||||||| Non-promoter\n",
916
+ "Non-promoter |||||||||||| Non-promoter\n",
917
+ "Non-promoter |||||||||||| Non-promoter\n",
918
+ "Non-promoter |||||||||||| promoter\n",
919
+ "Binding Sites |||||||||||| Binding Sites\n",
920
+ "promoter |||||||||||| promoter\n",
921
+ "promoter |||||||||||| Non-promoter\n",
922
+ "promoter |||||||||||| promoter\n",
923
+ "promoter |||||||||||| promoter\n",
924
+ "Non-promoter |||||||||||| Non-promoter\n",
925
+ "Donor Sites |||||||||||| Donor Sites\n",
926
+ "Non-promoter |||||||||||| Non-promoter\n",
927
+ "Non-promoter |||||||||||| Non-promoter\n",
928
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
929
+ "promoter |||||||||||| Non-promoter\n",
930
+ "Non-promoter |||||||||||| promoter\n",
931
+ "Binding Sites |||||||||||| Binding Sites\n",
932
+ "Binding Sites |||||||||||| Background Sequences\n",
933
+ "Donor Sites |||||||||||| D Donor Sites\n",
934
+ "promoter |||||||||||| promoter\n",
935
+ "Background Sequences |||||||||||| Background Sequences\n",
936
+ "Background Sequences |||||||||||| Background Sequences\n",
937
+ "Non-promoter |||||||||||| Non-promoter\n",
938
+ "promoter |||||||||||| promoter\n",
939
+ "Non-promoter |||||||||||| promoter\n",
940
+ "promoter |||||||||||| promoter\n",
941
+ "Non-promoter |||||||||||| promoter\n",
942
+ "Non-promoter |||||||||||| promoter\n",
943
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
944
+ "Non-promoter |||||||||||| Non-promoter\n",
945
+ "promoter |||||||||||| promoter\n",
946
+ "Donor Sites |||||||||||| Acceptor Sites\n",
947
+ "Donor Sites |||||||||||| Donor Sites\n",
948
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
949
+ "Background Sequences |||||||||||| Background Sequences\n",
950
+ "promoter |||||||||||| promoter\n",
951
+ "promoter |||||||||||| promoter\n",
952
+ "Non-promoter |||||||||||| Non-promoter\n",
953
+ "promoter |||||||||||| promoter\n",
954
+ "promoter |||||||||||| promoter\n",
955
+ "Background Sequences |||||||||||| Background Sequences\n",
956
+ "Donor Sites |||||||||||| Donor Sites\n",
957
+ "promoter |||||||||||| promoter\n",
958
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
959
+ "Binding Sites |||||||||||| Binding Sites\n",
960
+ "Non-promoter |||||||||||| Non-promoter\n",
961
+ "Non-promoter |||||||||||| Non-promoter\n",
962
+ "promoter |||||||||||| promoter\n",
963
+ "Non-promoter |||||||||||| promoter\n",
964
+ "Non-promoter |||||||||||| Non-promoter\n",
965
+ "Acceptor Sites |||||||||||| Donor Sites\n",
966
+ "promoter |||||||||||| promoter\n",
967
+ "Acceptor Sites |||||||||||| Donor Sites\n",
968
+ "promoter |||||||||||| promoter\n",
969
+ "promoter |||||||||||| promoter\n",
970
+ "Acceptor Sites |||||||||||| Donor Sites\n",
971
+ "promoter |||||||||||| promoter\n",
972
+ "promoter |||||||||||| promoter\n",
973
+ "promoter |||||||||||| Non-promoter\n",
974
+ "Non-promoter |||||||||||| promoter\n",
975
+ "promoter |||||||||||| promoter\n",
976
+ "Non-promoter |||||||||||| Non-promoter\n",
977
+ "Background Sequences |||||||||||| Background Sequences\n",
978
+ "Non-promoter |||||||||||| Non-promoter\n",
979
+ "Background Sequences |||||||||||| Background Sequences\n",
980
+ "Binding Sites |||||||||||| Binding Sites\n",
981
+ "Background Sequences |||||||||||| Background Sequences\n",
982
+ "promoter |||||||||||| promoter\n",
983
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
984
+ "Background Sequences |||||||||||| Background Sequences\n",
985
+ "Background Sequences |||||||||||| Background Sequences\n",
986
+ "Non-promoter |||||||||||| Non-promoter\n",
987
+ "Donor Sites |||||||||||| Donor Sites\n",
988
+ "Non-promoter |||||||||||| Non-promoter\n",
989
+ "Acceptor Sites |||||||||||| Donor Sites\n",
990
+ "Non-promoter |||||||||||| promoter\n",
991
+ "Non-promoter |||||||||||| Non-promoter\n",
992
+ "promoter |||||||||||| Non-promoter\n",
993
+ "Binding Sites |||||||||||| Background Sequences\n",
994
+ "Binding Sites |||||||||||| Background Sequences\n",
995
+ "Non-promoter |||||||||||| Non-promoter\n",
996
+ "Non-promoter |||||||||||| Non-promoter\n",
997
+ "Non-promoter |||||||||||| Non-promoter\n",
998
+ "Non-promoter |||||||||||| Non-promoter\n",
999
+ "Non-promoter |||||||||||| Non-promoter\n",
1000
+ "Non-promoter |||||||||||| promoter\n",
1001
+ "promoter |||||||||||| Non-promoter\n",
1002
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1003
+ "Non-promoter |||||||||||| Non-promoter\n",
1004
+ "Non-promoter |||||||||||| Non-promoter\n",
1005
+ "Non-promoter |||||||||||| Non-promoter\n",
1006
+ "Non-promoter |||||||||||| Non-promoter\n",
1007
+ "Binding Sites |||||||||||| Binding Sites\n",
1008
+ "Non-promoter |||||||||||| Non-promoter\n",
1009
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1010
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1011
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1012
+ "promoter |||||||||||| promoter\n",
1013
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1014
+ "promoter |||||||||||| promoter\n",
1015
+ "Non-promoter |||||||||||| Non-promoter\n",
1016
+ "Non-promoter |||||||||||| Non-promoter\n",
1017
+ "Donor Sites |||||||||||| Donor Sites\n",
1018
+ "promoter |||||||||||| Non-promoter\n",
1019
+ "promoter |||||||||||| promoter\n",
1020
+ "promoter |||||||||||| promoter\n",
1021
+ "Binding Sites |||||||||||| Binding Sites\n",
1022
+ "Donor Sites |||||||||||| Donor Sites\n",
1023
+ "Non-promoter |||||||||||| promoter\n",
1024
+ "Donor Sites |||||||||||| Donor Sites\n",
1025
+ "Non-promoter |||||||||||| promoter\n",
1026
+ "Background Sequences |||||||||||| Background Sequences\n",
1027
+ "Non-promoter |||||||||||| promoter\n",
1028
+ "Non-promoter |||||||||||| Non-promoter\n",
1029
+ "promoter |||||||||||| promoter\n",
1030
+ "Non-promoter |||||||||||| Non-promoter\n",
1031
+ "Binding Sites |||||||||||| Binding Sites\n",
1032
+ "Non-promoter |||||||||||| promoter\n",
1033
+ "Donor Sites |||||||||||| Donor Sites\n",
1034
+ "promoter |||||||||||| promoter\n",
1035
+ "promoter |||||||||||| promoter\n",
1036
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1037
+ "Non-promoter |||||||||||| Non-promoter\n",
1038
+ "promoter |||||||||||| promoter\n",
1039
+ "promoter |||||||||||| promoter\n",
1040
+ "promoter |||||||||||| promoter\n",
1041
+ "Non-promoter |||||||||||| Non-promoter\n",
1042
+ "Non-promoter |||||||||||| promoter\n",
1043
+ "promoter |||||||||||| promoter\n",
1044
+ "Non-promoter |||||||||||| Non-promoter\n",
1045
+ "promoter |||||||||||| promoter\n",
1046
+ "Non-promoter |||||||||||| promoter\n",
1047
+ "promoter |||||||||||| promoter\n",
1048
+ "Donor Sites |||||||||||| Donor Sites\n",
1049
+ "promoter |||||||||||| promoter\n",
1050
+ "Binding Sites |||||||||||| Background Sequences\n",
1051
+ "promoter |||||||||||| promoter\n",
1052
+ "Non-promoter |||||||||||| promoter\n",
1053
+ "promoter |||||||||||| promoter\n",
1054
+ "Non-promoter |||||||||||| Non-promoter\n",
1055
+ "Non-promoter |||||||||||| promoter\n",
1056
+ "promoter |||||||||||| promoter\n",
1057
+ "promoter |||||||||||| Non-promoter\n",
1058
+ "Non-promoter |||||||||||| Non-promoter\n",
1059
+ "promoter |||||||||||| promoter\n",
1060
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1061
+ "Non-promoter |||||||||||| Non-promoter\n",
1062
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1063
+ "Binding Sites |||||||||||| Background Sequences\n",
1064
+ "promoter |||||||||||| promoter\n",
1065
+ "Donor Sites |||||||||||| Donor Sites\n",
1066
+ "Non-promoter |||||||||||| Non-promoter\n",
1067
+ "Non-promoter |||||||||||| Non-promoter\n",
1068
+ "Non-promoter |||||||||||| promoter\n",
1069
+ "Non-promoter |||||||||||| Non-promoter\n",
1070
+ "Binding Sites |||||||||||| Binding Sites\n",
1071
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1072
+ "Donor Sites |||||||||||| Donor Sites\n",
1073
+ "Donor Sites |||||||||||| Donor Sites\n",
1074
+ "promoter |||||||||||| promoter\n",
1075
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1076
+ "Non-promoter |||||||||||| Non-promoter\n",
1077
+ "promoter |||||||||||| promoter\n",
1078
+ "promoter |||||||||||| promoter\n",
1079
+ "Donor Sites |||||||||||| Donor Sites\n",
1080
+ "Non-promoter |||||||||||| promoter\n",
1081
+ "Binding Sites |||||||||||| Background Sequences\n",
1082
+ "Background Sequences |||||||||||| Background Sequences\n",
1083
+ "promoter |||||||||||| Non-promoter\n",
1084
+ "promoter |||||||||||| promoter\n",
1085
+ "promoter |||||||||||| promoter\n",
1086
+ "promoter |||||||||||| promoter\n",
1087
+ "Non-promoter |||||||||||| Non-promoter\n",
1088
+ "Non-promoter |||||||||||| Non-promoter\n",
1089
+ "Donor Sites |||||||||||| Donor Sites\n",
1090
+ "Background Sequences |||||||||||| Background Sequences\n",
1091
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1092
+ "Non-promoter |||||||||||| Non-promoter\n",
1093
+ "promoter |||||||||||| promoter\n",
1094
+ "Non-promoter |||||||||||| Non-promoter\n",
1095
+ "Non-promoter |||||||||||| C promoter\n",
1096
+ "promoter |||||||||||| Non-promoter\n",
1097
+ "promoter |||||||||||| promoter\n",
1098
+ "Non-promoter |||||||||||| Non-promoter\n",
1099
+ "Donor Sites |||||||||||| Donor Sites\n",
1100
+ "Donor Sites |||||||||||| Donor Sites\n",
1101
+ "Donor Sites |||||||||||| Donor Sites\n",
1102
+ "Background Sequences |||||||||||| Background Sequences\n",
1103
+ "promoter |||||||||||| promoter\n",
1104
+ "Non-promoter |||||||||||| Non-promoter\n",
1105
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1106
+ "Binding Sites |||||||||||| Background Sequences\n",
1107
+ "Non-promoter |||||||||||| Non-promoter\n",
1108
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1109
+ "Non-promoter |||||||||||| Non-promoter\n",
1110
+ "promoter |||||||||||| promoter\n",
1111
+ "Non-promoter |||||||||||| Non-promoter\n",
1112
+ "promoter |||||||||||| promoter\n",
1113
+ "Non-promoter |||||||||||| promoter\n",
1114
+ "promoter |||||||||||| Non-promoter\n",
1115
+ "Non-promoter |||||||||||| Non-promoter\n",
1116
+ "Binding Sites |||||||||||| Binding Sites\n",
1117
+ "Donor Sites |||||||||||| Donor Sites\n",
1118
+ "Non-promoter |||||||||||| promoter\n",
1119
+ "promoter |||||||||||| promoter\n",
1120
+ "Non-promoter |||||||||||| Non-promoter\n",
1121
+ "Background Sequences |||||||||||| Binding Sites\n",
1122
+ "Binding Sites |||||||||||| Binding Sites\n",
1123
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1124
+ "Non-promoter |||||||||||| Non-promoter\n",
1125
+ "Non-promoter |||||||||||| Non-promoter\n",
1126
+ "Non-promoter |||||||||||| Non-promoter\n",
1127
+ "Donor Sites |||||||||||| Donor Sites\n",
1128
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1129
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1130
+ "promoter |||||||||||| promoter\n",
1131
+ "Non-promoter |||||||||||| Non-promoter\n",
1132
+ "Non-promoter |||||||||||| Non-promoter\n",
1133
+ "Donor Sites |||||||||||| Donor Sites\n",
1134
+ "promoter |||||||||||| promoter\n",
1135
+ "promoter |||||||||||| promoter\n",
1136
+ "promoter |||||||||||| promoter\n",
1137
+ "Background Sequences |||||||||||| Background Sequences\n",
1138
+ "promoter |||||||||||| promoter\n",
1139
+ "Donor Sites |||||||||||| Donor Sites\n",
1140
+ "Background Sequences |||||||||||| Background Sequences\n",
1141
+ "Binding Sites |||||||||||| Binding Sites\n",
1142
+ "Non-promoter |||||||||||| promoter\n",
1143
+ "Non-promoter |||||||||||| Non-promoter\n",
1144
+ "promoter |||||||||||| promoter\n",
1145
+ "promoter |||||||||||| promoter\n",
1146
+ "promoter |||||||||||| promoter\n",
1147
+ "Binding Sites |||||||||||| Binding Sites\n",
1148
+ "Background Sequences |||||||||||| Background Sequences\n",
1149
+ "Non-promoter |||||||||||| Non-promoter\n",
1150
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1151
+ "Non-promoter |||||||||||| Non-promoter\n",
1152
+ "Non-promoter |||||||||||| promoter\n",
1153
+ "Background Sequences |||||||||||| Binding Sites\n",
1154
+ "promoter |||||||||||| promoter\n",
1155
+ "Non-promoter |||||||||||| Non-promoter\n",
1156
+ "promoter |||||||||||| Non-promoter\n",
1157
+ "Non-promoter |||||||||||| Non-promoter\n",
1158
+ "Non-promoter |||||||||||| Non-promoter\n",
1159
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1160
+ "Non-promoter |||||||||||| Non-promoter\n",
1161
+ "promoter |||||||||||| promoter\n",
1162
+ "Non-promoter |||||||||||| promoter\n",
1163
+ "Non-promoter |||||||||||| promoter\n",
1164
+ "promoter |||||||||||| promoter\n",
1165
+ "Non-promoter |||||||||||| Non-promoter\n",
1166
+ "Non-promoter |||||||||||| promoter\n",
1167
+ "Non-promoter |||||||||||| Non-promoter\n",
1168
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1169
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1170
+ "promoter |||||||||||| Non-promoter\n",
1171
+ "Binding Sites |||||||||||| Background Sequences\n",
1172
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1173
+ "Non-promoter |||||||||||| Non-promoter\n",
1174
+ "Donor Sites |||||||||||| Donor Sites\n",
1175
+ "Non-promoter |||||||||||| Non-promoter\n",
1176
+ "promoter |||||||||||| promoter\n",
1177
+ "Donor Sites |||||||||||| Donor Sites\n",
1178
+ "Donor Sites |||||||||||| Donor Sites\n",
1179
+ "Non-promoter |||||||||||| promoter\n",
1180
+ "Binding Sites |||||||||||| Binding Sites\n",
1181
+ "Non-promoter |||||||||||| Non-promoter\n",
1182
+ "Binding Sites |||||||||||| Binding Sites\n",
1183
+ "Donor Sites |||||||||||| Donor Sites\n",
1184
+ "Background Sequences |||||||||||| Background Sequences\n",
1185
+ "Donor Sites |||||||||||| Donor Sites\n",
1186
+ "Background Sequences |||||||||||| Binding Sites\n",
1187
+ "Binding Sites |||||||||||| Binding Sites\n",
1188
+ "promoter |||||||||||| promoter\n",
1189
+ "promoter |||||||||||| promoter\n",
1190
+ "promoter |||||||||||| promoter\n",
1191
+ "Binding Sites |||||||||||| Binding Sites\n",
1192
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1193
+ "Non-promoter |||||||||||| Non-promoter\n",
1194
+ "Non-promoter |||||||||||| promoter\n",
1195
+ "promoter |||||||||||| promoter\n",
1196
+ "promoter |||||||||||| promoter\n",
1197
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1198
+ "Binding Sites |||||||||||| Binding Sites\n",
1199
+ "Background Sequences |||||||||||| Background Sequences\n",
1200
+ "Donor Sites |||||||||||| Donor Sites\n",
1201
+ "Non-promoter |||||||||||| Non-promoter\n",
1202
+ "promoter |||||||||||| promoter\n",
1203
+ "Background Sequences |||||||||||| Background Sequences\n",
1204
+ "Donor Sites |||||||||||| Donor Sites\n",
1205
+ "promoter |||||||||||| promoter\n",
1206
+ "Non-promoter |||||||||||| Non-promoter\n",
1207
+ "Non-promoter |||||||||||| Non-promoter\n",
1208
+ "Non-promoter |||||||||||| Non-promoter\n",
1209
+ "promoter |||||||||||| promoter\n",
1210
+ "Binding Sites |||||||||||| Binding Sites\n",
1211
+ "promoter |||||||||||| Non-promoter\n",
1212
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1213
+ "promoter |||||||||||| promoter\n",
1214
+ "promoter |||||||||||| promoter\n",
1215
+ "Background Sequences |||||||||||| Background Sequences\n",
1216
+ "Background Sequences |||||||||||| Background Sequences\n",
1217
+ "Non-promoter |||||||||||| Non-promoter\n",
1218
+ "Binding Sites |||||||||||| Binding Sites\n",
1219
+ "Background Sequences |||||||||||| Background Sequences\n",
1220
+ "Non-promoter |||||||||||| Non-promoter\n",
1221
+ "Non-promoter |||||||||||| Non-promoter\n",
1222
+ "Donor Sites |||||||||||| Donor Sites\n",
1223
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1224
+ "Non-promoter |||||||||||| Non-promoter\n",
1225
+ "Binding Sites |||||||||||| Binding Sites\n",
1226
+ "promoter |||||||||||| promoter\n",
1227
+ "Non-promoter |||||||||||| promoter\n",
1228
+ "promoter |||||||||||| Non-promoter\n",
1229
+ "Donor Sites |||||||||||| Donor Sites\n",
1230
+ "Non-promoter |||||||||||| Non-promoter\n",
1231
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1232
+ "Binding Sites |||||||||||| Background Sequences\n",
1233
+ "Background Sequences |||||||||||| Background Sequences\n",
1234
+ "Non-promoter |||||||||||| promoter\n",
1235
+ "Non-promoter |||||||||||| Non-promoter\n",
1236
+ "promoter |||||||||||| promoter\n",
1237
+ "Donor Sites |||||||||||| Donor Sites\n",
1238
+ "promoter |||||||||||| promoter\n",
1239
+ "Donor Sites |||||||||||| Donor Sites\n",
1240
+ "Donor Sites |||||||||||| Donor Sites\n",
1241
+ "promoter |||||||||||| Non-promoter\n",
1242
+ "Binding Sites |||||||||||| Background Sequences\n",
1243
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1244
+ "promoter |||||||||||| promoter\n",
1245
+ "promoter |||||||||||| promoter\n",
1246
+ "Non-promoter |||||||||||| Non-promoter\n",
1247
+ "Non-promoter |||||||||||| Non-promoter\n",
1248
+ "Background Sequences |||||||||||| Binding Sites\n",
1249
+ "Non-promoter |||||||||||| Non-promoter\n",
1250
+ "Non-promoter |||||||||||| promoter\n",
1251
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1252
+ "Non-promoter |||||||||||| Non-promoter\n",
1253
+ "promoter |||||||||||| promoter\n",
1254
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1255
+ "promoter |||||||||||| promoter\n",
1256
+ "Binding Sites |||||||||||| Binding Sites\n",
1257
+ "promoter |||||||||||| promoter\n",
1258
+ "promoter |||||||||||| promoter\n",
1259
+ "Non-promoter |||||||||||| promoter\n",
1260
+ "promoter |||||||||||| Non-promoter\n",
1261
+ "Non-promoter |||||||||||| Non-promoter\n",
1262
+ "promoter |||||||||||| promoter\n",
1263
+ "Donor Sites |||||||||||| Donor Sites\n",
1264
+ "Non-promoter |||||||||||| promoter\n",
1265
+ "Non-promoter |||||||||||| Non-promoter\n",
1266
+ "Donor Sites |||||||||||| Donor Sites\n",
1267
+ "promoter |||||||||||| promoter\n",
1268
+ "promoter |||||||||||| promoter\n",
1269
+ "promoter |||||||||||| promoter\n",
1270
+ "Donor Sites |||||||||||| Donor Sites\n",
1271
+ "Donor Sites |||||||||||| Donor Sites\n",
1272
+ "promoter |||||||||||| promoter\n",
1273
+ "Non-promoter |||||||||||| Non-promoter\n",
1274
+ "promoter |||||||||||| Non-promoter\n",
1275
+ "Non-promoter |||||||||||| Non-promoter\n",
1276
+ "Non-promoter |||||||||||| promoter\n",
1277
+ "promoter |||||||||||| promoter\n",
1278
+ "promoter |||||||||||| promoter\n",
1279
+ "Binding Sites |||||||||||| Background Sequences\n",
1280
+ "Non-promoter |||||||||||| Non-promoter\n",
1281
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1282
+ "Non-promoter |||||||||||| Non-promoter\n",
1283
+ "Non-promoter |||||||||||| Non-promoter\n",
1284
+ "Non-promoter |||||||||||| Non-promoter\n",
1285
+ "promoter |||||||||||| promoter\n",
1286
+ "promoter |||||||||||| promoter\n",
1287
+ "Donor Sites |||||||||||| Donor Sites\n",
1288
+ "Binding Sites |||||||||||| Binding Sites\n",
1289
+ "promoter |||||||||||| Non-promoter\n",
1290
+ "promoter |||||||||||| promoter\n",
1291
+ "Background Sequences |||||||||||| Binding Sites\n",
1292
+ "Non-promoter |||||||||||| Non-promoter\n",
1293
+ "Non-promoter |||||||||||| Non-promoter\n",
1294
+ "promoter |||||||||||| Non-promoter\n",
1295
+ "promoter |||||||||||| promoter\n",
1296
+ "Non-promoter |||||||||||| Non-promoter\n",
1297
+ "Background Sequences |||||||||||| Binding Sites\n",
1298
+ "Binding Sites |||||||||||| Binding Sites\n",
1299
+ "Non-promoter |||||||||||| Non-promoter\n",
1300
+ "Non-promoter |||||||||||| Non-promoter\n",
1301
+ "Binding Sites |||||||||||| Binding Sites\n",
1302
+ "promoter |||||||||||| promoter\n",
1303
+ "Non-promoter |||||||||||| Non-promoter\n",
1304
+ "promoter |||||||||||| Non-promoter\n",
1305
+ "promoter |||||||||||| promoter\n",
1306
+ "Non-promoter |||||||||||| Non-promoter\n",
1307
+ "promoter |||||||||||| Non-promoter\n",
1308
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1309
+ "Binding Sites |||||||||||| Background Sequences\n",
1310
+ "Background Sequences |||||||||||| Background Sequences\n",
1311
+ "promoter |||||||||||| Non-promoter\n",
1312
+ "Donor Sites |||||||||||| Donor Sites\n",
1313
+ "promoter |||||||||||| promoter\n",
1314
+ "Binding Sites |||||||||||| Binding Sites\n",
1315
+ "promoter |||||||||||| promoter\n",
1316
+ "Non-promoter |||||||||||| promoter\n",
1317
+ "Non-promoter |||||||||||| promoter\n",
1318
+ "Background Sequences |||||||||||| Background Sequences\n",
1319
+ "Non-promoter |||||||||||| Non-promoter\n",
1320
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1321
+ "promoter |||||||||||| promoter\n",
1322
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1323
+ "Donor Sites |||||||||||| Donor Sites\n",
1324
+ "promoter |||||||||||| promoter\n",
1325
+ "Non-promoter |||||||||||| Non-promoter\n",
1326
+ "Non-promoter |||||||||||| promoter\n",
1327
+ "Acceptor Sites |||||||||||| Splice Sites\n",
1328
+ "Binding Sites |||||||||||| Binding Sites\n",
1329
+ "Non-promoter |||||||||||| Non-promoter\n",
1330
+ "promoter |||||||||||| promoter\n",
1331
+ "Binding Sites |||||||||||| Binding Sites\n",
1332
+ "promoter |||||||||||| Non-promoter\n",
1333
+ "Donor Sites |||||||||||| Donor Sites\n",
1334
+ "promoter |||||||||||| promoter\n",
1335
+ "promoter |||||||||||| promoter\n",
1336
+ "Donor Sites |||||||||||| Donor Sites\n",
1337
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1338
+ "Non-promoter |||||||||||| promoter\n",
1339
+ "Non-promoter |||||||||||| Non-promoter\n",
1340
+ "promoter |||||||||||| promoter\n",
1341
+ "Non-promoter |||||||||||| Non-promoter\n",
1342
+ "Binding Sites |||||||||||| Background Sequences\n",
1343
+ "Non-promoter |||||||||||| Non-promoter\n",
1344
+ "Binding Sites |||||||||||| Binding Sites\n",
1345
+ "promoter |||||||||||| promoter\n",
1346
+ "promoter |||||||||||| promoter\n",
1347
+ "Non-promoter |||||||||||| Non-promoter\n",
1348
+ "Non-promoter |||||||||||| Non-promoter\n",
1349
+ "Donor Sites |||||||||||| Donor Sites\n",
1350
+ "Donor Sites |||||||||||| Donor Sites\n",
1351
+ "Background Sequences |||||||||||| Background Sequences\n",
1352
+ "promoter |||||||||||| promoter\n",
1353
+ "promoter |||||||||||| promoter\n",
1354
+ "Non-promoter |||||||||||| Non-promoter\n",
1355
+ "Binding Sites |||||||||||| Binding Sites\n",
1356
+ "promoter |||||||||||| promoter\n",
1357
+ "Binding Sites |||||||||||| Binding Sites\n",
1358
+ "promoter |||||||||||| promoter\n",
1359
+ "Donor Sites |||||||||||| Donor Sites\n",
1360
+ "promoter |||||||||||| promoter\n",
1361
+ "promoter |||||||||||| promoter\n",
1362
+ "Background Sequences |||||||||||| Background Sequences\n",
1363
+ "Non-promoter |||||||||||| Non-promoter\n",
1364
+ "promoter |||||||||||| promoter\n",
1365
+ "Non-promoter |||||||||||| Non-promoter\n",
1366
+ "Donor Sites |||||||||||| Donor Sites\n",
1367
+ "Background Sequences |||||||||||| Binding Sites\n",
1368
+ "Non-promoter |||||||||||| Non-promoter\n",
1369
+ "Donor Sites |||||||||||| Donor Sites\n",
1370
+ "promoter |||||||||||| promoter\n",
1371
+ "promoter |||||||||||| promoter\n",
1372
+ "promoter |||||||||||| promoter\n",
1373
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1374
+ "Background Sequences |||||||||||| Binding Sites\n",
1375
+ "promoter |||||||||||| Non-promoter\n",
1376
+ "Donor Sites |||||||||||| Donor Sites\n",
1377
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1378
+ "Non-promoter |||||||||||| Non-promoter\n",
1379
+ "Background Sequences |||||||||||| Background Sequences\n",
1380
+ "promoter |||||||||||| promoter\n",
1381
+ "Non-promoter |||||||||||| promoter\n",
1382
+ "Non-promoter |||||||||||| Non-promoter\n",
1383
+ "promoter |||||||||||| promoter\n",
1384
+ "promoter |||||||||||| promoter\n",
1385
+ "promoter |||||||||||| promoter\n",
1386
+ "promoter |||||||||||| promoter\n",
1387
+ "Non-promoter |||||||||||| Non-promoter\n",
1388
+ "Non-promoter |||||||||||| promoter\n",
1389
+ "Non-promoter |||||||||||| Non-promoter\n",
1390
+ "promoter |||||||||||| Non-promoter\n",
1391
+ "promoter |||||||||||| promoter\n",
1392
+ "Non-promoter |||||||||||| Non-promoter\n",
1393
+ "promoter |||||||||||| promoter\n",
1394
+ "Non-promoter |||||||||||| promoter\n",
1395
+ "promoter |||||||||||| Non-promoter\n",
1396
+ "Non-promoter |||||||||||| promoter\n",
1397
+ "promoter |||||||||||| promoter\n",
1398
+ "Binding Sites |||||||||||| Binding Sites\n",
1399
+ "promoter |||||||||||| promoter\n",
1400
+ "Non-promoter |||||||||||| Non-promoter\n",
1401
+ "promoter |||||||||||| promoter\n",
1402
+ "promoter |||||||||||| Non-promoter\n",
1403
+ "Non-promoter |||||||||||| Non-promoter\n",
1404
+ "Background Sequences |||||||||||| Binding Sites\n",
1405
+ "Donor Sites |||||||||||| Donor Sites\n",
1406
+ "Donor Sites |||||||||||| Donor Sites\n",
1407
+ "Binding Sites |||||||||||| Binding Sites\n",
1408
+ "Non-promoter |||||||||||| promoter\n",
1409
+ "Non-promoter |||||||||||| Non-promoter\n",
1410
+ "Non-promoter |||||||||||| Non-promoter\n",
1411
+ "promoter |||||||||||| promoter\n",
1412
+ "promoter |||||||||||| promoter\n",
1413
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1414
+ "Non-promoter |||||||||||| Non-promoter\n",
1415
+ "Non-promoter |||||||||||| promoter\n",
1416
+ "promoter |||||||||||| promoter\n",
1417
+ "Donor Sites |||||||||||| Donor Sites\n",
1418
+ "promoter |||||||||||| Non-promoter\n",
1419
+ "Non-promoter |||||||||||| promoter\n",
1420
+ "promoter |||||||||||| promoter\n",
1421
+ "Non-promoter |||||||||||| Non-promoter\n",
1422
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1423
+ "Non-promoter |||||||||||| Non-promoter\n",
1424
+ "promoter |||||||||||| Non-promoter\n",
1425
+ "Donor Sites |||||||||||| Donor Sites\n",
1426
+ "Non-promoter |||||||||||| Non-promoter\n",
1427
+ "Background Sequences |||||||||||| Background Sequences\n",
1428
+ "promoter |||||||||||| promoter\n",
1429
+ "promoter |||||||||||| promoter\n",
1430
+ "Donor Sites |||||||||||| Donor Sites\n",
1431
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1432
+ "promoter |||||||||||| promoter\n",
1433
+ "promoter |||||||||||| promoter\n",
1434
+ "promoter |||||||||||| promoter\n",
1435
+ "promoter |||||||||||| promoter\n",
1436
+ "promoter |||||||||||| promoter\n",
1437
+ "promoter |||||||||||| Non-promoter\n",
1438
+ "promoter |||||||||||| promoter\n",
1439
+ "promoter |||||||||||| promoter\n",
1440
+ "promoter |||||||||||| promoter\n",
1441
+ "Background Sequences |||||||||||| Background Sequences\n",
1442
+ "Background Sequences |||||||||||| Background Sequences\n",
1443
+ "promoter |||||||||||| promoter\n",
1444
+ "promoter |||||||||||| promoter\n",
1445
+ "Non-promoter |||||||||||| Non-promoter\n",
1446
+ "Background Sequences |||||||||||| Background Sequences\n",
1447
+ "Non-promoter |||||||||||| Non-promoter\n",
1448
+ "Non-promoter |||||||||||| Non-promoter\n",
1449
+ "Non-promoter |||||||||||| promoter\n",
1450
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
1451
+ "promoter |||||||||||| promoter\n",
1452
+ "Non-promoter |||||||||||| promoter\n",
1453
+ "Non-promoter |||||||||||| Non-promoter\n",
1454
+ "Background Sequences |||||||||||| Background Sequences\n",
1455
+ "promoter |||||||||||| Non-promoter\n",
1456
+ "promoter |||||||||||| Non-promoter\n",
1457
+ "Background Sequences |||||||||||| Background Sequences\n",
1458
+ "Background Sequences |||||||||||| Background Sequences\n",
1459
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1460
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1461
+ "Non-promoter |||||||||||| Non-promoter\n",
1462
+ "Non-promoter |||||||||||| Non-promoter\n",
1463
+ "promoter |||||||||||| promoter\n",
1464
+ "Non-promoter |||||||||||| Non-promoter\n",
1465
+ "promoter |||||||||||| Non-promoter\n",
1466
+ "Binding Sites |||||||||||| Background Sequences\n",
1467
+ "Binding Sites |||||||||||| Binding Sites\n",
1468
+ "Non-promoter |||||||||||| Non-promoter\n",
1469
+ "promoter |||||||||||| promoter\n",
1470
+ "Non-promoter |||||||||||| Non-promoter\n",
1471
+ "promoter |||||||||||| promoter\n",
1472
+ "Binding Sites |||||||||||| Binding Sites\n",
1473
+ "Non-promoter |||||||||||| Non-promoter\n",
1474
+ "Non-promoter |||||||||||| Non-promoter\n",
1475
+ "promoter |||||||||||| promoter\n",
1476
+ "Non-promoter |||||||||||| Non-promoter\n",
1477
+ "Binding Sites |||||||||||| Background Sequences\n",
1478
+ "Donor Sites |||||||||||| Donor Sites\n",
1479
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1480
+ "Non-promoter |||||||||||| promoter\n",
1481
+ "Non-promoter |||||||||||| Non-promoter\n",
1482
+ "promoter |||||||||||| promoter\n",
1483
+ "Background Sequences |||||||||||| Background Sequences\n",
1484
+ "Donor Sites |||||||||||| Donor Sites\n",
1485
+ "Non-promoter |||||||||||| promoter\n",
1486
+ "promoter |||||||||||| promoter\n",
1487
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1488
+ "Binding Sites |||||||||||| Binding Sites\n",
1489
+ "Non-promoter |||||||||||| promoter\n",
1490
+ "Donor Sites |||||||||||| Donor Sites\n",
1491
+ "promoter |||||||||||| promoter\n",
1492
+ "promoter |||||||||||| promoter\n",
1493
+ "Non-promoter |||||||||||| promoter\n",
1494
+ "Non-promoter |||||||||||| Non-promoter\n",
1495
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1496
+ "Non-promoter |||||||||||| Non-promoter\n",
1497
+ "Background Sequences |||||||||||| Background Sequences\n",
1498
+ "promoter |||||||||||| Non-promoter\n",
1499
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1500
+ "Donor Sites |||||||||||| Donor Sites\n",
1501
+ "promoter |||||||||||| promoter\n",
1502
+ "Binding Sites |||||||||||| Binding Sites\n",
1503
+ "promoter |||||||||||| promoter\n",
1504
+ "Donor Sites |||||||||||| Donor Sites\n",
1505
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1506
+ "promoter |||||||||||| promoter\n",
1507
+ "Non-promoter |||||||||||| Non-promoter\n",
1508
+ "promoter |||||||||||| Non-promoter\n",
1509
+ "Binding Sites |||||||||||| Binding Sites\n",
1510
+ "Non-promoter |||||||||||| Non-promoter\n",
1511
+ "Non-promoter |||||||||||| promoter\n",
1512
+ "Non-promoter |||||||||||| Non-promoter\n",
1513
+ "Non-promoter |||||||||||| Non-promoter\n",
1514
+ "Non-promoter |||||||||||| Non-promoter\n",
1515
+ "Non-promoter |||||||||||| promoter\n",
1516
+ "promoter |||||||||||| promoter\n",
1517
+ "Background Sequences |||||||||||| Binding Sites\n",
1518
+ "Non-promoter |||||||||||| promoter\n",
1519
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1520
+ "Donor Sites |||||||||||| Donor Sites\n",
1521
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1522
+ "Non-promoter |||||||||||| Non-promoter\n",
1523
+ "promoter |||||||||||| Non-promoter\n",
1524
+ "Non-promoter |||||||||||| promoter\n",
1525
+ "promoter |||||||||||| Non-promoter\n",
1526
+ "promoter |||||||||||| promoter\n",
1527
+ "promoter |||||||||||| promoter\n",
1528
+ "Donor Sites |||||||||||| Donor Sites\n",
1529
+ "promoter |||||||||||| promoter\n",
1530
+ "Donor Sites |||||||||||| Donor Sites\n",
1531
+ "Non-promoter |||||||||||| Non-promoter\n",
1532
+ "Donor Sites |||||||||||| Donor Sites\n",
1533
+ "Non-promoter |||||||||||| promoter\n",
1534
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1535
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1536
+ "Non-promoter |||||||||||| Non-promoter\n",
1537
+ "promoter |||||||||||| Non-promoter\n",
1538
+ "promoter |||||||||||| promoter\n",
1539
+ "Non-promoter |||||||||||| Non-promoter\n",
1540
+ "Non-promoter |||||||||||| Non-promoter\n",
1541
+ "promoter |||||||||||| Non-promoter\n",
1542
+ "promoter |||||||||||| promoter\n",
1543
+ "promoter |||||||||||| promoter\n",
1544
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1545
+ "Non-promoter |||||||||||| Non-promoter\n",
1546
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1547
+ "promoter |||||||||||| promoter\n",
1548
+ "promoter |||||||||||| promoter\n",
1549
+ "Background Sequences |||||||||||| Background Sequences\n",
1550
+ "Binding Sites |||||||||||| Binding Sites\n",
1551
+ "Donor Sites |||||||||||| Donor Sites\n",
1552
+ "Binding Sites |||||||||||| Binding Sites\n",
1553
+ "Non-promoter |||||||||||| Non-promoter\n",
1554
+ "promoter |||||||||||| promoter\n",
1555
+ "Background Sequences |||||||||||| Binding Sites\n",
1556
+ "Non-promoter |||||||||||| Non-promoter\n",
1557
+ "Background Sequences |||||||||||| Background Sequences\n",
1558
+ "promoter |||||||||||| promoter\n",
1559
+ "Non-promoter |||||||||||| Non-promoter\n",
1560
+ "presicion 0.739 same 0.253\n"
1561
+ ]
1562
+ }
1563
+ ],
1564
+ "source": [
1565
+ "import json\n",
1566
+ "from tqdm import tqdm\n",
1567
+ "\n",
1568
+ "\n",
1569
+ "\n",
1570
+ "with open(output_file, \"r\") as file:\n",
1571
+ " test_data = json.load(file)\n",
1572
+ "\n",
1573
+ "all_num = len(test_data)\n",
1574
+ "right_sum = 0\n",
1575
+ "same_sum = 0\n",
1576
+ "for item in test_data:\n",
1577
+ " output = item[\"output\"]\n",
1578
+ " #output = \" \".join(tokenizer.tokenize(output))\n",
1579
+ " model_response = item[\"model_response\"]\n",
1580
+ "\n",
1581
+ " print(output,\"||||||||||||\", model_response)\n",
1582
+ "\n",
1583
+ " if model_response == output: #same it\n",
1584
+ " same_sum = same_sum + 1\n",
1585
+ " \n",
1586
+ " if output.find(\"Non\")==-1: # no Non\n",
1587
+ " if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n",
1588
+ " right_sum = right_sum + 1\n",
1589
+ " else:\n",
1590
+ " if model_response.find(output)!=-1: #find it\n",
1591
+ " right_sum = right_sum + 1\n",
1592
+ "\n",
1593
+ "\n",
1594
+ "print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n"
1595
+ ]
1596
+ },
1597
+ {
1598
+ "cell_type": "code",
1599
+ "execution_count": null,
1600
+ "id": "294d46f3-2f5b-4e55-ae41-081d5195f5e2",
1601
+ "metadata": {},
1602
+ "outputs": [],
1603
+ "source": []
1604
+ }
1605
+ ],
1606
+ "metadata": {
1607
+ "kernelspec": {
1608
+ "display_name": "Python 3 (ipykernel)",
1609
+ "language": "python",
1610
+ "name": "python3"
1611
+ },
1612
+ "language_info": {
1613
+ "codemirror_mode": {
1614
+ "name": "ipython",
1615
+ "version": 3
1616
+ },
1617
+ "file_extension": ".py",
1618
+ "mimetype": "text/x-python",
1619
+ "name": "python",
1620
+ "nbconvert_exporter": "python",
1621
+ "pygments_lexer": "ipython3",
1622
+ "version": "3.12.3"
1623
+ }
1624
+ },
1625
+ "nbformat": 4,
1626
+ "nbformat_minor": 5
1627
+ }
04-gene-sft/.ipynb_checkpoints/merge_pt_model-checkpoint.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ python merge_llama_with_dna_lora.py \
3
+ --base_model llama-7b-hf \
4
+ --lora_model dnahlm_llama_7b/pt_lora_model \
5
+ --output_type huggingface \
6
+ --output_dir dnahlm-merge-hf
04-gene-sft/.ipynb_checkpoints/merge_sft_model-checkpoint.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ python merge_llama_with_dna_lora.py \
3
+ --base_model dnahlm-merge-hf \
4
+ --lora_model dnahlm-llama7b-sft/sft_lora_model \
5
+ --output_type huggingface \
6
+ --output_dir dnahlm-llama-7b-sft-v0
04-gene-sft/.ipynb_checkpoints/run_clm_pt_with_peft-checkpoint.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import numpy as np
26
+ import math
27
+ import os
28
+ import sys
29
+ from dataclasses import dataclass, field
30
+ from itertools import chain
31
+ from typing import Optional, List, Dict, Any, Mapping
32
+ from pathlib import Path
33
+ import datasets
34
+ import torch
35
+ from datasets import load_dataset, concatenate_datasets
36
+
37
+ import transformers
38
+ from transformers import (
39
+ CONFIG_MAPPING,
40
+ MODEL_FOR_CAUSAL_LM_MAPPING,
41
+ AutoConfig,
42
+ AutoModelForCausalLM,
43
+ LlamaForCausalLM,
44
+ LlamaTokenizer,
45
+ AutoTokenizer,
46
+ HfArgumentParser,
47
+ Trainer,
48
+ TrainingArguments,
49
+ is_torch_tpu_available,
50
+ set_seed,
51
+ )
52
+ from transformers.testing_utils import CaptureLogger
53
+ from transformers.trainer_utils import get_last_checkpoint
54
+ from transformers.utils import send_example_telemetry
55
+ from transformers.utils.versions import require_version
56
+
57
+ from sklearn.metrics import accuracy_score
58
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict
59
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
60
+
61
+
62
+ class SavePeftModelCallback(transformers.TrainerCallback):
63
+ def save_model(self, args, state, kwargs):
64
+ if state.best_model_checkpoint is not None:
65
+ checkpoint_folder = os.path.join(state.best_model_checkpoint, "pt_lora_model")
66
+ else:
67
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
68
+
69
+ peft_model_path = os.path.join(checkpoint_folder, "pt_lora_model")
70
+ kwargs["model"].save_pretrained(peft_model_path)
71
+
72
+ if "tokenizer" in kwargs:
73
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
74
+ else:
75
+ kwargs["processing_class"].save_pretrained(peft_model_path)
76
+
77
+ def on_save(self, args, state, control, **kwargs):
78
+ self.save_model(args, state, kwargs)
79
+ return control
80
+
81
+ def on_train_end(self, args, state, control, **kwargs):
82
+ peft_model_path = os.path.join(args.output_dir, "pt_lora_model")
83
+ kwargs["model"].save_pretrained(peft_model_path)
84
+
85
+ if "tokenizer" in kwargs:
86
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
87
+ else:
88
+ kwargs["processing_class"].save_pretrained(peft_model_path)
89
+
90
+
91
+ def accuracy(predictions, references, normalize=True, sample_weight=None):
92
+ return {
93
+ "accuracy": float(
94
+ accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
95
+ )
96
+ }
97
+
98
+
99
+ def compute_metrics(eval_preds):
100
+ preds, labels = eval_preds
101
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
102
+ # by preprocess_logits_for_metrics but we need to shift the labels
103
+ labels = labels[:, 1:].reshape(-1)
104
+ preds = preds[:, :-1].reshape(-1)
105
+ return accuracy(predictions=preds, references=labels)
106
+
107
+
108
+ def preprocess_logits_for_metrics(logits, labels):
109
+ if isinstance(logits, tuple):
110
+ # Depending on the model and config, logits may contain extra tensors,
111
+ # like past_key_values, but logits always come first
112
+ logits = logits[0]
113
+ return logits.argmax(dim=-1)
114
+
115
+
116
+ def fault_tolerance_data_collator(features: List) -> Dict[str, Any]:
117
+ if not isinstance(features[0], Mapping):
118
+ features = [vars(f) for f in features]
119
+ first = features[0]
120
+ batch = {}
121
+
122
+ # Special handling for labels.
123
+ # Ensure that tensor is created with the correct type
124
+ # (it should be automatically the case, but let's make sure of it.)
125
+ if "label" in first and first["label"] is not None:
126
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
127
+ dtype = torch.long if isinstance(label, int) else torch.float
128
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
129
+ elif "label_ids" in first and first["label_ids"] is not None:
130
+ if isinstance(first["label_ids"], torch.Tensor):
131
+ batch["labels"] = torch.stack([f["label_ids"] for f in features])
132
+ else:
133
+ dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
134
+ batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
135
+
136
+ # Handling of all other possible keys.
137
+ # Again, we will use the first element to figure out which key/values are not None for this model.
138
+
139
+ try:
140
+ for k, v in first.items():
141
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
142
+ if isinstance(v, torch.Tensor):
143
+ batch[k] = torch.stack([f[k] for f in features])
144
+ elif isinstance(v, np.ndarray):
145
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
146
+ else:
147
+ batch[k] = torch.tensor([f[k] for f in features])
148
+ except ValueError: # quick fix by simply take the first example
149
+ for k, v in first.items():
150
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
151
+ if isinstance(v, torch.Tensor):
152
+ batch[k] = torch.stack([features[0][k]] * len(features))
153
+ elif isinstance(v, np.ndarray):
154
+ batch[k] = torch.tensor(np.stack([features[0][k]] * len(features)))
155
+ else:
156
+ batch[k] = torch.tensor([features[0][k]] * len(features))
157
+
158
+ return batch
159
+
160
+
161
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
162
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
163
+
164
+
165
+ @dataclass
166
+ class ModelArguments:
167
+ """
168
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
169
+ """
170
+
171
+ model_name_or_path: Optional[str] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": (
175
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
176
+ )
177
+ },
178
+ )
179
+ tokenizer_name_or_path: Optional[str] = field(
180
+ default=None,
181
+ metadata={
182
+ "help": (
183
+ "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
184
+ )
185
+ },
186
+ )
187
+ model_type: Optional[str] = field(
188
+ default=None,
189
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
190
+ )
191
+ config_overrides: Optional[str] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": (
195
+ "Override some existing default config settings when a model is trained from scratch. Example: "
196
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
197
+ )
198
+ },
199
+ )
200
+ config_name: Optional[str] = field(
201
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
202
+ )
203
+ tokenizer_name: Optional[str] = field(
204
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
205
+ )
206
+ cache_dir: Optional[str] = field(
207
+ default=None,
208
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
209
+ )
210
+ use_fast_tokenizer: bool = field(
211
+ default=True,
212
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
213
+ )
214
+ model_revision: str = field(
215
+ default="main",
216
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
217
+ )
218
+ use_auth_token: bool = field(
219
+ default=False,
220
+ metadata={
221
+ "help": (
222
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
223
+ "with private models)."
224
+ )
225
+ },
226
+ )
227
+ torch_dtype: Optional[str] = field(
228
+ default=None,
229
+ metadata={
230
+ "help": (
231
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
232
+ "dtype will be automatically derived from the model's weights."
233
+ ),
234
+ "choices": ["auto", "bfloat16", "float16", "float32"],
235
+ },
236
+ )
237
+
238
+ def __post_init__(self):
239
+ if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
240
+ raise ValueError(
241
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
242
+ )
243
+
244
+
245
+ @dataclass
246
+ class DataTrainingArguments:
247
+ """
248
+ Arguments pertaining to what data we are going to input our model for training and eval.
249
+ """
250
+
251
+ dataset_dir: Optional[str] = field(
252
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
253
+ )
254
+ dataset_config_name: Optional[str] = field(
255
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
256
+ )
257
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
258
+ validation_file: Optional[str] = field(
259
+ default=None,
260
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
261
+ )
262
+ max_train_samples: Optional[int] = field(
263
+ default=None,
264
+ metadata={
265
+ "help": (
266
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
267
+ "value if set."
268
+ )
269
+ },
270
+ )
271
+ max_eval_samples: Optional[int] = field(
272
+ default=None,
273
+ metadata={
274
+ "help": (
275
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
276
+ "value if set."
277
+ )
278
+ },
279
+ )
280
+ streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
281
+ block_size: Optional[int] = field(
282
+ default=None,
283
+ metadata={
284
+ "help": (
285
+ "Optional input sequence length after tokenization. "
286
+ "The training dataset will be truncated in block of this size for training. "
287
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
288
+ )
289
+ },
290
+ )
291
+ overwrite_cache: bool = field(
292
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
293
+ )
294
+ validation_split_percentage: Optional[float] = field(
295
+ default=0.05,
296
+ metadata={
297
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
298
+ },
299
+ )
300
+ preprocessing_num_workers: Optional[int] = field(
301
+ default=None,
302
+ metadata={"help": "The number of processes to use for the preprocessing."},
303
+ )
304
+ keep_linebreaks: bool = field(
305
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
306
+ )
307
+ data_cache_dir: Optional[str] = field(default="./", metadata={"help": "The datasets processed stored"})
308
+
309
+ def __post_init__(self):
310
+ if self.streaming:
311
+ require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
312
+
313
+
314
+ @dataclass
315
+ class MyTrainingArguments(TrainingArguments):
316
+ trainable : Optional[str] = field(default="q_proj,v_proj")
317
+ lora_rank : Optional[int] = field(default=8)
318
+ lora_dropout : Optional[float] = field(default=0.1)
319
+ lora_alpha : Optional[float] = field(default=32.)
320
+ modules_to_save : Optional[str] = field(default=None)
321
+ debug_mode : Optional[bool] = field(default=False)
322
+ peft_path : Optional[str] = field(default=None)
323
+
324
+
325
+ logger = logging.getLogger(__name__)
326
+
327
+
328
+ def main():
329
+
330
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments))
331
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
332
+ # If we pass only one argument to the script and it's the path to a json file,
333
+ # let's parse it to get our arguments.
334
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
335
+ else:
336
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
337
+
338
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
339
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
340
+ send_example_telemetry("run_clm", model_args, data_args)
341
+
342
+ # Setup logging
343
+ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
344
+ level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
345
+ handlers=[logging.StreamHandler(sys.stdout)],)
346
+
347
+
348
+ if training_args.should_log:
349
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
350
+ transformers.utils.logging.set_verbosity_info()
351
+
352
+ log_level = training_args.get_process_log_level()
353
+ logger.setLevel(log_level)
354
+ datasets.utils.logging.set_verbosity(log_level)
355
+ transformers.utils.logging.set_verbosity(log_level)
356
+ transformers.utils.logging.enable_default_handler()
357
+ transformers.utils.logging.enable_explicit_format()
358
+ # transformers.tokenization_utils.logging.set_verbosity_warning()
359
+
360
+ # Log on each process the small summary:
361
+ logger.warning(
362
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
363
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
364
+ )
365
+
366
+ # Detecting last checkpoint.
367
+ last_checkpoint = None
368
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
369
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
370
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
371
+ raise ValueError(
372
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
373
+ "Use --overwrite_output_dir to overcome."
374
+ )
375
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
376
+ logger.info(
377
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
378
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
379
+ )
380
+
381
+ # Set seed before initializing model.
382
+ set_seed(training_args.seed)
383
+
384
+ config_kwargs = {
385
+ "cache_dir": model_args.cache_dir,
386
+ "revision": model_args.model_revision,
387
+ "use_auth_token": True if model_args.use_auth_token else None,
388
+ }
389
+ if model_args.config_name:
390
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
391
+ elif model_args.model_name_or_path:
392
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
393
+ else:
394
+ config = CONFIG_MAPPING[model_args.model_type]()
395
+ logger.warning("You are instantiating a new config instance from scratch.")
396
+ if model_args.config_overrides is not None:
397
+ logger.info(f"Overriding config: {model_args.config_overrides}")
398
+ config.update_from_string(model_args.config_overrides)
399
+ logger.info(f"New config: {config}")
400
+
401
+ tokenizer_kwargs = {
402
+ "cache_dir": model_args.cache_dir,
403
+ "use_fast": model_args.use_fast_tokenizer,
404
+ "revision": model_args.model_revision,
405
+ "use_auth_token": True if model_args.use_auth_token else None,
406
+ }
407
+ if model_args.tokenizer_name:
408
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
409
+ elif model_args.tokenizer_name_or_path:
410
+ tokenizer = LlamaTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **tokenizer_kwargs)
411
+ else:
412
+ raise ValueError(
413
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
414
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
415
+ )
416
+
417
+ # Preprocessing the datasets.
418
+ # First we tokenize all the texts.
419
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
420
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
421
+
422
+ def tokenize_function(examples):
423
+ with CaptureLogger(tok_logger) as cl:
424
+ output = tokenizer(examples["text"])
425
+ # clm input could be much much longer than block_size
426
+ if "Token indices sequence length is longer than the" in cl.out:
427
+ tok_logger.warning(
428
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
429
+ " before being passed to the model."
430
+ )
431
+ return output
432
+ if data_args.block_size is None:
433
+ block_size = tokenizer.model_max_length
434
+ if block_size > 1024:
435
+ logger.warning(
436
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
437
+ " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
438
+ " override this default with `--block_size xxx`."
439
+ )
440
+ block_size = 1024
441
+ else:
442
+ if data_args.block_size > tokenizer.model_max_length:
443
+ logger.warning(
444
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
445
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
446
+ )
447
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
448
+
449
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
450
+ def group_texts(examples):
451
+ # Concatenate all texts.
452
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
453
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
454
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
455
+ # customize this part to your needs.
456
+ if total_length >= block_size:
457
+ total_length = (total_length // block_size) * block_size
458
+ # Split by chunks of max_len.
459
+ result = {
460
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
461
+ for k, t in concatenated_examples.items()
462
+ }
463
+ result["labels"] = result["input_ids"].copy()
464
+ return result
465
+ with training_args.main_process_first(desc="dataset map tokenization and grouping"):
466
+ lm_datasets = []
467
+ path = Path(data_args.dataset_dir)
468
+ files = [file.name for file in path.glob("*.txt")]
469
+ if training_args.debug_mode is True:
470
+ files = [files[0]]
471
+ for idx, file in enumerate(files):
472
+ data_file = os.path.join(path, file)
473
+ filename = ''.join(file.split(".")[:-1])
474
+ cache_path = os.path.join(data_args.data_cache_dir, filename)
475
+ os.makedirs(cache_path, exist_ok=True)
476
+ try:
477
+ processed_dataset = datasets.load_from_disk(cache_path, keep_in_memory=False)
478
+ logger.info(f'training datasets-{filename} has been loaded from disk')
479
+ except Exception:
480
+ cache_dir = os.path.join(data_args.data_cache_dir, filename+"_text")
481
+ os.makedirs(cache_dir, exist_ok=True)
482
+ raw_dataset = load_dataset("text", data_files=data_file, cache_dir=cache_dir, keep_in_memory=False)
483
+ logger.info(f"{file} has been loaded")
484
+ tokenized_dataset = raw_dataset.map(
485
+ tokenize_function,
486
+ batched=True,
487
+ num_proc=data_args.preprocessing_num_workers,
488
+ remove_columns="text",
489
+ load_from_cache_file=True,
490
+ keep_in_memory=False,
491
+ cache_file_names = {k: os.path.join(cache_dir, 'tokenized.arrow') for k in raw_dataset},
492
+ desc="Running tokenizer on dataset",
493
+ )
494
+ grouped_datasets = tokenized_dataset.map(
495
+ group_texts,
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ load_from_cache_file=True,
499
+ keep_in_memory=False,
500
+ cache_file_names = {k: os.path.join(cache_dir, 'grouped.arrow') for k in tokenized_dataset},
501
+ desc=f"Grouping texts in chunks of {block_size}",
502
+ )
503
+ processed_dataset = grouped_datasets
504
+ processed_dataset.save_to_disk(cache_path)
505
+ if idx == 0:
506
+ lm_datasets = processed_dataset['train']
507
+ else:
508
+ assert lm_datasets.features.type == processed_dataset["train"].features.type
509
+ lm_datasets = concatenate_datasets([lm_datasets, processed_dataset["train"]])
510
+
511
+ lm_datasets = lm_datasets.train_test_split(test_size = data_args.validation_split_percentage)
512
+
513
+ if training_args.do_train:
514
+ train_dataset = lm_datasets['train']
515
+ if data_args.max_train_samples is not None:
516
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
517
+ train_dataset = train_dataset.select(range(max_train_samples))
518
+ logger.info(f"Num train_samples {len(train_dataset)}")
519
+ logger.info("training example:")
520
+ logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
521
+ if training_args.do_eval:
522
+ eval_dataset = lm_datasets["test"]
523
+ if data_args.max_eval_samples is not None:
524
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
525
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
526
+ logger.info(f"Num eval_samples {len(eval_dataset)}")
527
+ logger.info("training example:")
528
+ logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))
529
+
530
+
531
+
532
+ if model_args.model_name_or_path:
533
+ torch_dtype = (
534
+ model_args.torch_dtype
535
+ if model_args.torch_dtype in ["auto", None]
536
+ else getattr(torch, model_args.torch_dtype)
537
+ )
538
+ model = LlamaForCausalLM.from_pretrained(
539
+ model_args.model_name_or_path,
540
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
541
+ config=config,
542
+ cache_dir=model_args.cache_dir,
543
+ revision=model_args.model_revision,
544
+ use_auth_token=True if model_args.use_auth_token else None,
545
+ torch_dtype=torch_dtype,
546
+ low_cpu_mem_usage=True
547
+ )
548
+ else:
549
+ model = AutoModelForCausalLM.from_config(config)
550
+ n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
551
+ logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
552
+
553
+ model_vocab_size = model.get_output_embeddings().weight.size(0)
554
+
555
+ model.resize_token_embeddings(len(tokenizer))
556
+ if training_args.peft_path is not None:
557
+ logger.info("Peft from pre-trained model")
558
+ model = PeftModel.from_pretrained(model, training_args.peft_path)
559
+ else:
560
+ logger.info("Init new peft model")
561
+ target_modules = training_args.trainable.split(',')
562
+ modules_to_save = training_args.modules_to_save
563
+ if modules_to_save is not None:
564
+ modules_to_save = modules_to_save.split(',')
565
+ lora_rank = training_args.lora_rank
566
+ lora_dropout = training_args.lora_dropout
567
+ lora_alpha = training_args.lora_alpha
568
+ logger.info(f"target_modules: {target_modules}")
569
+ logger.info(f"lora_rank: {lora_rank}")
570
+ peft_config = LoraConfig(
571
+ task_type=TaskType.CAUSAL_LM,
572
+ target_modules=target_modules,
573
+ inference_mode=False,
574
+ r=lora_rank, lora_alpha=lora_alpha,
575
+ lora_dropout=lora_dropout,
576
+ modules_to_save=modules_to_save)
577
+ model = get_peft_model(model, peft_config)
578
+ model.print_trainable_parameters()
579
+ old_state_dict = model.state_dict
580
+ model.state_dict = (
581
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
582
+ ).__get__(model, type(model))
583
+
584
+ # Initialize our Trainer
585
+ trainer = Trainer(
586
+ model=model,
587
+ args=training_args,
588
+ train_dataset=train_dataset if training_args.do_train else None,
589
+ eval_dataset=eval_dataset if training_args.do_eval else None,
590
+ tokenizer=tokenizer,
591
+ data_collator=fault_tolerance_data_collator,
592
+ compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
593
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
594
+ if training_args.do_eval and not is_torch_tpu_available()
595
+ else None,
596
+ )
597
+ trainer.add_callback(SavePeftModelCallback)
598
+ # Training
599
+ if training_args.do_train:
600
+ checkpoint = None
601
+ if training_args.resume_from_checkpoint is not None:
602
+ checkpoint = training_args.resume_from_checkpoint
603
+ elif last_checkpoint is not None:
604
+ checkpoint = last_checkpoint
605
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
606
+
607
+ metrics = train_result.metrics
608
+
609
+ max_train_samples = (
610
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
611
+ )
612
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
613
+
614
+ trainer.log_metrics("train", metrics)
615
+ trainer.save_metrics("train", metrics)
616
+ trainer.save_state()
617
+
618
+ # Evaluation
619
+ if training_args.do_eval:
620
+ logger.info("*** Evaluate ***")
621
+
622
+ metrics = trainer.evaluate()
623
+
624
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
625
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
626
+ try:
627
+ perplexity = math.exp(metrics["eval_loss"])
628
+ except OverflowError:
629
+ perplexity = float("inf")
630
+ metrics["perplexity"] = perplexity
631
+
632
+ trainer.log_metrics("eval", metrics)
633
+ trainer.save_metrics("eval", metrics)
634
+
635
+
636
+ if __name__ == "__main__":
637
+ main()
04-gene-sft/.ipynb_checkpoints/run_clm_sft_with_peft-checkpoint.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ from dataclasses import dataclass, field
29
+ from typing import Optional
30
+ from pathlib import Path
31
+ import datasets
32
+ import torch
33
+ from build_dataset import build_instruction_dataset, DataCollatorForSupervisedDataset
34
+ import transformers
35
+ from transformers import (
36
+ CONFIG_MAPPING,
37
+ AutoConfig,
38
+ AutoModelForCausalLM,
39
+ LlamaForCausalLM,
40
+ LlamaTokenizer,
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ Trainer,
44
+ TrainingArguments,
45
+ set_seed,
46
+ )
47
+ from transformers.trainer_utils import get_last_checkpoint
48
+ from transformers.utils import send_example_telemetry
49
+ from transformers.utils.versions import require_version
50
+
51
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict
52
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
53
+
54
+ IGNORE_INDEX = -100
55
+ DEFAULT_PAD_TOKEN = "[PAD]"
56
+ DEFAULT_EOS_TOKEN = "</s>"
57
+ DEFAULT_BOS_TOKEN = "<s>"
58
+ DEFAULT_UNK_TOKEN = "<unk>"
59
+
60
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
61
+
62
+
63
+ class SavePeftModelCallback(transformers.TrainerCallback):
64
+ def save_model(self, args, state, kwargs):
65
+ if state.best_model_checkpoint is not None:
66
+ checkpoint_folder = os.path.join(state.best_model_checkpoint, "sft_lora_model")
67
+ else:
68
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
69
+
70
+ peft_model_path = os.path.join(checkpoint_folder, "sft_lora_model")
71
+ kwargs["model"].save_pretrained(peft_model_path)
72
+
73
+ if "tokenizer" in kwargs:
74
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
75
+ else:
76
+ kwargs["processing_class"].save_pretrained(peft_model_path)
77
+
78
+
79
+ def on_save(self, args, state, control, **kwargs):
80
+ self.save_model(args, state, kwargs)
81
+ return control
82
+
83
+ def on_train_end(self, args, state, control, **kwargs):
84
+ peft_model_path = os.path.join(args.output_dir, "sft_lora_model")
85
+ kwargs["model"].save_pretrained(peft_model_path)
86
+
87
+ if "tokenizer" in kwargs:
88
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
89
+ else:
90
+ kwargs["processing_class"].save_pretrained(peft_model_path)
91
+
92
+
93
+
94
+ @dataclass
95
+ class ModelArguments:
96
+ """
97
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
98
+ """
99
+
100
+ model_name_or_path: Optional[str] = field(
101
+ default=None,
102
+ metadata={
103
+ "help": (
104
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
105
+ )
106
+ },
107
+ )
108
+ tokenizer_name_or_path: Optional[str] = field(
109
+ default=None,
110
+ metadata={
111
+ "help": (
112
+ "The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
113
+ )
114
+ },
115
+ )
116
+
117
+ config_overrides: Optional[str] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": (
121
+ "Override some existing default config settings when a model is trained from scratch. Example: "
122
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
123
+ )
124
+ },
125
+ )
126
+ config_name: Optional[str] = field(
127
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
128
+ )
129
+ tokenizer_name: Optional[str] = field(
130
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
131
+ )
132
+ cache_dir: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
135
+ )
136
+ use_fast_tokenizer: bool = field(
137
+ default=True,
138
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
139
+ )
140
+ model_revision: str = field(
141
+ default="main",
142
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
143
+ )
144
+ use_auth_token: bool = field(
145
+ default=False,
146
+ metadata={
147
+ "help": (
148
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
149
+ "with private models)."
150
+ )
151
+ },
152
+ )
153
+ torch_dtype: Optional[str] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": (
157
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
158
+ "dtype will be automatically derived from the model's weights."
159
+ ),
160
+ "choices": ["auto", "bfloat16", "float16", "float32"],
161
+ },
162
+ )
163
+
164
+ def __post_init__(self):
165
+ if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
166
+ raise ValueError(
167
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
168
+ )
169
+
170
+
171
+ @dataclass
172
+ class DataTrainingArguments:
173
+ """
174
+ Arguments pertaining to what data we are going to input our model for training and eval.
175
+ """
176
+
177
+ dataset_dir: Optional[str] = field(
178
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
179
+ )
180
+
181
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
182
+ validation_file: Optional[str] = field(
183
+ default=None,
184
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
185
+ )
186
+
187
+ overwrite_cache: bool = field(
188
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
189
+ )
190
+ validation_split_percentage: Optional[float] = field(
191
+ default=0.05,
192
+ metadata={
193
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
194
+ },
195
+ )
196
+ preprocessing_num_workers: Optional[int] = field(
197
+ default=None,
198
+ metadata={"help": "The number of processes to use for the preprocessing."},
199
+ )
200
+ keep_linebreaks: bool = field(
201
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
202
+ )
203
+ data_cache_dir: Optional[str] = field(default=None, metadata={"help": "The datasets processed stored"})
204
+
205
+ max_seq_length: Optional[int] = field(default=512)
206
+
207
+
208
+ @dataclass
209
+ class MyTrainingArguments(TrainingArguments):
210
+ trainable : Optional[str] = field(default="q_proj,v_proj")
211
+ lora_rank : Optional[int] = field(default=8)
212
+ lora_dropout : Optional[float] = field(default=0.1)
213
+ lora_alpha : Optional[float] = field(default=32.)
214
+ modules_to_save : Optional[str] = field(default=None)
215
+ peft_path : Optional[str] = field(default=None)
216
+ force_resize_embeddings: bool = field(default=False)
217
+
218
+
219
+ logger = logging.getLogger(__name__)
220
+
221
+
222
+ def main():
223
+
224
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments))
225
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
226
+ # If we pass only one argument to the script and it's the path to a json file,
227
+ # let's parse it to get our arguments.
228
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
229
+ else:
230
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
231
+
232
+ send_example_telemetry("run_clm", model_args, data_args)
233
+
234
+ # Setup logging
235
+ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
236
+ level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
237
+ handlers=[logging.StreamHandler(sys.stdout)],)
238
+
239
+
240
+ if training_args.should_log:
241
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
242
+ transformers.utils.logging.set_verbosity_info()
243
+
244
+ log_level = training_args.get_process_log_level()
245
+ logger.setLevel(log_level)
246
+ datasets.utils.logging.set_verbosity(log_level)
247
+ transformers.utils.logging.set_verbosity(log_level)
248
+ transformers.utils.logging.enable_default_handler()
249
+ transformers.utils.logging.enable_explicit_format()
250
+ # transformers.tokenization_utils.logging.set_verbosity_warning()
251
+
252
+ # Log on each process the small summary:
253
+ logger.warning(
254
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
255
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
256
+ )
257
+
258
+ # Detecting last checkpoint.
259
+ last_checkpoint = None
260
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
261
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
262
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
263
+ raise ValueError(
264
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
265
+ "Use --overwrite_output_dir to overcome."
266
+ )
267
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
268
+ logger.info(
269
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
270
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
271
+ )
272
+
273
+ # Set seed before initializing model.
274
+ set_seed(training_args.seed)
275
+
276
+ config_kwargs = {
277
+ "cache_dir": model_args.cache_dir,
278
+ "revision": model_args.model_revision,
279
+ "use_auth_token": True if model_args.use_auth_token else None,
280
+ }
281
+ if model_args.config_name:
282
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
283
+ elif model_args.model_name_or_path:
284
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
285
+ else:
286
+ config = CONFIG_MAPPING[model_args.model_type]()
287
+ logger.warning("You are instantiating a new config instance from scratch.")
288
+ if model_args.config_overrides is not None:
289
+ logger.info(f"Overriding config: {model_args.config_overrides}")
290
+ config.update_from_string(model_args.config_overrides)
291
+ logger.info(f"New config: {config}")
292
+
293
+ tokenizer_kwargs = {
294
+ "cache_dir": model_args.cache_dir,
295
+ "use_fast": model_args.use_fast_tokenizer,
296
+ "revision": model_args.model_revision,
297
+ "use_auth_token": True if model_args.use_auth_token else None,
298
+ }
299
+ if model_args.tokenizer_name:
300
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
301
+ elif model_args.tokenizer_name_or_path:
302
+ tokenizer = LlamaTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **tokenizer_kwargs)
303
+ else:
304
+ raise ValueError(
305
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
306
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
307
+ )
308
+
309
+
310
+ if tokenizer.pad_token is None:
311
+ print(f"Adding pad token {DEFAULT_PAD_TOKEN}")
312
+ tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
313
+
314
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
315
+ eval_dataset=None
316
+ train_dataset = None
317
+
318
+ if training_args.do_train:
319
+ with training_args.main_process_first(desc="loading and tokenization"):
320
+ path = Path(data_args.dataset_dir)
321
+ files = [os.path.join(path,file.name) for file in path.glob("*.json")]
322
+ logger.info(f"Training files: {' '.join(files)}")
323
+ train_dataset = build_instruction_dataset(
324
+ data_path=files,
325
+ tokenizer=tokenizer,
326
+ max_seq_length=data_args.max_seq_length,
327
+ data_cache_dir = None,
328
+ preprocessing_num_workers = data_args.preprocessing_num_workers)
329
+ logger.info(f"Num train_samples {len(train_dataset)}")
330
+ logger.info("training example:")
331
+ logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
332
+ if training_args.do_eval:
333
+ with training_args.main_process_first(desc="loading and tokenization"):
334
+ files = [data_args.validation_file]
335
+ logger.info(f"Evaluation files: {' '.join(files)}")
336
+ eval_dataset = build_instruction_dataset(
337
+ data_path=files,
338
+ tokenizer=tokenizer,
339
+ max_seq_length=data_args.max_seq_length,
340
+ data_cache_dir = None,
341
+ preprocessing_num_workers = data_args.preprocessing_num_workers)
342
+ logger.info(f"Num eval_samples {len(eval_dataset)}")
343
+ logger.info("eval example:")
344
+ logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))
345
+
346
+ if model_args.model_name_or_path:
347
+ torch_dtype = (
348
+ model_args.torch_dtype
349
+ if model_args.torch_dtype in ["auto", None]
350
+ else getattr(torch, model_args.torch_dtype)
351
+ )
352
+ model = LlamaForCausalLM.from_pretrained(
353
+ model_args.model_name_or_path,
354
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
355
+ config=config,
356
+ cache_dir=model_args.cache_dir,
357
+ revision=model_args.model_revision,
358
+ use_auth_token=True if model_args.use_auth_token else None,
359
+ torch_dtype=torch_dtype,
360
+ low_cpu_mem_usage=True
361
+ )
362
+ else:
363
+ model = AutoModelForCausalLM.from_config(config)
364
+ n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
365
+ logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
366
+
367
+ logger.info(f"len(tokenizer):{len(tokenizer)}")
368
+ embedding_size = model.get_input_embeddings().weight.shape[0]
369
+ if len(tokenizer) != embedding_size:
370
+ logger.info("resize the embedding size by the size of the tokenizer")
371
+ model.resize_token_embeddings(len(tokenizer))
372
+
373
+ if training_args.peft_path is not None:
374
+ logger.info("Peft from pre-trained model")
375
+ model = PeftModel.from_pretrained(model, training_args.peft_path)
376
+ else:
377
+ logger.info("Init new peft model")
378
+ target_modules = training_args.trainable.split(',')
379
+ modules_to_save = training_args.modules_to_save
380
+ if modules_to_save is not None:
381
+ modules_to_save = modules_to_save.split(',')
382
+ lora_rank = training_args.lora_rank
383
+ lora_dropout = training_args.lora_dropout
384
+ lora_alpha = training_args.lora_alpha
385
+ logger.info(f"target_modules: {target_modules}")
386
+ logger.info(f"lora_rank: {lora_rank}")
387
+ peft_config = LoraConfig(
388
+ task_type=TaskType.CAUSAL_LM,
389
+ target_modules=target_modules,
390
+ inference_mode=False,
391
+ r=lora_rank, lora_alpha=lora_alpha,
392
+ lora_dropout=lora_dropout,
393
+ modules_to_save=modules_to_save)
394
+ model = get_peft_model(model, peft_config)
395
+
396
+ #model.base_model.tie_weights()
397
+ model.print_trainable_parameters()
398
+ logger.info(f"model.modules_to_save: {model.modules_to_save}")
399
+ old_state_dict = model.state_dict
400
+ model.state_dict = (
401
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
402
+ ).__get__(model, type(model))
403
+
404
+ # Initialize our Trainer
405
+ trainer = Trainer(
406
+ model=model,
407
+ args=training_args,
408
+ train_dataset=train_dataset,
409
+ eval_dataset=eval_dataset,
410
+ tokenizer=tokenizer,
411
+ data_collator=data_collator,
412
+ )
413
+ trainer.add_callback(SavePeftModelCallback)
414
+
415
+ # Training
416
+ if training_args.do_train:
417
+ checkpoint = None
418
+ if training_args.resume_from_checkpoint is not None:
419
+ checkpoint = training_args.resume_from_checkpoint
420
+ elif last_checkpoint is not None:
421
+ checkpoint = last_checkpoint
422
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
423
+
424
+ metrics = train_result.metrics
425
+
426
+ metrics["train_samples"] = len(train_dataset)
427
+
428
+ trainer.log_metrics("train", metrics)
429
+ trainer.save_metrics("train", metrics)
430
+ trainer.save_state()
431
+
432
+ # Evaluation
433
+ if training_args.do_eval:
434
+ logger.info("*** Evaluate ***")
435
+
436
+ metrics = trainer.evaluate()
437
+ metrics["eval_samples"] =len(eval_dataset)
438
+ try:
439
+ perplexity = math.exp(metrics["eval_loss"])
440
+ except OverflowError:
441
+ perplexity = float("inf")
442
+ metrics["perplexity"] = perplexity
443
+
444
+ trainer.log_metrics("eval", metrics)
445
+ trainer.save_metrics("eval", metrics)
446
+
447
+
448
+ if __name__ == "__main__":
449
+ main()
04-gene-sft/.ipynb_checkpoints/run_pt-checkpoint.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr=2e-4
2
+ lora_rank=8
3
+ lora_alpha=32
4
+ lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
5
+ modules_to_save="embed_tokens,lm_head"
6
+ lora_dropout=0.05
7
+
8
+ pretrained_model=./llama-7b-hf
9
+ dna_eng_tokenizer_path=./merged_dna_eng_tokenizer_hf
10
+ dataset_dir=./train_data
11
+ data_cache=temp_data_cache_dir
12
+ per_device_train_batch_size=32
13
+ per_device_eval_batch_size=32
14
+ gradient_accumulation_steps=8
15
+ output_dir=dnahlm_llama_7b
16
+
17
+ deepspeed_config_file=ds_zero2_no_offload.json
18
+
19
+ torchrun --nnodes 1 --nproc_per_node 6 run_clm_pt_with_peft.py \
20
+ --deepspeed ${deepspeed_config_file} \
21
+ --model_name_or_path ${pretrained_model} \
22
+ --tokenizer_name_or_path ${dna_eng_tokenizer_path} \
23
+ --dataset_dir ${dataset_dir} \
24
+ --data_cache_dir ${data_cache} \
25
+ --validation_split_percentage 0.001 \
26
+ --per_device_train_batch_size ${per_device_train_batch_size} \
27
+ --per_device_eval_batch_size ${per_device_eval_batch_size} \
28
+ --do_train \
29
+ --seed $RANDOM \
30
+ --fp16 \
31
+ --num_train_epochs 1 \
32
+ --lr_scheduler_type cosine \
33
+ --learning_rate ${lr} \
34
+ --warmup_ratio 0.05 \
35
+ --weight_decay 0.01 \
36
+ --logging_strategy steps \
37
+ --logging_steps 10 \
38
+ --save_strategy steps \
39
+ --save_total_limit 3 \
40
+ --save_steps 200 \
41
+ --gradient_accumulation_steps ${gradient_accumulation_steps} \
42
+ --preprocessing_num_workers 128 \
43
+ --block_size 512 \
44
+ --output_dir ${output_dir} \
45
+ --overwrite_output_dir \
46
+ --ddp_timeout 30000 \
47
+ --logging_first_step True \
48
+ --lora_rank ${lora_rank} \
49
+ --lora_alpha ${lora_alpha} \
50
+ --trainable ${lora_trainable} \
51
+ --modules_to_save ${modules_to_save} \
52
+ --lora_dropout ${lora_dropout} \
53
+ --torch_dtype float16 \
54
+ --gradient_checkpointing \
55
+ --ddp_find_unused_parameters False
04-gene-sft/.ipynb_checkpoints/run_sft-checkpoint.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr=1e-4
2
+ lora_rank=8
3
+ lora_alpha=32
4
+ lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
5
+ modules_to_save="embed_tokens,lm_head"
6
+ lora_dropout=0.05
7
+
8
+ pretrained_model=dnahlm-merge-hf
9
+ chinese_tokenizer_path=dnahlm-merge-hf
10
+ dataset_dir=sft_data
11
+ per_device_train_batch_size=32
12
+ per_device_eval_batch_size=32
13
+ gradient_accumulation_steps=8
14
+ output_dir=dnahlm-llama7b-sft
15
+ #peft_model=peft_model/dir
16
+ validation_file=val_data.json
17
+
18
+ deepspeed_config_file=ds_zero2_no_offload.json
19
+
20
+ torchrun --nnodes 1 --nproc_per_node 6 run_clm_sft_with_peft.py \
21
+ --deepspeed ${deepspeed_config_file} \
22
+ --model_name_or_path ${pretrained_model} \
23
+ --tokenizer_name_or_path ${chinese_tokenizer_path} \
24
+ --dataset_dir ${dataset_dir} \
25
+ --validation_split_percentage 0.001 \
26
+ --per_device_train_batch_size ${per_device_train_batch_size} \
27
+ --per_device_eval_batch_size ${per_device_eval_batch_size} \
28
+ --do_train \
29
+ --do_eval \
30
+ --seed $RANDOM \
31
+ --fp16 \
32
+ --num_train_epochs 8 \
33
+ --lr_scheduler_type cosine \
34
+ --learning_rate ${lr} \
35
+ --warmup_ratio 0.03 \
36
+ --weight_decay 0 \
37
+ --logging_strategy steps \
38
+ --logging_steps 10 \
39
+ --save_strategy steps \
40
+ --save_total_limit 3 \
41
+ --evaluation_strategy steps \
42
+ --eval_steps 100 \
43
+ --save_steps 200 \
44
+ --gradient_accumulation_steps ${gradient_accumulation_steps} \
45
+ --preprocessing_num_workers 4 \
46
+ --max_seq_length 512 \
47
+ --output_dir ${output_dir} \
48
+ --overwrite_output_dir \
49
+ --ddp_timeout 30000 \
50
+ --logging_first_step True \
51
+ --lora_rank ${lora_rank} \
52
+ --lora_alpha ${lora_alpha} \
53
+ --trainable ${lora_trainable} \
54
+ --modules_to_save ${modules_to_save} \
55
+ --lora_dropout ${lora_dropout} \
56
+ --torch_dtype float16 \
57
+ --validation_file ${validation_file} \
58
+ --gradient_checkpointing \
59
+ --ddp_find_unused_parameters False
04-gene-sft/1-finetue-intro.ipynb CHANGED
@@ -31,6 +31,12 @@
31
  "\"yuanzhoulvpi/gpt2_chinese\", num_labels=2\n",
32
  ")\n",
33
  "\n",
 
 
 
 
 
 
34
  "\n",
35
  "\n",
36
  "2 如果是把分类问题,改成指令微调的模式,就是像\n",
@@ -174,7 +180,7 @@
174
  },
175
  {
176
  "cell_type": "code",
177
- "execution_count": null,
178
  "id": "64312191-423f-4a18-aa0c-036374e93fb2",
179
  "metadata": {},
180
  "outputs": [],
@@ -192,10 +198,44 @@
192
  },
193
  {
194
  "cell_type": "code",
195
- "execution_count": null,
196
  "id": "32c16282-f9f1-4545-b522-daf2b39b4ead",
197
  "metadata": {},
198
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  "source": [
200
  "#原始模型\n",
201
  "from transformers import AutoModel\n",
@@ -205,10 +245,55 @@
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": null,
209
  "id": "1149163f-4d89-472e-8d45-ebcbb5f9575e",
210
  "metadata": {},
211
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  "source": [
213
  "#分类微调模型\n",
214
  "from transformers import AutoModelForSequenceClassification\n",
@@ -218,16 +303,110 @@
218
  },
219
  {
220
  "cell_type": "code",
221
- "execution_count": 1,
222
  "id": "09735059-507c-48c4-893f-ca0da21ce5e8",
223
  "metadata": {},
224
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  "source": [
226
  "#指令微调模型\n",
227
  "from transformers import AutoModelForCausalLM\n",
228
- "sft_model = AutoModelForMaskedLM.from_pretrained(\"gpt2\")\n",
229
  "sft_model"
230
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  }
232
  ],
233
  "metadata": {
 
31
  "\"yuanzhoulvpi/gpt2_chinese\", num_labels=2\n",
32
  ")\n",
33
  "\n",
34
+ "对应的训练数据一般是这样的:\n",
35
+ "\n",
36
+ "| seq | label |\n",
37
+ "|------------------------------|-------|\n",
38
+ "| 他家的奶茶超级好喝。。。 | 1 |\n",
39
+ "| 他家的奶茶超级难喝。。。 | 0 |\n",
40
  "\n",
41
  "\n",
42
  "2 如果是把分类问题,改成指令微调的模式,就是像\n",
 
180
  },
181
  {
182
  "cell_type": "code",
183
+ "execution_count": 1,
184
  "id": "64312191-423f-4a18-aa0c-036374e93fb2",
185
  "metadata": {},
186
  "outputs": [],
 
198
  },
199
  {
200
  "cell_type": "code",
201
+ "execution_count": 2,
202
  "id": "32c16282-f9f1-4545-b522-daf2b39b4ead",
203
  "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "data": {
207
+ "text/plain": [
208
+ "GPT2Model(\n",
209
+ " (wte): Embedding(50257, 768)\n",
210
+ " (wpe): Embedding(1024, 768)\n",
211
+ " (drop): Dropout(p=0.1, inplace=False)\n",
212
+ " (h): ModuleList(\n",
213
+ " (0-11): 12 x GPT2Block(\n",
214
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
215
+ " (attn): GPT2SdpaAttention(\n",
216
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
217
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
218
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
219
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
220
+ " )\n",
221
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
222
+ " (mlp): GPT2MLP(\n",
223
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
224
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
225
+ " (act): NewGELUActivation()\n",
226
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
227
+ " )\n",
228
+ " )\n",
229
+ " )\n",
230
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
231
+ ")"
232
+ ]
233
+ },
234
+ "execution_count": 2,
235
+ "metadata": {},
236
+ "output_type": "execute_result"
237
+ }
238
+ ],
239
  "source": [
240
  "#原始模型\n",
241
  "from transformers import AutoModel\n",
 
245
  },
246
  {
247
  "cell_type": "code",
248
+ "execution_count": 3,
249
  "id": "1149163f-4d89-472e-8d45-ebcbb5f9575e",
250
  "metadata": {},
251
+ "outputs": [
252
+ {
253
+ "name": "stderr",
254
+ "output_type": "stream",
255
+ "text": [
256
+ "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']\n",
257
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
258
+ ]
259
+ },
260
+ {
261
+ "data": {
262
+ "text/plain": [
263
+ "GPT2ForSequenceClassification(\n",
264
+ " (transformer): GPT2Model(\n",
265
+ " (wte): Embedding(50257, 768)\n",
266
+ " (wpe): Embedding(1024, 768)\n",
267
+ " (drop): Dropout(p=0.1, inplace=False)\n",
268
+ " (h): ModuleList(\n",
269
+ " (0-11): 12 x GPT2Block(\n",
270
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
271
+ " (attn): GPT2SdpaAttention(\n",
272
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
273
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
274
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
275
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
276
+ " )\n",
277
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
278
+ " (mlp): GPT2MLP(\n",
279
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
280
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
281
+ " (act): NewGELUActivation()\n",
282
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
283
+ " )\n",
284
+ " )\n",
285
+ " )\n",
286
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
287
+ " )\n",
288
+ " (score): Linear(in_features=768, out_features=2, bias=False)\n",
289
+ ")"
290
+ ]
291
+ },
292
+ "execution_count": 3,
293
+ "metadata": {},
294
+ "output_type": "execute_result"
295
+ }
296
+ ],
297
  "source": [
298
  "#分类微调模型\n",
299
  "from transformers import AutoModelForSequenceClassification\n",
 
303
  },
304
  {
305
  "cell_type": "code",
306
+ "execution_count": 5,
307
  "id": "09735059-507c-48c4-893f-ca0da21ce5e8",
308
  "metadata": {},
309
+ "outputs": [
310
+ {
311
+ "data": {
312
+ "text/plain": [
313
+ "GPT2LMHeadModel(\n",
314
+ " (transformer): GPT2Model(\n",
315
+ " (wte): Embedding(50257, 768)\n",
316
+ " (wpe): Embedding(1024, 768)\n",
317
+ " (drop): Dropout(p=0.1, inplace=False)\n",
318
+ " (h): ModuleList(\n",
319
+ " (0-11): 12 x GPT2Block(\n",
320
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
321
+ " (attn): GPT2SdpaAttention(\n",
322
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
323
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
324
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
325
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
326
+ " )\n",
327
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
328
+ " (mlp): GPT2MLP(\n",
329
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
330
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
331
+ " (act): NewGELUActivation()\n",
332
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
333
+ " )\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
337
+ " )\n",
338
+ " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
339
+ ")"
340
+ ]
341
+ },
342
+ "execution_count": 5,
343
+ "metadata": {},
344
+ "output_type": "execute_result"
345
+ }
346
+ ],
347
  "source": [
348
  "#指令微调模型\n",
349
  "from transformers import AutoModelForCausalLM\n",
350
+ "sft_model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
351
  "sft_model"
352
  ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": 6,
357
+ "id": "d1407cbe-4996-4898-a135-e26d28da2a2a",
358
+ "metadata": {},
359
+ "outputs": [
360
+ {
361
+ "data": {
362
+ "text/plain": [
363
+ "GPT2LMHeadModel(\n",
364
+ " (transformer): GPT2Model(\n",
365
+ " (wte): Embedding(50257, 768)\n",
366
+ " (wpe): Embedding(1024, 768)\n",
367
+ " (drop): Dropout(p=0.1, inplace=False)\n",
368
+ " (h): ModuleList(\n",
369
+ " (0-11): 12 x GPT2Block(\n",
370
+ " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
371
+ " (attn): GPT2SdpaAttention(\n",
372
+ " (c_attn): Conv1D(nf=2304, nx=768)\n",
373
+ " (c_proj): Conv1D(nf=768, nx=768)\n",
374
+ " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
375
+ " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
376
+ " )\n",
377
+ " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
378
+ " (mlp): GPT2MLP(\n",
379
+ " (c_fc): Conv1D(nf=3072, nx=768)\n",
380
+ " (c_proj): Conv1D(nf=768, nx=3072)\n",
381
+ " (act): NewGELUActivation()\n",
382
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
383
+ " )\n",
384
+ " )\n",
385
+ " )\n",
386
+ " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
387
+ " )\n",
388
+ " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
389
+ ")"
390
+ ]
391
+ },
392
+ "execution_count": 6,
393
+ "metadata": {},
394
+ "output_type": "execute_result"
395
+ }
396
+ ],
397
+ "source": [
398
+ "from transformers import GPT2LMHeadModel\n",
399
+ "gpt2_model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
400
+ "gpt2_model"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "id": "92fc8e55-2d90-4694-b8df-90885d08d51a",
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": []
410
  }
411
  ],
412
  "metadata": {
04-gene-sft/2-gpt2-instruction-ft.ipynb CHANGED
@@ -8,6 +8,123 @@
8
  "# 4.2 基于GPT2的指令微调"
9
  ]
10
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  {
12
  "cell_type": "code",
13
  "execution_count": null,
 
8
  "# 4.2 基于GPT2的指令微调"
9
  ]
10
  },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "02cd6e13-bbfb-413a-8236-ff092456fd1c",
14
+ "metadata": {},
15
+ "source": [
16
+ "我还是用第二章中的分类的例子,使用指令微调的形式,来再次解决分类问题。\n",
17
+ "\n",
18
+ "使用 GPT-2 进行文本分类的两种方法:**使用 GPT-2 的分类头(Classification Header)** 和 **将分类任务转换为指令微调**,在思路、实现、优劣势和适用场景上存在明显差异。以下是详细对比:\n",
19
+ "\n",
20
+ "---\n",
21
+ "\n",
22
+ "### **1. 核心思路**\n",
23
+ "\n",
24
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
25
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
26
+ "| **基本概念** | 在 GPT-2 顶部添加一个分类头(通常是一个线性层),直接预测分类标签。 | 将分类任务转化为自然语言指令,模型通过微调理解并完成指令形式的任务。 |\n",
27
+ "| **实现方式** | 修改 GPT-2 模型,添加 `num_labels` 分类头并定义分类损失函数。 | 构建任务指令数据(Instruction + Input + Output),然后微调模型。 |\n",
28
+ "| **数据形式** | 文本与其分类标签的直接映射。 | 文本通过指令转化为生成任务。例如:<br>`Input`: 文章内容<br>`Output`: 分类结果。 |\n",
29
+ "\n",
30
+ "---\n",
31
+ "\n",
32
+ "### **2. 数据格式**\n",
33
+ "\n",
34
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
35
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
36
+ "| **数据格式** | - 输入:文本 <br>- 标签:离散类别标签(如 0, 1, 2)。 | - 指令:自然语言描述任务(如 \"请分类以下文本\")。<br>- 输入:分类文本。<br>- 输出:分类结果(文本形式)。 |\n",
37
+ "| **示例** | 输入:`\"This is a happy day!\"`<br>标签:`1`(表示积极) | `Instruction`: \"请对以下文本进行情感分类\"<br>`Input`: `\"This is a happy day!\"`<br>`Output`: `\"积极\"` |\n",
38
+ "\n",
39
+ "---\n",
40
+ "\n",
41
+ "### **3. 模型结构**\n",
42
+ "\n",
43
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
44
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
45
+ "| **模型结构** | - GPT-2 + 分类头(线性层)。 | - GPT-2 原始结构,无需额外的分类头。 |\n",
46
+ "| **损失函数** | - 使用交叉熵损失(Cross Entropy Loss)。 | - 使用自回归的语言建模损失(Language Modeling Loss)。 |\n",
47
+ "\n",
48
+ "---\n",
49
+ "\n",
50
+ "### **4. 训练过程**\n",
51
+ "\n",
52
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
53
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
54
+ "| **微调对象** | 主要微调分类头部分的参数(可选择冻结 GPT-2 的主干部分)。 | 微调整个 GPT-2 模型(或使用参数高效微调如 LoRA)。 |\n",
55
+ "| **标签处理** | 离散化标签(如 0, 1, 2)。 | 标签转化为自然语言(如“积极”、“中立”、“消极”)。 |\n",
56
+ "| **训练难度** | - 简单,标准分类任务流程。<br>- 数据需求较小,适合小规模微调。 | - 复杂,需要构造高质量的指令数据集。<br>- 数据需求较大,适合多任务场景。 |\n",
57
+ "\n",
58
+ "---\n",
59
+ "\n",
60
+ "### **5. 优缺点分析**\n",
61
+ "\n",
62
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
63
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
64
+ "| **优点** | - 训练速度快,计算资源需求较低。<br>- 实现简单,适合单一任务。 | - 泛化能力��,支持多任务扩展。<br>- 与多任务微调和开放式生成兼容。 |\n",
65
+ "| **缺点** | - 只能处理分类任务,难以扩展为其他任务。<br>- 需要人工调整分类头和损失函数。 | - 数据构造复杂且对数据质量依赖较高。<br>- 训练资源需求较大,训练时间较长。 |\n",
66
+ "\n",
67
+ "---\n",
68
+ "\n",
69
+ "### **6. 适用场景**\n",
70
+ "\n",
71
+ "| **方法** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
72
+ "|-----------------------------|-------------------------------------------------------------|-------------------------------------------------------|\n",
73
+ "| **适用场景** | - 单任务文本分类,如情感分析、垃圾邮件检测等。 | - 多任务场景,支持分类、翻译、摘要等任务的统一处理。 |\n",
74
+ "| **数据规模** | 适合小数据集,数千到数万条数据即可训练效果良好。 | 适合大数据集,特别是多任务、多领域的数据集。 |\n",
75
+ "| **需求类型** | 专注于提高单一任务的分类准确率。 | 需要增强模型的多任务泛化能力,同时提升用户交互体验。 |\n",
76
+ "\n",
77
+ "---\n",
78
+ "\n",
79
+ "### **7. 综合对比总结**\n",
80
+ "\n",
81
+ "| **维度** | **使用 GPT-2 分类头** | **转换为指令微调** |\n",
82
+ "|-------------------------|--------------------------------------------------------------|-------------------------------------------------------|\n",
83
+ "| **实现复杂度** | 较低,直接添加分类头并使用标准分类流程即可完成。 | 较高,需要构造高质量指令数据,并调整训练流程。 |\n",
84
+ "| **资源需求** | 较低,仅需调整分类头部分,训练时间和显存消耗较少。 | 较高,需要微调整个模型,且对数据和算力需求更大。 |\n",
85
+ "| **性能表现** | 对单一分类任务效果较好,但泛化能力较弱。 | 在多任务、多样化分类场景中表现更强,且可扩展为其他任务类型。 |\n",
86
+ "| **扩展性** | 较差,仅适用于当前任务,难以迁移到其他任务。 | 较强,可适应多任务指令和开放式生成场景。 |\n",
87
+ "\n",
88
+ "---\n",
89
+ "\n",
90
+ "### **选择建议**\n",
91
+ "\n",
92
+ "1. **使用 GPT-2 分类头**:\n",
93
+ " - 如果任务是单一分类问题(如情感分析、垃圾邮件检测),并且数据量有限,推荐使用分类头方法。\n",
94
+ " - 适合快速实现和部署,无需复杂的预处理和指令数据集构建。\n",
95
+ "\n",
96
+ "2. **转换为指令微调**:\n",
97
+ " - 如果任务需要多样化(分类+生成+翻译等),或需要对未见任务有更好的泛化能力,推荐使用指令微调。\n",
98
+ " - 适合多任务、多场景部署,尤其是在 ChatGPT 风格的应用中更为适用。\n",
99
+ "\n",
100
+ "通过综合任务需求、数据规模和资源条件选择合适的方法,能够有效提升模型性能并实现更广泛的适用性。\n",
101
+ "\n",
102
+ "\n",
103
+ "原始的数据格式如下:\n",
104
+ "| sequence | label | label_name |\n",
105
+ "|--------------------------------------------------------|-------|----------------|\n",
106
+ "| TATATTTTCTCAGCTGAGTTAATTAGTTTCACTAGTTAACTGAGAATAAAAGAA | 1 | promoter |\n",
107
+ "| TGGGGAGGGTCCGGTGTTAGTTAGATACATCCCCAGACCCACACCCCGGATAGA | 0 | Non-promoter |\n",
108
+ "\n",
109
+ "转成指令的格式为:\n",
110
+ "```\n",
111
+ "{'instruction': 'Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.', \n",
112
+ "'input': 'CATGCGGGTCG...', \n",
113
+ "'output': 'Non-promoter'}\n",
114
+ "```\n",
115
+ "\n",
116
+ "然后写成指令微调数据格式,当做一般的文本进行训练:\n",
117
+ "```\n",
118
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
119
+ "### Instruction:\n",
120
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
121
+ "### Input:\n",
122
+ "TCTTTCTCTTCTGTATCATTCTACTT...\n",
123
+ "### Response:\n",
124
+ "Non-promoter\n",
125
+ "```\n"
126
+ ]
127
+ },
128
  {
129
  "cell_type": "code",
130
  "execution_count": null,
04-gene-sft/3-llama-expand-dict.ipynb CHANGED
@@ -114,10 +114,18 @@
114
  },
115
  {
116
  "cell_type": "code",
117
- "execution_count": null,
118
  "id": "19a06b82-31b8-48cb-9c83-ec016da2da8a",
119
  "metadata": {},
120
- "outputs": [],
 
 
 
 
 
 
 
 
121
  "source": [
122
  "from sentencepiece import SentencePieceProcessor\n",
123
  "model_path = \"gene_bpe_seg.model\"\n",
@@ -147,7 +155,7 @@
147
  },
148
  {
149
  "cell_type": "code",
150
- "execution_count": null,
151
  "id": "3bafcc33-2923-4026-bc39-c6ec716d2e3c",
152
  "metadata": {},
153
  "outputs": [],
@@ -161,10 +169,28 @@
161
  },
162
  {
163
  "cell_type": "code",
164
- "execution_count": null,
165
  "id": "66cb86ed-3225-4bb0-8aca-6005bc918d03",
166
  "metadata": {},
167
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  "source": [
169
  "llama_tokenizer_dir = \"llama-7b-hf\" \n",
170
  "dna_sp_model_file = \"gene_bpe_seg.model\"\n",
@@ -188,10 +214,20 @@
188
  },
189
  {
190
  "cell_type": "code",
191
- "execution_count": null,
192
  "id": "7ba4240e-bc08-4be0-8ca3-c4e7a47fa055",
193
  "metadata": {},
194
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
195
  "source": [
196
  "## Add dna tokens to LLaMA tokenizer\n",
197
  "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n",
@@ -210,10 +246,18 @@
210
  },
211
  {
212
  "cell_type": "code",
213
- "execution_count": null,
214
  "id": "a240a7d8-c1a9-4473-a5c5-157a25f97c16",
215
  "metadata": {},
216
- "outputs": [],
 
 
 
 
 
 
 
 
217
  "source": [
218
  "## Save\n",
219
  "output_sp_dir = 'merged_gene_eng_tokenizer_sp'\n",
@@ -229,10 +273,25 @@
229
  },
230
  {
231
  "cell_type": "code",
232
- "execution_count": null,
233
  "id": "cbd1f648-f8a0-4f16-b516-2ce3e7c7cfee",
234
  "metadata": {},
235
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  "source": [
237
  "# Test\n",
238
  "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
@@ -246,6 +305,14 @@
246
  "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
247
  "print(f\"Tokenized by GENE-LLaMA tokenizer:{dna_llama_tokenizer.tokenize(text)}\")"
248
  ]
 
 
 
 
 
 
 
 
249
  }
250
  ],
251
  "metadata": {
 
114
  },
115
  {
116
  "cell_type": "code",
117
+ "execution_count": 1,
118
  "id": "19a06b82-31b8-48cb-9c83-ec016da2da8a",
119
  "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "['▁TCG', 'ACGGC', 'ACGCG', 'ACAGC', 'AGCG', 'AGCCCC', 'GCGC', 'ACCCG', 'AGCGCG', 'AKCG', 'FVGP', 'MV', 'HLKV', 'HLE', 'ADV', 'ASSC', 'RS', 'AVI', 'YL', 'TS', 'EEP', 'FEG', 'VLGL', 'RLKE', 'GI', 'AI', 'TGC', 'WPR', 'WP', 'DEM', 'DE', 'RS', 'AVW', 'RV', 'EPY', 'TR', 'HFG', 'RVL', 'YS', 'FGV']\n"
126
+ ]
127
+ }
128
+ ],
129
  "source": [
130
  "from sentencepiece import SentencePieceProcessor\n",
131
  "model_path = \"gene_bpe_seg.model\"\n",
 
155
  },
156
  {
157
  "cell_type": "code",
158
+ "execution_count": 2,
159
  "id": "3bafcc33-2923-4026-bc39-c6ec716d2e3c",
160
  "metadata": {},
161
  "outputs": [],
 
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": 3,
173
  "id": "66cb86ed-3225-4bb0-8aca-6005bc918d03",
174
  "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message\n"
181
+ ]
182
+ },
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "32000 60000\n",
188
+ "['<s>', '</s>', '<unk>']\n",
189
+ "[1, 2, 0]\n",
190
+ "{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}\n"
191
+ ]
192
+ }
193
+ ],
194
  "source": [
195
  "llama_tokenizer_dir = \"llama-7b-hf\" \n",
196
  "dna_sp_model_file = \"gene_bpe_seg.model\"\n",
 
214
  },
215
  {
216
  "cell_type": "code",
217
+ "execution_count": 4,
218
  "id": "7ba4240e-bc08-4be0-8ca3-c4e7a47fa055",
219
  "metadata": {},
220
+ "outputs": [
221
+ {
222
+ "name": "stdout",
223
+ "output_type": "stream",
224
+ "text": [
225
+ "32000\n",
226
+ "Before:32000\n",
227
+ "New model pieces: 91643\n"
228
+ ]
229
+ }
230
+ ],
231
  "source": [
232
  "## Add dna tokens to LLaMA tokenizer\n",
233
  "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n",
 
246
  },
247
  {
248
  "cell_type": "code",
249
+ "execution_count": 5,
250
  "id": "a240a7d8-c1a9-4473-a5c5-157a25f97c16",
251
  "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "name": "stdout",
255
+ "output_type": "stream",
256
+ "text": [
257
+ "gene-LLaMA tokenizer has been saved to merged_gene_eng_tokenizer_hf\n"
258
+ ]
259
+ }
260
+ ],
261
  "source": [
262
  "## Save\n",
263
  "output_sp_dir = 'merged_gene_eng_tokenizer_sp'\n",
 
273
  },
274
  {
275
  "cell_type": "code",
276
+ "execution_count": 6,
277
  "id": "cbd1f648-f8a0-4f16-b516-2ce3e7c7cfee",
278
  "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "['<s>', '</s>', '<unk>']\n",
285
+ "[1, 2, 0]\n",
286
+ "{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}\n",
287
+ "Test text:\n",
288
+ " TCGACGGCACGCGACAGCAGCGAGCCCCGCGCACCCGAGCGCGAKCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV,\n",
289
+ "The primary use of LLaMA is research on large language models, including\n",
290
+ "Tokenized by LLaMA tokenizer:['▁T', 'CG', 'AC', 'G', 'GC', 'AC', 'GC', 'G', 'AC', 'AG', 'CA', 'GC', 'G', 'AG', 'CC', 'CC', 'GC', 'GC', 'AC', 'CC', 'GA', 'GC', 'GC', 'GA', 'K', 'CG', 'F', 'V', 'G', 'PM', 'V', 'HL', 'K', 'V', 'H', 'LE', 'AD', 'VA', 'SS', 'CR', 'S', 'AV', 'I', 'Y', 'LT', 'SEE', 'PF', 'EG', 'V', 'L', 'GL', 'RL', 'KE', 'G', 'IA', 'IT', 'GC', 'W', 'PR', 'WP', 'DE', 'MD', 'ERS', 'AV', 'WR', 'VE', 'PY', 'TR', 'H', 'F', 'GR', 'V', 'LY', 'SF', 'GV', ',', '<0x0A>', 'The', '▁primary', '▁use', '▁of', '▁L', 'La', 'MA', '▁is', '▁research', '▁on', '▁large', '▁language', '▁models', ',', '▁including']\n",
291
+ "Tokenized by GENE-LLaMA tokenizer:['▁TCG', 'ACGGC', 'ACGCG', 'ACAG', 'CA', 'GCG', 'AGCCCC', 'GCGC', 'ACCCG', 'AGCGCG', 'AKCG', 'FVGP', 'MVHL', 'KV', 'HLE', 'ADV', 'ASSC', 'RSAV', 'I', 'YL', 'TSEE', 'P', 'FEG', 'VLGL', 'RLK', 'EGI', 'AI', 'TGC', 'W', 'PRW', 'P', 'DEM', 'DER', 'SAV', 'W', 'RVE', 'PY', 'TRH', 'FG', 'RVLY', 'SFGV', ',', '<0x0A>', 'The', '▁primary', '▁use', '▁of', '▁L', 'La', 'MA', '▁is', '▁research', '▁on', '▁large', '▁language', '▁models', ',', '▁including']\n"
292
+ ]
293
+ }
294
+ ],
295
  "source": [
296
  "# Test\n",
297
  "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n",
 
305
  "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n",
306
  "print(f\"Tokenized by GENE-LLaMA tokenizer:{dna_llama_tokenizer.tokenize(text)}\")"
307
  ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "46ae7605-2ef8-4927-bff3-2c0325f8df0d",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": []
316
  }
317
  ],
318
  "metadata": {
04-gene-sft/4-deepspeed-intro.ipynb CHANGED
@@ -56,6 +56,8 @@
56
  "\n",
57
  "每个阶段都进一步减少显存需求,Stage 3 可支持超大规模模型(如 GPT-3)。\n",
58
  "\n",
 
 
59
  "#### **(2)混合精度训练**\n",
60
  "通过 FP16 或 BF16(半精度浮点数)计算,显著减少显存占用并提升计算效率。\n",
61
  "\n",
@@ -567,6 +569,14 @@
567
  "metadata": {},
568
  "outputs": [],
569
  "source": []
 
 
 
 
 
 
 
 
570
  }
571
  ],
572
  "metadata": {
 
56
  "\n",
57
  "每个阶段都进一步减少显存需求,Stage 3 可支持超大规模模型(如 GPT-3)。\n",
58
  "\n",
59
+ "<img src='img/deepspeed.png' width='600px' />\n",
60
+ "\n",
61
  "#### **(2)混合精度训练**\n",
62
  "通过 FP16 或 BF16(半精度浮点数)计算,显著减少显存占用并提升计算效率。\n",
63
  "\n",
 
569
  "metadata": {},
570
  "outputs": [],
571
  "source": []
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "id": "ce701aeb-c8c7-450a-bbf9-b793a19cd0c6",
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": []
580
  }
581
  ],
582
  "metadata": {
04-gene-sft/5-peft-intro.ipynb ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.5 peft简介"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "182b82c4-d484-4c15-a600-03c3b51367ec",
14
+ "metadata": {},
15
+ "source": [
16
+ "**PEFT**(Parameter-Efficient Fine-Tuning,参数高效微调)是一种优化技术,旨在以最小的参数更新实现对大规模预训练模型(如 GPT、BERT 等)的微调。PEFT 技术通过减少微调所需的参数量,显著降低了存储和计算开销,同时保留模型的性能,特别适合资源受限的场景和领域特定任务的定制化。\n",
17
+ "\n",
18
+ "---\n",
19
+ "\n",
20
+ "### **1. 核心思想**\n",
21
+ "传统的微调方式需要更新整个预训练模型的所有参数,PEFT 技术通过只调整少量的参数(如特定层或额外添加的小型模块)实现微调目标,大幅减少了训练开销和存储需求。\n",
22
+ "\n",
23
+ "---\n",
24
+ "\n",
25
+ "### **2. 常见的 PEFT 方法**\n",
26
+ "\n",
27
+ "#### **(1)Adapter 模型**\n",
28
+ "- 在每一层 Transformer 的输出中插入小型适配器模块,仅训练适配器模块的参数。\n",
29
+ "- 原始模型参数保持冻结不变。\n",
30
+ "- 优点:适配器模块参数量小,能适应不同任务。\n",
31
+ "\n",
32
+ "示例方法:\n",
33
+ "- **AdapterFusion**\n",
34
+ "- **MAD-X**\n",
35
+ "\n",
36
+ "---\n",
37
+ "\n",
38
+ "#### **(2)Prefix Tuning**\n",
39
+ "- 在 Transformer 的输入前添加一组可学习的前缀向量,这些前缀与模型的注意力机制交互。\n",
40
+ "- 只调整前缀向量的参数,而不更新原始模型。\n",
41
+ "- 优点:对生成任务效果显著,参数量进一步减少。\n",
42
+ "\n",
43
+ "---\n",
44
+ "\n",
45
+ "#### **(3)LoRA(Low-Rank Adaptation)**\n",
46
+ "- 将预训练模型中的部分权重分解为两个低秩矩阵,仅调整这些低秩矩阵的参数。\n",
47
+ "- 原始权重保持冻结状态。\n",
48
+ "- 优点:参数量极小,计算高效。\n",
49
+ " \n",
50
+ "---\n",
51
+ "\n",
52
+ "#### **(4)Prompt Tuning**\n",
53
+ "- 在输入文本中添加可学习的提示(Prompt)。\n",
54
+ "- 适合 NLP 任务中的文本生成、分类等。\n",
55
+ "- 优点:实现简单,易于集成到现有框架。\n",
56
+ "\n",
57
+ "---\n",
58
+ "\n",
59
+ "### **3. PEFT 的优势**\n",
60
+ "\n",
61
+ "1. **显著减少参数更新量**:\n",
62
+ " - 微调传统的大模型(如 GPT-3)需要更新数百亿参数,而 PEFT 仅需更新百万级别甚至更少的参数。\n",
63
+ "\n",
64
+ "2. **高效存储**:\n",
65
+ " - 每个任务的微调结果只需存储少量额外参数,而不是整个模型。\n",
66
+ "\n",
67
+ "3. **适用多任务**:\n",
68
+ " - 同一预训练模型可以通过不同的 PEFT 模块适配多个任务,无需重新训练。\n",
69
+ "\n",
70
+ "4. **降低计算开销**:\n",
71
+ " - 训练所需的内存和计算显著减少,适合资源有限的环境。\n",
72
+ "\n",
73
+ "---\n",
74
+ "\n",
75
+ "### **4. 应用场景**\n",
76
+ "\n",
77
+ "1. **领域特定任务**:\n",
78
+ " - 医疗、法律、金融等领域微调预训练模型。\n",
79
+ "\n",
80
+ "2. **多任务学习**:\n",
81
+ " - 适配多个任务,复用同一模型的预训练权重。\n",
82
+ "\n",
83
+ "3. **资源受限场景**:\n",
84
+ " - 移动设备、边缘设备上的模型部署。\n",
85
+ "\n",
86
+ "---\n",
87
+ "\n",
88
+ "### **5. Hugging Face PEFT 库**\n",
89
+ "\n",
90
+ "Hugging Face 提供了专门的 PEFT 库,支持多种参数高效微调技术:\n",
91
+ "- **安装**:\n",
92
+ " ```bash\n",
93
+ " pip install peft\n",
94
+ " ```\n",
95
+ "- **使用 LoRA 微调示例**:\n",
96
+ " ```python\n",
97
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
98
+ " from peft import LoraConfig, get_peft_model, TaskType\n",
99
+ "\n",
100
+ " # 加载模型和分词器\n",
101
+ " model_name = \"gpt2\"\n",
102
+ " model = AutoModelForCausalLM.from_pretrained(model_name)\n",
103
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
104
+ "\n",
105
+ " # 配置 LoRA\n",
106
+ " lora_config = LoraConfig(\n",
107
+ " task_type=TaskType.CAUSAL_LM,\n",
108
+ " r=8,\n",
109
+ " lora_alpha=32,\n",
110
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
111
+ " lora_dropout=0.1,\n",
112
+ " bias=\"none\"\n",
113
+ " )\n",
114
+ "\n",
115
+ " # 使用 LoRA 微调模型\n",
116
+ " model = get_peft_model(model, lora_config)\n",
117
+ " model.print_trainable_parameters()\n",
118
+ "\n",
119
+ " # 微调代码...\n",
120
+ " ```\n",
121
+ "\n",
122
+ "---\n",
123
+ "\n",
124
+ "### **6. PEFT 的局限性**\n",
125
+ "1. **特定任务限制**:\n",
126
+ " - 在一些复杂任务中,PEFT 方法可能不如全量微调效果好。\n",
127
+ "\n",
128
+ "2. **需要设计合适的模块**:\n",
129
+ " - 不同任务需要选择和设计合适的 PEFT 技术。\n",
130
+ "\n",
131
+ "3. **与模型架构相关**:\n",
132
+ " - PEFT 技术可能需要对模型架构进行一定程度的修改。\n",
133
+ "\n",
134
+ "---\n",
135
+ "\n",
136
+ "### **7. 小结**\n",
137
+ "PEFT 是一个极具潜力的技术,特别适合在有限资源下对大模型进行微调。它在许多领域和任务中已显示出良好的效果,例如 LoRA 和 Adapter 模型已经成为高效微调的主流方法。\n",
138
+ "\n",
139
+ "如果您需要实现高效微调,可以结合 Hugging Face 的 PEFT 库快速上手。"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 1,
145
+ "id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "import subprocess\n",
150
+ "import os\n",
151
+ "# 设置环境变量, autodl一般区域\n",
152
+ "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
153
+ "output = result.stdout\n",
154
+ "for line in output.splitlines():\n",
155
+ " if '=' in line:\n",
156
+ " var, value = line.split('=', 1)\n",
157
+ " os.environ[var] = value"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "id": "17bdb69d-3f0f-465e-bd60-2047a088e264",
163
+ "metadata": {},
164
+ "source": [
165
+ "如果您不确定模型中有哪些模块可以微调,可以打印模型结构:"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 2,
171
+ "id": "41a0c049-9134-4d89-aad0-1aa2241a9fca",
172
+ "metadata": {},
173
+ "outputs": [
174
+ {
175
+ "data": {
176
+ "application/vnd.jupyter.widget-view+json": {
177
+ "model_id": "4becc479adbc472bb7672d49da16aafd",
178
+ "version_major": 2,
179
+ "version_minor": 0
180
+ },
181
+ "text/plain": [
182
+ "generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]"
183
+ ]
184
+ },
185
+ "metadata": {},
186
+ "output_type": "display_data"
187
+ },
188
+ {
189
+ "name": "stdout",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "\n",
193
+ "transformer\n",
194
+ "transformer.wte\n",
195
+ "transformer.wpe\n",
196
+ "transformer.drop\n",
197
+ "transformer.h\n",
198
+ "transformer.h.0\n",
199
+ "transformer.h.0.ln_1\n",
200
+ "transformer.h.0.attn\n",
201
+ "transformer.h.0.attn.c_attn\n",
202
+ "transformer.h.0.attn.c_proj\n",
203
+ "transformer.h.0.attn.attn_dropout\n",
204
+ "transformer.h.0.attn.resid_dropout\n",
205
+ "transformer.h.0.ln_2\n",
206
+ "transformer.h.0.mlp\n",
207
+ "transformer.h.0.mlp.c_fc\n",
208
+ "transformer.h.0.mlp.c_proj\n",
209
+ "transformer.h.0.mlp.act\n",
210
+ "transformer.h.0.mlp.dropout\n",
211
+ "transformer.h.1\n",
212
+ "transformer.h.1.ln_1\n",
213
+ "transformer.h.1.attn\n",
214
+ "transformer.h.1.attn.c_attn\n",
215
+ "transformer.h.1.attn.c_proj\n",
216
+ "transformer.h.1.attn.attn_dropout\n",
217
+ "transformer.h.1.attn.resid_dropout\n",
218
+ "transformer.h.1.ln_2\n",
219
+ "transformer.h.1.mlp\n",
220
+ "transformer.h.1.mlp.c_fc\n",
221
+ "transformer.h.1.mlp.c_proj\n",
222
+ "transformer.h.1.mlp.act\n",
223
+ "transformer.h.1.mlp.dropout\n",
224
+ "transformer.h.2\n",
225
+ "transformer.h.2.ln_1\n",
226
+ "transformer.h.2.attn\n",
227
+ "transformer.h.2.attn.c_attn\n",
228
+ "transformer.h.2.attn.c_proj\n",
229
+ "transformer.h.2.attn.attn_dropout\n",
230
+ "transformer.h.2.attn.resid_dropout\n",
231
+ "transformer.h.2.ln_2\n",
232
+ "transformer.h.2.mlp\n",
233
+ "transformer.h.2.mlp.c_fc\n",
234
+ "transformer.h.2.mlp.c_proj\n",
235
+ "transformer.h.2.mlp.act\n",
236
+ "transformer.h.2.mlp.dropout\n",
237
+ "transformer.h.3\n",
238
+ "transformer.h.3.ln_1\n",
239
+ "transformer.h.3.attn\n",
240
+ "transformer.h.3.attn.c_attn\n",
241
+ "transformer.h.3.attn.c_proj\n",
242
+ "transformer.h.3.attn.attn_dropout\n",
243
+ "transformer.h.3.attn.resid_dropout\n",
244
+ "transformer.h.3.ln_2\n",
245
+ "transformer.h.3.mlp\n",
246
+ "transformer.h.3.mlp.c_fc\n",
247
+ "transformer.h.3.mlp.c_proj\n",
248
+ "transformer.h.3.mlp.act\n",
249
+ "transformer.h.3.mlp.dropout\n",
250
+ "transformer.h.4\n",
251
+ "transformer.h.4.ln_1\n",
252
+ "transformer.h.4.attn\n",
253
+ "transformer.h.4.attn.c_attn\n",
254
+ "transformer.h.4.attn.c_proj\n",
255
+ "transformer.h.4.attn.attn_dropout\n",
256
+ "transformer.h.4.attn.resid_dropout\n",
257
+ "transformer.h.4.ln_2\n",
258
+ "transformer.h.4.mlp\n",
259
+ "transformer.h.4.mlp.c_fc\n",
260
+ "transformer.h.4.mlp.c_proj\n",
261
+ "transformer.h.4.mlp.act\n",
262
+ "transformer.h.4.mlp.dropout\n",
263
+ "transformer.h.5\n",
264
+ "transformer.h.5.ln_1\n",
265
+ "transformer.h.5.attn\n",
266
+ "transformer.h.5.attn.c_attn\n",
267
+ "transformer.h.5.attn.c_proj\n",
268
+ "transformer.h.5.attn.attn_dropout\n",
269
+ "transformer.h.5.attn.resid_dropout\n",
270
+ "transformer.h.5.ln_2\n",
271
+ "transformer.h.5.mlp\n",
272
+ "transformer.h.5.mlp.c_fc\n",
273
+ "transformer.h.5.mlp.c_proj\n",
274
+ "transformer.h.5.mlp.act\n",
275
+ "transformer.h.5.mlp.dropout\n",
276
+ "transformer.h.6\n",
277
+ "transformer.h.6.ln_1\n",
278
+ "transformer.h.6.attn\n",
279
+ "transformer.h.6.attn.c_attn\n",
280
+ "transformer.h.6.attn.c_proj\n",
281
+ "transformer.h.6.attn.attn_dropout\n",
282
+ "transformer.h.6.attn.resid_dropout\n",
283
+ "transformer.h.6.ln_2\n",
284
+ "transformer.h.6.mlp\n",
285
+ "transformer.h.6.mlp.c_fc\n",
286
+ "transformer.h.6.mlp.c_proj\n",
287
+ "transformer.h.6.mlp.act\n",
288
+ "transformer.h.6.mlp.dropout\n",
289
+ "transformer.h.7\n",
290
+ "transformer.h.7.ln_1\n",
291
+ "transformer.h.7.attn\n",
292
+ "transformer.h.7.attn.c_attn\n",
293
+ "transformer.h.7.attn.c_proj\n",
294
+ "transformer.h.7.attn.attn_dropout\n",
295
+ "transformer.h.7.attn.resid_dropout\n",
296
+ "transformer.h.7.ln_2\n",
297
+ "transformer.h.7.mlp\n",
298
+ "transformer.h.7.mlp.c_fc\n",
299
+ "transformer.h.7.mlp.c_proj\n",
300
+ "transformer.h.7.mlp.act\n",
301
+ "transformer.h.7.mlp.dropout\n",
302
+ "transformer.h.8\n",
303
+ "transformer.h.8.ln_1\n",
304
+ "transformer.h.8.attn\n",
305
+ "transformer.h.8.attn.c_attn\n",
306
+ "transformer.h.8.attn.c_proj\n",
307
+ "transformer.h.8.attn.attn_dropout\n",
308
+ "transformer.h.8.attn.resid_dropout\n",
309
+ "transformer.h.8.ln_2\n",
310
+ "transformer.h.8.mlp\n",
311
+ "transformer.h.8.mlp.c_fc\n",
312
+ "transformer.h.8.mlp.c_proj\n",
313
+ "transformer.h.8.mlp.act\n",
314
+ "transformer.h.8.mlp.dropout\n",
315
+ "transformer.h.9\n",
316
+ "transformer.h.9.ln_1\n",
317
+ "transformer.h.9.attn\n",
318
+ "transformer.h.9.attn.c_attn\n",
319
+ "transformer.h.9.attn.c_proj\n",
320
+ "transformer.h.9.attn.attn_dropout\n",
321
+ "transformer.h.9.attn.resid_dropout\n",
322
+ "transformer.h.9.ln_2\n",
323
+ "transformer.h.9.mlp\n",
324
+ "transformer.h.9.mlp.c_fc\n",
325
+ "transformer.h.9.mlp.c_proj\n",
326
+ "transformer.h.9.mlp.act\n",
327
+ "transformer.h.9.mlp.dropout\n",
328
+ "transformer.h.10\n",
329
+ "transformer.h.10.ln_1\n",
330
+ "transformer.h.10.attn\n",
331
+ "transformer.h.10.attn.c_attn\n",
332
+ "transformer.h.10.attn.c_proj\n",
333
+ "transformer.h.10.attn.attn_dropout\n",
334
+ "transformer.h.10.attn.resid_dropout\n",
335
+ "transformer.h.10.ln_2\n",
336
+ "transformer.h.10.mlp\n",
337
+ "transformer.h.10.mlp.c_fc\n",
338
+ "transformer.h.10.mlp.c_proj\n",
339
+ "transformer.h.10.mlp.act\n",
340
+ "transformer.h.10.mlp.dropout\n",
341
+ "transformer.h.11\n",
342
+ "transformer.h.11.ln_1\n",
343
+ "transformer.h.11.attn\n",
344
+ "transformer.h.11.attn.c_attn\n",
345
+ "transformer.h.11.attn.c_proj\n",
346
+ "transformer.h.11.attn.attn_dropout\n",
347
+ "transformer.h.11.attn.resid_dropout\n",
348
+ "transformer.h.11.ln_2\n",
349
+ "transformer.h.11.mlp\n",
350
+ "transformer.h.11.mlp.c_fc\n",
351
+ "transformer.h.11.mlp.c_proj\n",
352
+ "transformer.h.11.mlp.act\n",
353
+ "transformer.h.11.mlp.dropout\n",
354
+ "transformer.ln_f\n",
355
+ "lm_head\n"
356
+ ]
357
+ }
358
+ ],
359
+ "source": [
360
+ "from transformers import AutoModelForCausalLM\n",
361
+ "\n",
362
+ "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
363
+ "\n",
364
+ "# 打印所有模块名称\n",
365
+ "for name, module in model.named_modules():\n",
366
+ " print(name)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "37aa6abb-ab1c-4e9c-b968-579dd74044db",
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": []
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "id": "0add2f79-f35c-4638-80bb-0d8a87a9b6a7",
380
+ "metadata": {},
381
+ "source": [
382
+ "在选择 `target_modules` 时,通常会根据模块的名称选择模型的特定部分,通常使用列表中最后一个点 `.` 后的字段名或整个路径名(如果需要更精确)。以下是对这些模块的详细分析和选择建议:\n",
383
+ "\n",
384
+ "---\n",
385
+ "\n",
386
+ "### **1. 分析模块结构**\n",
387
+ "\n",
388
+ "从列表中可以看出,GPT-2 的模块层次分为以下几类:\n",
389
+ "\n",
390
+ "1. **Embedding 层**:\n",
391
+ " - `transformer.wte`:词嵌入层(Word Token Embeddings)。\n",
392
+ " - `transformer.wpe`:位置嵌入层(Position Embeddings)。\n",
393
+ "\n",
394
+ "2. **Transformer 编码器层**:\n",
395
+ " - 每层编号为 `transformer.h.<层号>`(共 12 层)。\n",
396
+ " - 每层中包含:\n",
397
+ " - **层归一化**:\n",
398
+ " - `transformer.h.<层号>.ln_1`:第一层归一化。\n",
399
+ " - `transformer.h.<层号>.ln_2`:第二层归一化。\n",
400
+ " - **自注意力模块**:\n",
401
+ " - `transformer.h.<层号>.attn.c_attn`:注意力模块的 Query、Key 和 Value 投影。\n",
402
+ " - `transformer.h.<层号>.attn.c_proj`:注意力的输出投影。\n",
403
+ " - `transformer.h.<层号>.attn.attn_dropout`:注意力的 Dropout。\n",
404
+ " - `transformer.h.<层号>.attn.resid_dropout`:残差的 Dropout。\n",
405
+ " - **前馈网络模块(MLP)**:\n",
406
+ " - `transformer.h.<层号>.mlp.c_fc`:MLP 的第一层全连接。\n",
407
+ " - `transformer.h.<层号>.mlp.c_proj`:MLP 的第二层全连接(输出投影)。\n",
408
+ " - `transformer.h.<层号>.mlp.act`:激活函数(如 GELU)。\n",
409
+ " - `transformer.h.<层号>.mlp.dropout`:MLP 的 Dropout。\n",
410
+ "\n",
411
+ "3. **最终层**:\n",
412
+ " - `transformer.ln_f`:最终层归一化(LayerNorm)。\n",
413
+ " - `lm_head`:语言建模头,用于生成预测的 token 分布。\n",
414
+ "\n",
415
+ "---\n",
416
+ "\n",
417
+ "### **2. 如何选择 `target_modules`**\n",
418
+ "\n",
419
+ "#### **(1)常见目标模块**\n",
420
+ "- `transformer.h.<层号>.attn.c_attn`:对自注意力模块的 Query、Key 和 Value 投影层微调。\n",
421
+ "- `transformer.h.<层号>.attn.c_proj`:对注意力输出的投影层微调。\n",
422
+ "- `transformer.h.<层号>.mlp.c_fc`:对前馈网络的输入全连接层微调。\n",
423
+ "- `transformer.h.<层号>.mlp.c_proj`:对前馈网络的输出投影层微调。\n",
424
+ "\n",
425
+ "#### **(2)推荐设置**\n",
426
+ "- **文本生成任务**:\n",
427
+ " ```python\n",
428
+ " target_modules = [\"transformer.h.*.attn.c_attn\", \"transformer.h.*.attn.c_proj\"]\n",
429
+ " ```\n",
430
+ " 解释:\n",
431
+ " - `*.attn.c_attn`:调整 Query、Key、Value 的生成。\n",
432
+ " - `*.attn.c_proj`:调整注意力输出。\n",
433
+ "\n",
434
+ "- **文本分类任务**:\n",
435
+ " ```python\n",
436
+ " target_modules = [\"transformer.h.*.attn.c_attn\"]\n",
437
+ " ```\n",
438
+ " 解释:\n",
439
+ " - 微调自注意力模块最重要的部分即可。\n",
440
+ "\n",
441
+ "- **特定任务需要更细粒度控制**:\n",
442
+ " - 仅微调某几层:\n",
443
+ " ```python\n",
444
+ " target_modules = [\"transformer.h.0.attn.c_attn\", \"transformer.h.0.mlp.c_fc\"]\n",
445
+ " ```\n",
446
+ "\n",
447
+ "#### **(3)通配符选择**\n",
448
+ "使用 `*` 通配符可以指定所有层的某些模块:\n",
449
+ "- `transformer.h.*.attn.c_attn`:所有层的 Query、Key 和 Value 投影。\n",
450
+ "- `transformer.h.*.mlp.*`:所有层的 MLP 模块。\n",
451
+ "\n",
452
+ "---\n",
453
+ "\n",
454
+ "### **3. 示例:指定多个模块**\n",
455
+ "\n",
456
+ "```python\n",
457
+ "lora_config = LoraConfig(\n",
458
+ " task_type=TaskType.CAUSAL_LM,\n",
459
+ " r=8,\n",
460
+ " lora_alpha=32,\n",
461
+ " target_modules=[\n",
462
+ " \"transformer.h.*.attn.c_attn\",\n",
463
+ " \"transformer.h.*.mlp.c_fc\"\n",
464
+ " ],\n",
465
+ " lora_dropout=0.1,\n",
466
+ " bias=\"none\"\n",
467
+ ")\n",
468
+ "```\n",
469
+ "\n",
470
+ "- 这表示对所有层的 `attn.c_attn` 和 `mlp.c_fc` 模块进行 LoRA 微调。\n",
471
+ "\n",
472
+ "---\n",
473
+ "\n",
474
+ "### **4. 小提示:如何确定适合的模块**\n",
475
+ "\n",
476
+ "1. **任务相关性**:\n",
477
+ " - 文本生成:优先选择自注意力模块(如 `c_attn`)。\n",
478
+ " - 文本分类:通常需要全局语义表示,选择 `attn.c_attn` 或 `mlp.c_fc`。\n",
479
+ "\n",
480
+ "2. **性能与资源平衡**:\n",
481
+ " - 如果显存有限,可以只微调部分层。例如,仅选择浅层和深层的模块:\n",
482
+ " ```python\n",
483
+ " target_modules = [\"transformer.h.0.attn.c_attn\", \"transformer.h.11.attn.c_attn\"]\n",
484
+ " ```\n",
485
+ "\n",
486
+ "3. **打印模块名称以调试**:\n",
487
+ " - 确保选择的 `target_modules` 在模型中实际存在:\n",
488
+ " ```python\n",
489
+ " for name, _ in model.named_modules():\n",
490
+ " if \"c_attn\" in name:\n",
491
+ " print(name)\n",
492
+ " ```\n",
493
+ "\n",
494
+ "---\n",
495
+ "\n",
496
+ "### **建议**\n",
497
+ "- 一般情况下,`c_attn` 和 `c_proj` 是首选模块。\n",
498
+ "- 使用 `transformer.h.*` 通配符可以轻松指定多层。\n",
499
+ "- 根据任务需求和资源限制灵活调整目标模块,以实现最佳性能和效率。"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "id": "b4a41750-420f-49c4-845d-69db394794f9",
506
+ "metadata": {},
507
+ "outputs": [],
508
+ "source": []
509
+ },
510
+ {
511
+ "cell_type": "markdown",
512
+ "id": "10c99eb9-8007-4297-972e-7be71768c9c3",
513
+ "metadata": {},
514
+ "source": [
515
+ "以下是对 `LoraConfig` 配置的更详细解释,特别是如何设置微调哪些参数、冻结哪些参数,以及一般如何选择这些设置:\n",
516
+ "\n",
517
+ "---\n",
518
+ "\n",
519
+ "### **1. `LoraConfig` 参数解析**\n",
520
+ "\n",
521
+ "```python\n",
522
+ "lora_config = LoraConfig(\n",
523
+ " task_type=TaskType.SEQ_CLS, # 序列分类任务\n",
524
+ " r=8, # 降低矩阵秩\n",
525
+ " lora_alpha=32, # LoRA 的 alpha 超参数\n",
526
+ " target_modules=[\"c_attn\"], # GPT-2 中的自注意力模块\n",
527
+ " lora_dropout=0.1, # dropout 概率\n",
528
+ " bias=\"none\", # 是否微调偏置参数\n",
529
+ ")\n",
530
+ "```\n",
531
+ "\n",
532
+ "#### **(1)`task_type`**\n",
533
+ "- 定义任务类型,用于指导 PEFT 的具体行为。\n",
534
+ "- **常见选项**:\n",
535
+ " - `TaskType.CAUSAL_LM`:自回归语言建模(��� GPT 系列模型)。\n",
536
+ " - `TaskType.SEQ_CLS`:序列分类(如情感分析)。\n",
537
+ " - `TaskType.TOKEN_CLS`:标注任务(如命名实体识别)。\n",
538
+ " - `TaskType.SEQ_2_SEQ_LM`:序列到序列任务(如翻译、摘要)。\n",
539
+ "\n",
540
+ "**当前设置**:\n",
541
+ "- `TaskType.SEQ_CLS` 表示目标是文本分类任务。\n",
542
+ "\n",
543
+ "---\n",
544
+ "\n",
545
+ "#### **(2)`r`**\n",
546
+ "- 表示 LoRA 的 **秩**(rank),是降低矩阵秩的核心参数。\n",
547
+ "- LoRA 通过将模型的权重分解为两个低秩矩阵(`A` 和 `B`),只更新这两个矩阵。\n",
548
+ "- `r` 的值越大,微调能力越强,但需要的额外参数也越多。\n",
549
+ "- **典型范围**:`4` 至 `64`,大多数任务中 `8` 或 `16` 是常用值。\n",
550
+ "\n",
551
+ "**当前设置**:\n",
552
+ "- `r=8` 表示使用低秩分解,并微调 8 维的参数矩阵。\n",
553
+ "\n",
554
+ "---\n",
555
+ "\n",
556
+ "#### **(3)`lora_alpha`**\n",
557
+ "- 是 LoRA 的一个缩放因子,用于调节两个低秩矩阵的更新速率。\n",
558
+ "- **公式**:实际更新 = LoRA 输出 × `lora_alpha / r`\n",
559
+ "- **典型范围**:`16` 至 `128`,较大任务中可以选择更高的值。\n",
560
+ "\n",
561
+ "**当前设置**:\n",
562
+ "- `lora_alpha=32`,表示适中幅度的更新速率。\n",
563
+ "\n",
564
+ "---\n",
565
+ "\n",
566
+ "#### **(4)`target_modules`**\n",
567
+ "- 指定要应用 LoRA 微调的模块。\n",
568
+ "- **常见选择**:\n",
569
+ " - 对 Transformer 模型中的 **注意力模块**(如 `query`、`key`、`value`)进行微调,因为这些模块对任务性能影响较大。\n",
570
+ " - 对 GPT-2,通常选择 `c_attn`(GPT-2 中负责自注意力机制的组合模块)。\n",
571
+ "\n",
572
+ "**当前设置**:\n",
573
+ "- `target_modules=[\"c_attn\"]` 表示只对 GPT-2 的自注意力模块 `c_attn` 应用 LoRA。\n",
574
+ "\n",
575
+ "---\n",
576
+ "\n",
577
+ "#### **(5)`lora_dropout`**\n",
578
+ "- 表示 LoRA 层的 dropout 概率,用于防止过拟合。\n",
579
+ "- **典型范围**:`0.0` 至 `0.1`,视任务复杂性而定。\n",
580
+ "\n",
581
+ "**当前设置**:\n",
582
+ "- `lora_dropout=0.1`,表示有 10% 的概率随机丢弃 LoRA 层的输出。\n",
583
+ "\n",
584
+ "---\n",
585
+ "\n",
586
+ "#### **(6)`bias`**\n",
587
+ "- 决定是否微调偏置参数。\n",
588
+ "- **选项**:\n",
589
+ " - `\"none\"`:不微调任何偏置。\n",
590
+ " - `\"all\"`:微调所有偏置。\n",
591
+ " - `\"lora_only\"`:只微调 LoRA 层的偏置。\n",
592
+ "\n",
593
+ "**当前设置**:\n",
594
+ "- `bias=\"none\"`,表示所有偏置参数保持冻结。\n",
595
+ "\n",
596
+ "---\n",
597
+ "\n",
598
+ "### **2. 微调哪些参数,冻结哪些参数**\n",
599
+ "\n",
600
+ "LoRA 的核心思想是通过 **分解矩阵**,只更新少量参数,而冻结模型的大部分参数。以下是常见设置的说明:\n",
601
+ "\n",
602
+ "#### **微调的参数**\n",
603
+ "- LoRA 通过 `target_modules` 指定的模块,例如:\n",
604
+ " - GPT-2 的 `c_attn`(自注意力模块)。\n",
605
+ " - BERT 的 `query` 和 `key`。\n",
606
+ "- 这些模块是模型中对性能贡献最大的部分,通过微调这些模块,任务性能可以显著提升。\n",
607
+ "\n",
608
+ "#### **冻结的参数**\n",
609
+ "- 除了 `target_modules` 中指定的参数外,所有其他模型参数默认冻结,包括:\n",
610
+ " - 预训练权重的绝大部分。\n",
611
+ " - 偏置参数(如果 `bias=\"none\"`)。\n",
612
+ "\n",
613
+ "---\n",
614
+ "\n",
615
+ "### **3. 一般如何设置**\n",
616
+ "\n",
617
+ "#### **(1)针对不同任务调整**\n",
618
+ "- **文本分类任务**:\n",
619
+ " - 优先选择自注意力模块(如 `c_attn`)作为 `target_modules`。\n",
620
+ " - `r=8` 或 `r=16` 是常见选择,适中计算开销。\n",
621
+ " - 设置适当的 dropout(如 `lora_dropout=0.1`)以防止过拟合。\n",
622
+ " \n",
623
+ "- **语言生成任务**:\n",
624
+ " - 对 GPT-2 或 GPT-3,选择 `q_proj` 和 `v_proj`(query 和 value 投影模块)。\n",
625
+ " - `r=16` 或更高,适应生成任务的高复杂性。\n",
626
+ "\n",
627
+ "- **命名实体识别任务**:\n",
628
+ " - 优先选择 `q_proj` 和 `k_proj`(query 和 key 模块)。\n",
629
+ "\n",
630
+ "#### **(2)参数量与显存的权衡**\n",
631
+ "- 如果显存有限,减少 `r` 的值。\n",
632
+ "- 对小型任务,`r=4` 或 `r=8` 通常已经足够。\n",
633
+ "\n",
634
+ "#### **(3)偏置设置**\n",
635
+ "- 偏置参数的影响较小,在大多数情况下,可以选择 `bias=\"none\"` 保持冻结。\n",
636
+ "- 对非常依赖偏置的任务(如生成风格微调),可以尝试 `bias=\"lora_only\"`。\n",
637
+ "\n",
638
+ "---\n",
639
+ "\n",
640
+ "### **4. 示例:如何选择目标模块**\n",
641
+ "\n",
642
+ "#### **GPT-2**\n",
643
+ "对 GPT-2 来说,以下模块通常是微调的目标:\n",
644
+ "- **`c_attn`**:注意力模块的组合层。\n",
645
+ "- **`q_proj` 和 `v_proj`**:Query 和 Value 的线性投影。\n",
646
+ "\n",
647
+ "#### **BERT**\n",
648
+ "对 BERT 来说,以下模块通常是微调的目标:\n",
649
+ "- **`query`**:Attention 的 Query 模块。\n",
650
+ "- **`key`**:Attention 的 Key 模块。\n",
651
+ "\n",
652
+ "---\n",
653
+ "\n",
654
+ "### **5. 总结建议**\n",
655
+ "- **微调的参数**:优先选择模型中注意力相关模块。\n",
656
+ "- **冻结的参数**:大部分参数默认冻结以节省显存。\n",
657
+ "- **配置选择**:根据任务复杂性调整 `r` 和 `target_modules`。\n",
658
+ "- **推荐起点**:\n",
659
+ " - 文本分类:`target_modules=[\"c_attn\"]`, `r=8`, `lora_dropout=0.1`。\n",
660
+ " - 文本生成:`target_modules=[\"q_proj\", \"v_proj\"]`, `r=16`, `lora_dropout=0.1`。\n",
661
+ "\n",
662
+ "通过这些设置,LoRA 可以在参数量极小的情况下实现高效微调,适合各种任务场景。"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": null,
668
+ "id": "26d9f362-18cc-471f-b208-f29a6933c06a",
669
+ "metadata": {},
670
+ "outputs": [],
671
+ "source": [
672
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer\n",
673
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
674
+ "from datasets import load_dataset\n",
675
+ "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
676
+ "\n",
677
+ "# **1. 加载模型和分词器**\n",
678
+ "model_name = \"gpt2\" # 基础模型\n",
679
+ "num_labels = 2 # 二分类任务\n",
680
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)\n",
681
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
682
+ "tokenizer.pad_token = tokenizer.eos_token # 设置 pad_token 为 eos_token\n",
683
+ "\n",
684
+ "# **2. 定义数据集**\n",
685
+ "# 示例数据集:dna_promoter_300\n",
686
+ "dataset = load_dataset(\"dnagpt/dna_promoter_300\")['train'].train_test_split(test_size=0.1)\n",
687
+ "\n",
688
+ "# **3. 数据预处理**\n",
689
+ "def preprocess_function(examples):\n",
690
+ " examples['label'] = [int(item) for item in examples['label']]\n",
691
+ " return tokenizer(\n",
692
+ " examples[\"sequence\"], truncation=True, padding=\"max_length\", max_length=128\n",
693
+ " )\n",
694
+ "\n",
695
+ "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
696
+ "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\") # Hugging Face Trainer 要求标签列名为 'labels'\n",
697
+ "\n",
698
+ "# **4. 划分数据集**\n",
699
+ "train_dataset = tokenized_datasets[\"train\"]\n",
700
+ "test_dataset = tokenized_datasets[\"test\"]\n",
701
+ "\n",
702
+ "# **5. 配置 LoRA**\n",
703
+ "lora_config = LoraConfig(\n",
704
+ " task_type=TaskType.SEQ_CLS, # 序列分类任务\n",
705
+ " r=8, # 降低矩阵秩\n",
706
+ " lora_alpha=32, # LoRA 的 alpha 超参数\n",
707
+ " target_modules=[\"c_attn\"], # GPT-2 中的自注意力模块\n",
708
+ " lora_dropout=0.1, # dropout 概率\n",
709
+ " bias=\"none\", # 是否微调偏置参数\n",
710
+ ")\n",
711
+ "\n",
712
+ "# 使用 LoRA 包装模型\n",
713
+ "model = get_peft_model(model, lora_config)\n",
714
+ "model.print_trainable_parameters() # 打印可训练的参数信息\n",
715
+ "\n",
716
+ "# **6. 计算指标**\n",
717
+ "def compute_metrics(eval_pred):\n",
718
+ " predictions, labels = eval_pred\n",
719
+ " preds = predictions.argmax(axis=-1)\n",
720
+ " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=\"binary\")\n",
721
+ " acc = accuracy_score(labels, preds)\n",
722
+ " return {\"accuracy\": acc, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
723
+ "\n",
724
+ "# **7. 定义训练参数**\n",
725
+ "training_args = TrainingArguments(\n",
726
+ " output_dir=\"./gpt2_lora_text_classification\", # 模型保存路径\n",
727
+ " evaluation_strategy=\"epoch\", # 每个 epoch 评估一次\n",
728
+ " save_strategy=\"epoch\", # 每个 epoch 保存一次\n",
729
+ " learning_rate=2e-5, # 学习率\n",
730
+ " per_device_train_batch_size=8, # 每设备的批量大小\n",
731
+ " per_device_eval_batch_size=8, # 每设备评估的批量大小\n",
732
+ " num_train_epochs=3, # 训练轮数\n",
733
+ " weight_decay=0.01, # 权重衰减\n",
734
+ " logging_dir=\"./logs\", # 日志路径\n",
735
+ " fp16=True, # 启用混合精度训练\n",
736
+ " save_total_limit=2, # 保留最多两个检查点\n",
737
+ " load_best_model_at_end=True, # 加载最佳模型\n",
738
+ " metric_for_best_model=\"accuracy\", # 根据准确率选择最佳模型\n",
739
+ " greater_is_better=True,\n",
740
+ ")\n",
741
+ "\n",
742
+ "# **8. 定义 Trainer**\n",
743
+ "trainer = Trainer(\n",
744
+ " model=model,\n",
745
+ " args=training_args,\n",
746
+ " train_dataset=train_dataset,\n",
747
+ " eval_dataset=test_dataset,\n",
748
+ " tokenizer=tokenizer,\n",
749
+ " compute_metrics=compute_metrics,\n",
750
+ ")\n",
751
+ "\n",
752
+ "# **9. 开始训练**\n",
753
+ "trainer.train()\n",
754
+ "\n",
755
+ "# **10. 保存模型**\n",
756
+ "model.save_pretrained(\"./gpt2_lora_text_classification\")\n",
757
+ "tokenizer.save_pretrained(\"./gpt2_lora_text_classification\")\n",
758
+ "\n",
759
+ "print(\"训练完成,模型已保存至 ./gpt2_lora_text_classification\")"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": null,
765
+ "id": "49a60fed-3a7d-4608-98b1-b4e313b94dbb",
766
+ "metadata": {},
767
+ "outputs": [],
768
+ "source": [
769
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
770
+ "from peft import PeftModel\n",
771
+ "\n",
772
+ "# 加载分词器\n",
773
+ "model_path = \"./gpt2_lora_text_classification\"\n",
774
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
775
+ "\n",
776
+ "# 加载微调后的 PEFT 模型\n",
777
+ "base_model = AutoModelForSequenceClassification.from_pretrained(\"gpt2\", num_labels=2)\n",
778
+ "model = PeftModel.from_pretrained(base_model, model_path)"
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "execution_count": null,
784
+ "id": "3c0d8f02-c3dc-4961-8b3a-50eefc5f9448",
785
+ "metadata": {},
786
+ "outputs": [],
787
+ "source": [
788
+ "import torch\n",
789
+ "\n",
790
+ "def predict(texts, model, tokenizer):\n",
791
+ " \"\"\"\n",
792
+ " 使用微调后的 PEFT 模型进行推理。\n",
793
+ " \n",
794
+ " Args:\n",
795
+ " texts (list of str): 待分类的文本列表。\n",
796
+ " model (PeftModel): 微调后的模型。\n",
797
+ " tokenizer (AutoTokenizer): 分词器。\n",
798
+ " \n",
799
+ " Returns:\n",
800
+ " list of dict: 每个文本的预测结果,包括 logits 和预测的类别标签。\n",
801
+ " \"\"\"\n",
802
+ " # 对输入文本进行分词和编码\n",
803
+ " inputs = tokenizer(\n",
804
+ " texts,\n",
805
+ " padding=True,\n",
806
+ " truncation=True,\n",
807
+ " max_length=512,\n",
808
+ " return_tensors=\"pt\"\n",
809
+ " )\n",
810
+ " \n",
811
+ " # 将输入数据移动到模型的设备上(CPU/GPU)\n",
812
+ " inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
813
+ " \n",
814
+ " # 模型推理\n",
815
+ " model.eval()\n",
816
+ " with torch.no_grad():\n",
817
+ " outputs = model(**inputs)\n",
818
+ " \n",
819
+ " # 获取 logits 并计算预测类别\n",
820
+ " logits = outputs.logits\n",
821
+ " probs = torch.nn.functional.softmax(logits, dim=-1)\n",
822
+ " predictions = torch.argmax(probs, dim=-1)\n",
823
+ " \n",
824
+ " # 返回每个文本的预测结果\n",
825
+ " results = [\n",
826
+ " {\"text\": text, \"logits\": logit.tolist(), \"predicted_class\": int(pred)}\n",
827
+ " for text, logit, pred in zip(texts, logits, predictions)\n",
828
+ " ]\n",
829
+ " return results\n"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "code",
834
+ "execution_count": null,
835
+ "id": "9c0cfe65-f4f3-4274-a4f4-1ac13725b15a",
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "Text: This movie was fantastic! I loved every part of it.\n",
840
+ "Predicted Class: 1\n",
841
+ "Logits: [-2.345, 3.567]\n",
842
+ "\n",
843
+ "Text: The plot was terrible and the acting was worse.\n",
844
+ "Predicted Class: 0\n",
845
+ "Logits: [4.123, -1.234]\n"
846
+ ]
847
+ }
848
+ ],
849
+ "metadata": {
850
+ "kernelspec": {
851
+ "display_name": "Python 3 (ipykernel)",
852
+ "language": "python",
853
+ "name": "python3"
854
+ },
855
+ "language_info": {
856
+ "codemirror_mode": {
857
+ "name": "ipython",
858
+ "version": 3
859
+ },
860
+ "file_extension": ".py",
861
+ "mimetype": "text/x-python",
862
+ "name": "python",
863
+ "nbconvert_exporter": "python",
864
+ "pygments_lexer": "ipython3",
865
+ "version": "3.12.3"
866
+ }
867
+ },
868
+ "nbformat": 4,
869
+ "nbformat_minor": 5
870
+ }
04-gene-sft/6-llama-continue-train.ipynb ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1e6d4978-4f0f-4268-aa23-d864857bd6c8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.6 基于llama的基因大模型持续预训练"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "2c201732-e736-463c-8446-637bf517479f",
14
+ "metadata": {},
15
+ "source": [
16
+ "LLaMA(**Large Language Model Meta AI**)是由 Meta(Facebook)开发的一系列大型语言模型,专注于提供高性能和高效的大语言模型,面向学术研究和开发社区。LLaMA 系列主要强调训练效率、模型性能和对计算资源的高效利用,是 GPT 系列模型的有力竞争者之一。\n",
17
+ "\n",
18
+ "---\n",
19
+ "\n",
20
+ "### **1. LLaMA 模型概述**\n",
21
+ "\n",
22
+ "#### **1.1 LLaMA 1**\n",
23
+ "- **发布**:2023 年 2 月。\n",
24
+ "- **模型参数规模**:\n",
25
+ " - 7B(70 亿)\n",
26
+ " - 13B(130 亿)\n",
27
+ " - 33B(330 亿)\n",
28
+ " - 65B(650 亿)\n",
29
+ "- **特点**:\n",
30
+ " - 专注于效率:与 GPT-3 等模型相比,LLaMA 在相同的训练成本下实现了更高的性能。\n",
31
+ " - 针对研究开放:提供预训练模型权重供研究使用。\n",
32
+ " - 使用高质量的数据:模型训练使用大量从网络中筛选的高质量文本数据,包括维基百科、书籍和其他高质量来源。\n",
33
+ "- **性能**:\n",
34
+ " - 在许多 NLP 任务中,LLaMA 的性能超过 GPT-3 和其他同类模型。\n",
35
+ " - 参数规模较小的版本(如 LLaMA-13B)性能可与 GPT-3(175B 参数)媲美。\n",
36
+ "\n",
37
+ "#### **1.2 LLaMA 2**\n",
38
+ "- **发布**:2023 年 7 月。\n",
39
+ "- **改进**:\n",
40
+ " - 增强的训练数据:相比 LLaMA 1,使用了更多的高质量数据。\n",
41
+ " - 引入微调版本:发布了开箱即用的对话模型(LLaMA 2-Chat)。\n",
42
+ " - 更好的开源支持:LLaMA 2 在商业用途上比 LLaMA 1 更加开放。\n",
43
+ "- **模型参数规模**:\n",
44
+ " - 7B(70 亿)\n",
45
+ " - 13B(130 亿)\n",
46
+ " - 70B(700 亿)\n",
47
+ "- **性能**:\n",
48
+ " - LLaMA 2 的性能相比 LLaMA 1 有显著提升。\n",
49
+ " - LLaMA 2-Chat 在对话任务中的表现优于许多现有开源模型。\n",
50
+ " - 在多个标准基准(如 MMLU)上超过 GPT-4 和 Claude 的开源实现。\n",
51
+ "\n",
52
+ "---\n",
53
+ "\n",
54
+ "### **2. LLaMA 的关键技术特点**\n",
55
+ "\n",
56
+ "#### **2.1 高效的架构设计**\n",
57
+ "- 基于 Transformer 架构。\n",
58
+ "- 针对训练效率和推理速度进行了优化,适合研究和开发。\n",
59
+ "\n",
60
+ "#### **2.2 模型压缩**\n",
61
+ "- 提供更小的参数规模(如 7B 和 13B),以便在更低的计算资源上运行。\n",
62
+ "- 在性能与参数量之间实现了很好的平衡。\n",
63
+ "\n",
64
+ "#### **2.3 训练数据**\n",
65
+ "- 使用从互联网中提取的高质量数据,注重数据清洗和筛选,避免低质量文本对模型的负面影响。\n",
66
+ "\n",
67
+ "#### **2.4 微调能力**\n",
68
+ "- 支持指令微调(Instruction Tuning)和 RLHF(基于人类反馈的强化学习),特别是在 LLaMA 2-Chat 模型中表现优异。\n",
69
+ "\n",
70
+ "---\n",
71
+ "\n",
72
+ "### **3. LLaMA 的性能对比**\n",
73
+ "\n",
74
+ "#### **与 GPT-3 比较**\n",
75
+ "- LLaMA 1-13B 参数模型在许多任务上的性能接近 GPT-3-175B。\n",
76
+ "- LLaMA 2-70B 在多个任务上超过 GPT-3。\n",
77
+ "\n",
78
+ "#### **与其他开源模型比较**\n",
79
+ "- LLaMA 2 在许多基准测试中优于其他开源模型(如 Falcon 和 MPT)。\n",
80
+ "- LLaMA 2-Chat 提供了与 ChatGPT 类似的对话能力,适用于对话任务。\n",
81
+ "\n",
82
+ "---\n",
83
+ "\n",
84
+ "### **4. 应用场景**\n",
85
+ "\n",
86
+ "1. **研究**:\n",
87
+ " - 开源权重适合学术研究,推动了对大语言模型的进一步探索。\n",
88
+ "\n",
89
+ "2. **对话系统**:\n",
90
+ " - LLaMA 2-Chat 专为对话任务设计,适合开发智能客服、聊天机器人等应用。\n",
91
+ "\n",
92
+ "3. **生成任务**:\n",
93
+ " - 支持文本生成、补全、摘要等任务。\n",
94
+ "\n",
95
+ "4. **微调与定制**:\n",
96
+ " - 可以基于特定领域数据进行微调,如医学、法律、教育等领域的专用模型。\n",
97
+ "\n",
98
+ "---\n",
99
+ "\n",
100
+ "### **5. 开源与获取方式**\n",
101
+ "\n",
102
+ "#### **1. 开源**\n",
103
+ "- LLaMA 1:需要申请权限才能获得模型权重。\n",
104
+ "- LLaMA 2:更加开放,允许商业用途,模型和权重可以通过 Meta 的合作平台获取(如 Hugging Face 和 AWS)。\n",
105
+ "\n",
106
+ "#### **2. 下载与使用**\n",
107
+ "使用 Hugging Face 加载模型:\n",
108
+ "```python\n",
109
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
110
+ "\n",
111
+ "model_name = \"meta-llama/Llama-2-7b-hf\" # 替换为具体模型\n",
112
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
113
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
114
+ "\n",
115
+ "# 使用模型生成文本\n",
116
+ "inputs = tokenizer(\"Hello, how are you?\", return_tensors=\"pt\")\n",
117
+ "outputs = model.generate(**inputs, max_length=50)\n",
118
+ "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
119
+ "```\n",
120
+ "\n",
121
+ "---\n",
122
+ "\n",
123
+ "### **6. 总结**\n",
124
+ "\n",
125
+ "#### **优势**\n",
126
+ "- **高性能**:在多个基准任务上表现出色。\n",
127
+ "- **高效训练**:小参数模型能与大模型媲美。\n",
128
+ "- **开放性**:LLaMA 2 提供了较为开放的商用许可。\n",
129
+ "\n",
130
+ "#### **局限**\n",
131
+ "- 模型需要高质量数据和强大算力训练,对推理设备也有一定要求。\n",
132
+ "\n",
133
+ "LLaMA 系列以其高效和开放的特点,为大模型研究和应用带来了强大动力,是当前大语言模型生态的重要组成部分。"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "7fb0d648-f891-47b9-a644-af5263fa9718",
139
+ "metadata": {},
140
+ "source": [
141
+ "---\n",
142
+ "---"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "id": "8b3c9ebb-213b-4dc4-a712-5a819fea3197",
148
+ "metadata": {},
149
+ "source": [
150
+ "**大模型的持续预训练**(Continual Pretraining of Large Models)是指在基础预训练模型(如 GPT、BERT 等)的基础上,通过引入新的数据或特定领域的数据继续进行预训练的过程。这一过程旨在让模型在特定场景或任务中表现更好,同时保留其通用能力。\n",
151
+ "\n",
152
+ "---\n",
153
+ "\n",
154
+ "### **1. 持续预训练的概念**\n",
155
+ "\n",
156
+ "持续预训练是一种在通用大模型的预训练基础上,进一步优化和适配模型的方法,主要包括以下两种场景:\n",
157
+ "1. **领域适配**:\n",
158
+ " - 将预训练模型在特定领域的数据上继续训练,使其对该领域的语料理解更深刻,例如法律、医学、金融等领域。\n",
159
+ "2. **性能优化**:\n",
160
+ " - 通过引入更多的通用数据或多样化的数据类型,扩展模型的通用能力,提高性能。\n",
161
+ "\n",
162
+ "---\n",
163
+ "\n",
164
+ "### **2. 持续预训练的目标**\n",
165
+ "\n",
166
+ "1. **提升领域性能**:\n",
167
+ " - 在特定领域任务上,模型能够更好地理解特定领域的语言模式和知识。\n",
168
+ " \n",
169
+ "2. **增强模型鲁棒性**:\n",
170
+ " - 通过引入新的数据或增强数据多样性,使模型对未见数据表现更稳定。\n",
171
+ "\n",
172
+ "3. **优化资源利用**:\n",
173
+ " - 通过复用已有的大模型权重,只需训练少量额外步骤,避免从零开始重新训练模型。\n",
174
+ "\n",
175
+ "---\n",
176
+ "\n",
177
+ "### **3. 持续预训练的步骤**\n",
178
+ "\n",
179
+ "#### **(1)数据准备**\n",
180
+ "- **领域数据**:针对特定领域(如医学、法律、科技)收集高质量语料。\n",
181
+ "- **新语料整合**:补充模型未见过的多样化语料。\n",
182
+ "- **数据清洗**:确保数据无噪声、语言风格一致。\n",
183
+ "\n",
184
+ "#### **(2)模型初始化**\n",
185
+ "- 使用现有的预训练模型作为初始权重,例如 Hugging Face 提供的 GPT-2 或 BERT 模型。\n",
186
+ "\n",
187
+ "#### **(3)训练设置**\n",
188
+ "- **超参数调整**:\n",
189
+ " - 通常使用较小的学习率(例如 `1e-5` 或 `2e-5`)以避免破坏已有的知识。\n",
190
+ "- **训练策略**:\n",
191
+ " - 冻结部分参数(如嵌入层或前几层)以保留通用能力,仅调整高层或新加入的部分。\n",
192
+ "\n",
193
+ "#### **(4)评估和验证**\n",
194
+ "- 使用领域任务的数据集对模型进行评估,验证其在目标任务中的改进效果。\n",
195
+ "\n",
196
+ "---\n",
197
+ "\n",
198
+ "### **4. 持续预训练的常见方法**\n",
199
+ "\n",
200
+ "#### **(1)全量持续预训练**\n",
201
+ "- 对整个模型的参数进行调整。\n",
202
+ "- **优点**:适合较大规模的新数据训练,能显著提升领域性能。\n",
203
+ "- **缺点**:计算资源需求大,可能导致模型过拟合。\n",
204
+ "\n",
205
+ "#### **(2)冻结部分参数**\n",
206
+ "- 冻结低层参数,仅微调高层。\n",
207
+ "- **优点**:保留通用知识,减少计算开销。\n",
208
+ "- **缺点**:对领域特定知识的适配可能不足。\n",
209
+ "\n",
210
+ "#### **(3)参数高效微调(PEFT)**\n",
211
+ "- 使用 PEFT 方法(如 LoRA、Adapter)进行预训练:\n",
212
+ " - **LoRA**:通过低秩矩阵分解,微调部分关键模块。\n",
213
+ " - **Adapter**:在 Transformer 层中插入小型适配模块。\n",
214
+ "- **优点**:显著减少需要更新的参数量。\n",
215
+ "\n",
216
+ "---\n",
217
+ "\n",
218
+ "### **5. 持续预训练的典型应用**\n",
219
+ "\n",
220
+ "1. **领域适配**\n",
221
+ " - **医学**:将预训练模型在 PubMed 或生物医学数据集上进行持续预训练。\n",
222
+ " - **法律**:使用法律文档进一步训练基础模型。\n",
223
+ " - **金融**:通过金融新闻、报告语料提升模型在金融领域的表现。\n",
224
+ "\n",
225
+ "2. **多语言扩展**\n",
226
+ " - 引入多语言语料,扩展模型的多语言能力。\n",
227
+ "\n",
228
+ "3. **数据更新**\n",
229
+ " - 持续加入新数据(如时事新闻)以适配最新语言模式。\n",
230
+ "\n",
231
+ "4. **特殊任务优化**\n",
232
+ " - 针对特定任务(如代码生成、对话)引入专用数据进行训练。\n",
233
+ "\n",
234
+ "---\n",
235
+ "\n",
236
+ "### **6. 实现持续预训练的代码示例**\n",
237
+ "\n",
238
+ "以下示例基于 Hugging Face 实现 GPT-2 的持续预训练:\n",
239
+ "\n",
240
+ "```python\n",
241
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n",
242
+ "from datasets import load_dataset\n",
243
+ "\n",
244
+ "# 1. 加载预训练模型和分词器\n",
245
+ "model_name = \"gpt2\"\n",
246
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
247
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
248
+ "\n",
249
+ "# 2. 加载新语料数据\n",
250
+ "dataset = load_dataset(\"text\", data_files={\"train\": \"domain_corpus.txt\"})\n",
251
+ "\n",
252
+ "# 3. 数据预处理\n",
253
+ "def tokenize_function(examples):\n",
254
+ " return tokenizer(examples[\"text\"], truncation=True, max_length=1024, padding=\"max_length\")\n",
255
+ "\n",
256
+ "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
257
+ "\n",
258
+ "# 4. 设置训练参数\n",
259
+ "training_args = TrainingArguments(\n",
260
+ " output_dir=\"./gpt2_domain_adapted\",\n",
261
+ " overwrite_output_dir=True,\n",
262
+ " per_device_train_batch_size=4,\n",
263
+ " num_train_epochs=3,\n",
264
+ " learning_rate=5e-5,\n",
265
+ " save_steps=500,\n",
266
+ " save_total_limit=2,\n",
267
+ " logging_dir=\"./logs\",\n",
268
+ " evaluation_strategy=\"no\", # 评估策略可以根据需要调整\n",
269
+ " fp16=True, # 混合精度训练\n",
270
+ ")\n",
271
+ "\n",
272
+ "# 5. 定义 Trainer 并启动训练\n",
273
+ "trainer = Trainer(\n",
274
+ " model=model,\n",
275
+ " args=training_args,\n",
276
+ " train_dataset=tokenized_dataset[\"train\"],\n",
277
+ " tokenizer=tokenizer,\n",
278
+ ")\n",
279
+ "\n",
280
+ "trainer.train()\n",
281
+ "\n",
282
+ "# 6. 保存模型\n",
283
+ "model.save_pretrained(\"./gpt2_domain_adapted\")\n",
284
+ "tokenizer.save_pretrained(\"./gpt2_domain_adapted\")\n",
285
+ "```\n",
286
+ "\n",
287
+ "---\n",
288
+ "\n",
289
+ "### **7. 持续预训练的挑战**\n",
290
+ "\n",
291
+ "1. **灾难性遗忘**:\n",
292
+ " - 持续预训练可能导致模型丧失之前学到的知识。\n",
293
+ " - **解决方法**:使用少量原始数据进行联合训练。\n",
294
+ "\n",
295
+ "2. **计算资源需求**:\n",
296
+ " - 需要大量显存和算力,特别是对于大规模模型和数据。\n",
297
+ "\n",
298
+ "3. **数据质量和多样性**:\n",
299
+ " - 新引入的数据可能包含噪声,影响模型性能。\n",
300
+ "\n",
301
+ "---\n",
302
+ "\n",
303
+ "### **8. 持续预训练的优势**\n",
304
+ "\n",
305
+ "- 提高特定领域或任务的性能。\n",
306
+ "- 更高效地利用已有模型权重,避免从头训练。\n",
307
+ "- 保留原始模型的通用能力,同时增强领域适应性。\n",
308
+ "\n",
309
+ "---\n",
310
+ "\n",
311
+ "### **总结**\n",
312
+ "\n",
313
+ "持续预训练是适配领域任务和提升模型性能的重要方法,通过引入新数据或优化模型训练策略,可以让大模型在特定场景中表现更优。配合参数高效微调方法(如 LoRA),还可显著降低计算开销,提升训练效率。这种技术在学术研究、工业应用和前沿领域(如法律、医学等)中均具有广泛价值。"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": null,
319
+ "id": "ca41ad33-18fb-44da-8f79-0380b5c9dcaa",
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": []
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "id": "3038550c-cc92-45c9-8bb4-46c58688bfc5",
327
+ "metadata": {},
328
+ "source": [
329
+ "## 本节任务\n",
330
+ "本节任务是基于llama。训练一个能够处理dna和protein蛋白质数据的基础预训练大模型,数据为第一章中的预训练数据,包括英文数据。"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "b1bd33b8-2e05-4b59-9d8f-c48de194cfd6",
336
+ "metadata": {},
337
+ "source": [
338
+ "## 代码运行\n",
339
+ "\n",
340
+ "```\n",
341
+ "# 复制第一章训练数据,包括dna,protein,还有英文数据,添加英文数据是为了避免遗忘问题\n",
342
+ "mkdir train_data\n",
343
+ "cp ../01-data_env/data/*.txt train_data/\n",
344
+ "\n",
345
+ "#持续预训练\n",
346
+ "./run_pt.sh\n",
347
+ "\n",
348
+ "#合并模型\n",
349
+ "./merge_pt_model.sh\n",
350
+ "\n",
351
+ "```"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "id": "4960a36c-7529-4db8-b91d-df91245f79d9",
357
+ "metadata": {},
358
+ "source": [
359
+ "## 模型验证"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": null,
365
+ "id": "69b3e97f-a801-4264-a651-a854bcfba9c6",
366
+ "metadata": {},
367
+ "outputs": [],
368
+ "source": [
369
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
370
+ "from transformers import DataCollatorForLanguageModeling\n",
371
+ "from transformers import Trainer, TrainingArguments\n",
372
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
373
+ "from tokenizers import Tokenizer\n",
374
+ "from datasets import load_dataset"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "339435d9-9379-4b30-ae8b-50feee1ba714",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-merge-hf\")\n",
385
+ "tokenizer.pad_token = tokenizer.eos_token\n",
386
+ "tokenizer"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": null,
392
+ "id": "d0f154bb-b1ab-4611-a14c-9b403043fd96",
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-merge-hf\") #continue pretrain\n",
397
+ "model"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "id": "792a9f78-1828-4695-9f6e-479a704ea7e8",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "from transformers import AutoConfig\n",
408
+ "# 加载配置\n",
409
+ "config = AutoConfig.from_pretrained('dnahlm-merge-hf')\n",
410
+ "config"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "49021c65-54bb-4a97-a96d-b030cc3dcd13",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "text='''GCTGACTCTGCCAGGATGGAATGAAATTAGGTTGTTTTAATTATAATGTAAAGTCAGTTCTAGTCAGACATAGTCACATAGGCAAGTAAGGGAACCTAAAATTGCTTGGAAT,\n",
421
+ "KCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV,\n",
422
+ "The primary use of LLaMA is research on large language models, including'''\n",
423
+ "print(\"Test text:\\n\",text)\n",
424
+ "print(f\"Tokenized by DNA-LLaMA tokenizer:{tokenizer.tokenize(text)}\")"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "ebf869c8-866d-4770-8f64-79d671f88663",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "import torch\n",
435
+ "from transformers import pipeline\n",
436
+ "\n",
437
+ "model_id = \"dnahlm-merge-hf\"\n",
438
+ "\n",
439
+ "pipe = pipeline(\n",
440
+ " \"text-generation\", \n",
441
+ " model=model_id, \n",
442
+ " #torch_dtype=torch.bfloat16, \n",
443
+ " device_map=\"auto\",\n",
444
+ ")\n",
445
+ "\n",
446
+ "pipe(\"The key to life is\")"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "40a22c70-f1c4-4cd5-a118-2f5db40790e6",
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "pipe(\"GGAATGAAATTAGGTTGTTTTAATTATAATGTAAAGTCAGTTCT\")"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "id": "aec95d0a-4269-4540-bf14-4ce157b9a194",
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "pipe(\"KCGFVGPMVHLKVHLEADVASSCRSAVIYLTSEEPFEGVLGLRLKEGIAITGCWPRWPDEMDERSAVWRVEPYTRHFGRVLYSFGV\")"
467
+ ]
468
+ }
469
+ ],
470
+ "metadata": {
471
+ "kernelspec": {
472
+ "display_name": "Python 3 (ipykernel)",
473
+ "language": "python",
474
+ "name": "python3"
475
+ },
476
+ "language_info": {
477
+ "codemirror_mode": {
478
+ "name": "ipython",
479
+ "version": 3
480
+ },
481
+ "file_extension": ".py",
482
+ "mimetype": "text/x-python",
483
+ "name": "python",
484
+ "nbconvert_exporter": "python",
485
+ "pygments_lexer": "ipython3",
486
+ "version": "3.12.3"
487
+ }
488
+ },
489
+ "nbformat": 4,
490
+ "nbformat_minor": 5
491
+ }
04-gene-sft/7-llama-instruction-ft.ipynb ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 4.7 基于llama的基因大模型指令微调"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "c844103d-4e27-41b9-9bf1-c6a577846ab6",
14
+ "metadata": {},
15
+ "source": [
16
+ "### **大模型的指令微调(Instruction Fine-Tuning)**\n",
17
+ "\n",
18
+ "指令微调是指通过对大语言模型(如 GPT、T5、LLaMA 等)进行微调,使其能够更好地理解和执行人类以指令形式表达的任务。这种技术是大模型适配实际应用和增强用户交互能力的关键手段。\n",
19
+ "\n",
20
+ "---\n",
21
+ "\n",
22
+ "### **1. 指令微调的核心概念**\n",
23
+ "\n",
24
+ "指令微调的目标是通过在包含指令的专用数据集上进行微调,让模型能够:\n",
25
+ "1. 理解用户的任务需求(以自然语言表达的指令形式)。\n",
26
+ "2. 根据指令内容生成符合预期的高质量响应。\n",
27
+ "3. 适应多任务场景,减少特定任务的单独训练需求。\n",
28
+ "\n",
29
+ "---\n",
30
+ "\n",
31
+ "### **2. 指令微调的关键特点**\n",
32
+ "\n",
33
+ "1. **多任务统一**:\n",
34
+ " - 不需要针对每个任务单独微调,而是通过指令微调使模型能适应多种任务。\n",
35
+ " \n",
36
+ "2. **自然语言交互**:\n",
37
+ " - 用户可以用自然语言指令与模型交互,无需提供特定格式的输入。\n",
38
+ "\n",
39
+ "3. **泛化能力**:\n",
40
+ " - 微调后的模型能够对未见过的任务产生合理的推断和响应。\n",
41
+ "\n",
42
+ "---\n",
43
+ "\n",
44
+ "### **3. 数据集的构建与使用**\n",
45
+ "\n",
46
+ "#### **(1)指令微调数据集的特点**\n",
47
+ "- 数据通常包含以下三部分:\n",
48
+ " 1. **指令(Instruction)**:任务描述或问题,例如“将以下文本翻译为法语”。\n",
49
+ " 2. **输入(Input)**:任务相关的上下文或数据,可以为空。\n",
50
+ " 3. **输出(Output)**:模型期望生成的结果。\n",
51
+ "\n",
52
+ "#### **(2)常用指令微调数据集**\n",
53
+ "- **FLAN**:包含多个 NLP 任务的指令数据集,用于 T5 等模型的微调。\n",
54
+ "- **OpenAI 提供的指令数据**:如 GPT 系列的 ChatGPT 调优数据集。\n",
55
+ "- **InstructGPT 数据**:通过人类标注的多任务指令数据,用于模型优化。\n",
56
+ "- **Self-Instruct**:通过模型自生成指令和回答,进一步扩展训练数据。\n",
57
+ "\n",
58
+ "#### **(3)构建自己的数据集**\n",
59
+ "- 如果需要特定领域的指令微调,可以自行构建数据集:\n",
60
+ " - 收集任务需求和示例。\n",
61
+ " - 设计多样化的指令。\n",
62
+ " - 使用专家标注或模型辅助生成高质量答案。\n",
63
+ "\n",
64
+ "---\n",
65
+ "\n",
66
+ "### **4. 微调的步骤**\n",
67
+ "\n",
68
+ "#### **(1)加载基础模型**\n",
69
+ "从 Hugging Face 或其他框架加载预训练的大语言模型,例如 GPT-2、T5、LLaMA。\n",
70
+ "\n",
71
+ "#### **(2)准备数据集**\n",
72
+ "将指令微调数据集格式化为:\n",
73
+ "```python\n",
74
+ "{\n",
75
+ " \"instruction\": \"Translate the following text to French\",\n",
76
+ " \"input\": \"Hello, how are you?\",\n",
77
+ " \"output\": \"Bonjour, comment ça va?\"\n",
78
+ "}\n",
79
+ "```\n",
80
+ "\n",
81
+ "#### **(3)定义微调方法**\n",
82
+ "使用 `Trainer` 或分布式框架(如 DeepSpeed、Accelerate)进行微调。\n",
83
+ "\n",
84
+ "---\n",
85
+ "\n",
86
+ "### **5. 示例代码:指令微调实现**\n",
87
+ "\n",
88
+ "以下是基于 Hugging Face 的指令微调代码示例:\n",
89
+ "\n",
90
+ "```python\n",
91
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer\n",
92
+ "from datasets import load_dataset\n",
93
+ "\n",
94
+ "# 1. 加载预训练模型和分词器\n",
95
+ "model_name = \"gpt2\"\n",
96
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
97
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
98
+ "\n",
99
+ "# 2. 加载指令微调数据集\n",
100
+ "# 数据格式应包含 instruction, input, output 字段\n",
101
+ "dataset = load_dataset(\"path/to/instruction_dataset\")\n",
102
+ "\n",
103
+ "# 3. 数据预处理\n",
104
+ "def preprocess_function(example):\n",
105
+ " # 将指令和输入拼接成完整的提示\n",
106
+ " prompt = example[\"instruction\"]\n",
107
+ " if example[\"input\"]:\n",
108
+ " prompt += f\"\\n{example['input']}\"\n",
109
+ " labels = example[\"output\"]\n",
110
+ " tokenized = tokenizer(prompt, truncation=True, max_length=512, padding=\"max_length\")\n",
111
+ " with tokenizer.as_target_tokenizer():\n",
112
+ " tokenized_labels = tokenizer(labels, truncation=True, max_length=512, padding=\"max_length\")\n",
113
+ " tokenized[\"labels\"] = tokenized_labels[\"input_ids\"]\n",
114
+ " return tokenized\n",
115
+ "\n",
116
+ "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
117
+ "\n",
118
+ "# 4. 设置训练参数\n",
119
+ "training_args = TrainingArguments(\n",
120
+ " output_dir=\"./instruction_finetuned_model\",\n",
121
+ " per_device_train_batch_size=4,\n",
122
+ " num_train_epochs=3,\n",
123
+ " evaluation_strategy=\"epoch\",\n",
124
+ " save_strategy=\"epoch\",\n",
125
+ " learning_rate=5e-5,\n",
126
+ " weight_decay=0.01,\n",
127
+ " logging_dir=\"./logs\",\n",
128
+ " fp16=True,\n",
129
+ ")\n",
130
+ "\n",
131
+ "# 5. 定义 Trainer\n",
132
+ "trainer = Trainer(\n",
133
+ " model=model,\n",
134
+ " args=training_args,\n",
135
+ " train_dataset=tokenized_datasets[\"train\"],\n",
136
+ " eval_dataset=tokenized_datasets[\"test\"],\n",
137
+ " tokenizer=tokenizer,\n",
138
+ ")\n",
139
+ "\n",
140
+ "# 6. 开始训练\n",
141
+ "trainer.train()\n",
142
+ "\n",
143
+ "# 7. 保存模型\n",
144
+ "model.save_pretrained(\"./instruction_finetuned_model\")\n",
145
+ "tokenizer.save_pretrained(\"./instruction_finetuned_model\")\n",
146
+ "```\n",
147
+ "\n",
148
+ "---\n",
149
+ "\n",
150
+ "### **6. 指令微调的挑战**\n",
151
+ "\n",
152
+ "1. **数据质量**:\n",
153
+ " - 低质量或噪声数据可能导致模型生成结果不符合指令。\n",
154
+ "\n",
155
+ "2. **指令覆盖范围**:\n",
156
+ " - 数据集指令种类不足会限制模型的泛化能力。\n",
157
+ "\n",
158
+ "3. **计算资源需求**:\n",
159
+ " - 大模型的微调需要高性能 GPU 和大容量存储。\n",
160
+ "\n",
161
+ "4. **灾难性遗忘**:\n",
162
+ " - 微调过程中可能导致模型丧失部分原始能力。\n",
163
+ "\n",
164
+ "---\n",
165
+ "\n",
166
+ "### **7. 指令微调的应用场景**\n",
167
+ "\n",
168
+ "1. **多任务问答**:\n",
169
+ " - 适配多任务场景,支持翻译、总结、推理等功能。\n",
170
+ "\n",
171
+ "2. **特定领域优化**:\n",
172
+ " - 在法律、医疗等特定领域的任务指令上进行微调。\n",
173
+ "\n",
174
+ "3. **用户交互优化**:\n",
175
+ " - 提升模型对自然语言指令的理解和响应能力。\n",
176
+ "\n",
177
+ "4. **开放式对话生成**:\n",
178
+ " - 优化模型在对话场景下的表现,例如 ChatGPT 的微调。\n",
179
+ "\n",
180
+ "---\n",
181
+ "\n",
182
+ "### **总结**\n",
183
+ "\n",
184
+ "指令微调通过在特定格式的数据集上进一步训练大模型,使其能够更好地理解和执行用户的自然语言指令。这种方法适合多任务场景,并能提升模型的交互能力和领域适应性。借助高质量的指令数据集和高效的微调技术,大模型在实际应用中的表现可以得到显著提升。"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "id": "7be8b814-42f6-4fb6-bf4b-ae23292030f6",
190
+ "metadata": {},
191
+ "source": []
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "f9bed0ae-337d-49af-85f0-c8e6263d78db",
196
+ "metadata": {},
197
+ "source": [
198
+ "**大模型的持续预训练**和**指令微调**是两种针对大模型的后续优化策略,虽然它们的目标都是提升模型性能,但在应用场景、方法和效果等方面有明显区别。以下是它们的对比分析:\n",
199
+ "\n",
200
+ "---\n",
201
+ "\n",
202
+ "### **1. 概念与目标**\n",
203
+ "\n",
204
+ "| **特性** | **持续预训练** | **指令微调** |\n",
205
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
206
+ "| **定义** | 在通用预训练模型上,使用新的大规模语料(通用或领域特定数据)进行进一步预训练。 | 在包含指令任务的数据集上对大模型进行微调,以提升模型对人类指令的理解和执行能力。 |\n",
207
+ "| **目标** | 提升模型的通用能力或适应特定领域的语言理解与生成能力。 | 提高模型对多任务指令的泛化能力,让模型更好地理解和执行自然语言表达的具体任务。 |\n",
208
+ "| **典型应用** | 领域适配(医学、法律、金融)、性能优化、跨语言适配等。 | 多任务问答、开放式对话生成、翻译、推理等需要用户直接交互的场景。 |\n",
209
+ "\n",
210
+ "---\n",
211
+ "\n",
212
+ "### **2. 数据使用**\n",
213
+ "\n",
214
+ "| **特性** | **持续预训练** | **指令微调** |\n",
215
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
216
+ "| **数据类型** | 通用语料(如新闻、社交媒体文本)或领域特定语料(如 PubMed、法律文档、金融报告)。 | 任务指令数据集,包括指令(Instruction)、输入(Input)和输出(Output)。 |\n",
217
+ "| **数据构建** | 通常需要清洗和去重大规模语料数据,避免与原始预训练数据重叠。 | 通常由人工标注或模型生成的指令数据构成,例如 FLAN、InstructGPT 数据集。 |\n",
218
+ "| **多样性要求** | 数据应覆盖尽可能广的领域或目标领域的多种场景,以提升模型在这些场景的表现。 | 数据需要覆盖多种任务类型(如翻译、分类、摘要)和丰富的指令表达形式,以提高模型对多任务的适配能力。 |\n",
219
+ "\n",
220
+ "---\n",
221
+ "\n",
222
+ "### **3. 方法与技术**\n",
223
+ "\n",
224
+ "| **特性** | **持续预训练** | **指令微调** |\n",
225
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
226
+ "| **主要技术** | 继续使用自监督学习目标(如语言建模、掩码预测)进行训练。 | 使用监督学习,通常以任务输入和目标输出对为数据,通过微调适配特定任务需求。 |\n",
227
+ "| **模型调整** | - 可选择全量参数更新或冻结部分参数。<br>- 可结合参数高效微调技术(如 LoRA、Adapter)。 | - 通常使用监督训练方式,可能结合参数高效微调技术(如 LoRA)。 |\n",
228
+ "| **学习率** | 通常使用较小的学习率(如 `1e-5` 或更小),以防止破坏原始权重。 | 同样使用较小的学习率,但任务指令微调可能需要更高的关注任务特定的标签对准。 |\n",
229
+ "\n",
230
+ "---\n",
231
+ "\n",
232
+ "### **4. 模型能力与效果**\n",
233
+ "\n",
234
+ "| **特性** | **持续预训练** | **指令微调** |\n",
235
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
236
+ "| **提升的能力** | - 对领域特定语言模式和知识的适配性提升显著。<br>- 对未见过的通用场景生成能力增强(扩展模型知识广度)。 | - 显著提升模型对指令理解的能力,尤其是自然语言表达的任务需求。<br>- 对多任务和零样本任务的泛化能力有较大提升。 |\n",
237
+ "| **局限性** | - 对具体任务的直接适配能力较弱,可能需要额外的任务微调。<br>- 数据选择不当可能导致灾难性遗忘。 | - 依赖高质量的指令数据集,数据质量不高会导致模型生成结果不稳定。<br>- 对通用能力的提升有限。 |\n",
238
+ "\n",
239
+ "---\n",
240
+ "\n",
241
+ "### **5. 应用场景与示例**\n",
242
+ "\n",
243
+ "| **特性** | **持续预训练** | **指令微调** |\n",
244
+ "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
245
+ "| **典型应用场景** | - 医学文献总结(通过 PubMed 语料持续预训练)。<br>- 法律条文分析(通过法律文档进一步训练)。<br>- 增强多语言生成能力(跨语言语料)。 | - ChatGPT 的多任务对话生成。<br>- 翻译、摘要、问答等用户交互任务的泛化处理。 |\n",
246
+ "| **实际示例** | - BioBERT:在 BERT 基础上使用生物医学语料持续预训练的模型。<br>- FinBERT:针对��融领域持续预训练的语言模型。 | - InstructGPT:在 GPT-3 基础上进行指令微调,用于多任务用户交互。<br>- FLAN-T5:通过 FLAN 数据集进行指令微调。 |\n",
247
+ "\n",
248
+ "---\n",
249
+ "\n",
250
+ "### **6. 持续预训练与指令微调的结合**\n",
251
+ "\n",
252
+ "持续预训练和指令微调可以结合使用,形成一个从领域适配到任务适配的完整流程:\n",
253
+ "1. **持续预训练**:\n",
254
+ " - 先在领域特定数据(如医学、法律、金融语料)上进行持续预训练,获取领域知识。\n",
255
+ "2. **指令微调**:\n",
256
+ " - 再利用多任务指令数据集对模型微调,使其能够高效执行领域内的多样化任务。\n",
257
+ "\n",
258
+ "这种结合方式特别适用于需要领域知识和任务适配的场景,例如医学问答系统或金融文本分析。\n",
259
+ "\n",
260
+ "---\n",
261
+ "\n",
262
+ "### **总结**\n",
263
+ "\n",
264
+ "| **维度** | **持续预训练** | **指令微调** |\n",
265
+ "|------------------------|-------------------------------------|----------------------------------|\n",
266
+ "| **目标** | 增强通用能力或适配特定领域。 | 提升对任务指令的理解和执行能力。 |\n",
267
+ "| **数据集** | 通用或领域语料。 | 指令数据集,包含输入和输出对。 |\n",
268
+ "| **方法** | 自监督学习,扩展语言建模能力。 | 监督学习,强化任务适配能力。 |\n",
269
+ "| **适用场景** | 领域特定任务(如医学、法律)。 | 多任务交互(如问答、对话生成)。 |\n",
270
+ "| **局限性** | 对具体任务适配较弱。 | 通用能力提升有限,依赖数据质量。 |\n",
271
+ "\n",
272
+ "两者各有侧重,且在许多场景下可以结合使用,形成一个强大的任务和领域适配框架。"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "markdown",
277
+ "id": "f97a705a-b946-4dc1-a173-a9df033d6f2b",
278
+ "metadata": {},
279
+ "source": [
280
+ "## 本节任务\n",
281
+ "本节任务是基于上一节预训练的llama生物大模型。对一些生物学任务进行微调,包含了多个不同类型的分类问题和多序列交换问题。具体可见sft_data下的数据。"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "markdown",
286
+ "id": "9782db62-95bd-40a6-9759-966b9a0b362e",
287
+ "metadata": {},
288
+ "source": [
289
+ "## 代码运行\n",
290
+ "\n",
291
+ "```\n",
292
+ "\n",
293
+ "#微调\n",
294
+ "./run_sft.sh\n",
295
+ "\n",
296
+ "#合并模型\n",
297
+ "./merge_sft_model.sh\n",
298
+ "\n",
299
+ "```"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "id": "182b82c4-d484-4c15-a600-03c3b51367ec",
305
+ "metadata": {},
306
+ "source": [
307
+ "## 模型验证"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 1,
313
+ "id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": [
317
+ "import subprocess\n",
318
+ "import os\n",
319
+ "# 设置环境变量, autodl一般区域\n",
320
+ "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
321
+ "output = result.stdout\n",
322
+ "for line in output.splitlines():\n",
323
+ " if '=' in line:\n",
324
+ " var, value = line.split('=', 1)\n",
325
+ " os.environ[var] = value"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "17bdb69d-3f0f-465e-bd60-2047a088e264",
331
+ "metadata": {},
332
+ "source": [
333
+ "如果您不确定模型中有哪些模块可以微调,可以打印模型结构:"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "054a2956-9045-4ad5-a878-1bfc84ad4ed8",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
344
+ "from transformers import DataCollatorForLanguageModeling\n",
345
+ "from transformers import Trainer, TrainingArguments\n",
346
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
347
+ "from tokenizers import Tokenizer\n",
348
+ "from datasets import load_dataset"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "id": "63c8bf16-9576-41bc-b27c-c92ba4289cf4",
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "from datasets import load_dataset\n",
359
+ "dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n",
360
+ "dna_ft_dataset"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "95928da3-ca64-4a17-80f4-945da395702c",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n",
371
+ "data"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "id": "a3e65bcd-85ce-4261-8ba6-7665c4ec60e2",
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n",
382
+ "tokenizer.pad_token = tokenizer.eos_token"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "id": "3d3fe49b-f48f-42b2-bc97-028e443111e4",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #continue pretrain\n",
393
+ "model"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "id": "c54df9fe-86c4-4963-b313-b438894bf9dd",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "#构建提示词\n",
404
+ "def format_input(entry):\n",
405
+ " instruction_text = (\n",
406
+ " f\"Below is an instruction that describes a task. \"\n",
407
+ " f\"Write a response that appropriately completes the request.\"\n",
408
+ " f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
409
+ " )\n",
410
+ "\n",
411
+ " input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
412
+ "\n",
413
+ " return instruction_text + input_text + \"\\n\\n### Response:\\n\"\n",
414
+ "\n",
415
+ "#构建提示词\n",
416
+ "def build_prompt(entry):\n",
417
+ "\n",
418
+ " input_data = format_input(entry)\n",
419
+ "\n",
420
+ " desired_response = entry['output']\n",
421
+ "\n",
422
+ " return input_data + desired_response\n"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "ee540cfb-1f6e-4e02-a3bc-c814e43685cb",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "example = data[\"test\"][0]\n",
433
+ "example"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "7ee35528-7b3f-4e60-b88b-1bc3e950012b",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "prompt = build_prompt(example)\n",
444
+ "print(prompt)"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "8aa6f38f-3bcc-4566-8a66-a541db91e031",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "tokenizer.tokenize(prompt)"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "11875339-4901-4912-86e5-afe8c74921d9",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n",
465
+ " # Tokenize\n",
466
+ " input_ids = tokenizer.encode(\n",
467
+ " text,\n",
468
+ " return_tensors=\"pt\",\n",
469
+ " truncation=True,\n",
470
+ " max_length=max_input_tokens\n",
471
+ " # return_attention_mask=True,\n",
472
+ " )\n",
473
+ "\n",
474
+ " # Generate\n",
475
+ " device = model.device\n",
476
+ " generated_tokens_with_prompt = model.generate(\n",
477
+ " input_ids=input_ids.to(device),\n",
478
+ " #max_length=max_output_tokens,\n",
479
+ " max_new_tokens=8,\n",
480
+ " temperature=0.01 # 控制生成的多样性\n",
481
+ " )\n",
482
+ "\n",
483
+ " # Decode\n",
484
+ " generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n",
485
+ " generated_text_answer = generated_text_with_prompt[len(text):]\n",
486
+ "\n",
487
+ "\n",
488
+ " return generated_text_answer\n",
489
+ "\n",
490
+ "# 如果需要进一步清理\n",
491
+ "def clean_generated_text(text):\n",
492
+ " # 去除 'Ġ' 符号并替换为空格\n",
493
+ " text = text.replace('Ġ', ' ')\n",
494
+ " # 去除多余的空格\n",
495
+ " text = ' '.join(text.split())\n",
496
+ " return text"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "id": "1b02644a-8b24-45aa-b22d-0f7ce2270dd9",
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "input_text = format_input(data[\"test\"][0])\n",
507
+ "\n",
508
+ "print(\"input (test):\", input_text)\n",
509
+ "\n",
510
+ "print(\"real answer:\", data[\"test\"][0][\"output\"])\n",
511
+ "\n",
512
+ "print(\"--------------------------\\n\")\n",
513
+ "\n",
514
+ "print(\"model's answer: \\n\")\n",
515
+ "print(inference(input_text, model, tokenizer))"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "id": "e2df1569-7f70-46ee-b93f-cbd879e32e14",
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "test_data = data[\"test\"].shuffle(seed=199).select(range(100))\n",
526
+ "\n",
527
+ "data_list = []\n",
528
+ "\n",
529
+ "for entry in test_data:\n",
530
+ " input_text = format_input(entry)\n",
531
+ " #print(input_text)\n",
532
+ " response_text = inference(input_text, model, tokenizer)\n",
533
+ " #print(response_text)\n",
534
+ " data = {\n",
535
+ " \"instruction\":entry[\"instruction\"],\n",
536
+ " \"input\":entry[\"input\"],\n",
537
+ " \"output\":entry[\"output\"],\n",
538
+ " \"model_response\":response_text\n",
539
+ " }\n",
540
+ "\n",
541
+ " data_list.append(data)"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "id": "0c6e47cb-1b64-4690-a51d-f1816b82f15f",
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "import json\n",
552
+ "\n",
553
+ "# 定义输出文件路径\n",
554
+ "output_file = 'llama-sft-2.json'\n",
555
+ "\n",
556
+ "# 将 Dataset 对象导出为 JSON 文件\n",
557
+ "# test_data.to_json(output_file)\n",
558
+ "with open(output_file, \"w\") as file:\n",
559
+ " json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n",
560
+ "\n"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "68831e19-5a99-46d8-9f40-e8bf6957dbfc",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "import json\n",
571
+ "from tqdm import tqdm\n",
572
+ "\n",
573
+ "\n",
574
+ "\n",
575
+ "with open(output_file, \"r\") as file:\n",
576
+ " test_data = json.load(file)\n",
577
+ "\n",
578
+ "all_num = len(test_data)\n",
579
+ "right_sum = 0\n",
580
+ "same_sum = 0\n",
581
+ "for item in test_data:\n",
582
+ " output = item[\"output\"]\n",
583
+ " #output = \" \".join(tokenizer.tokenize(output))\n",
584
+ " model_response = item[\"model_response\"]\n",
585
+ "\n",
586
+ " print(output,\"||||||||||||\", model_response)\n",
587
+ "\n",
588
+ " if model_response == output: #same it\n",
589
+ " same_sum = same_sum + 1\n",
590
+ " \n",
591
+ " if output.find(\"Non\")==-1: # no Non\n",
592
+ " if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n",
593
+ " right_sum = right_sum + 1\n",
594
+ " else:\n",
595
+ " if model_response.find(output)!=-1: #find it\n",
596
+ " right_sum = right_sum + 1\n",
597
+ "\n",
598
+ "\n",
599
+ "print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n"
600
+ ]
601
+ }
602
+ ],
603
+ "metadata": {
604
+ "kernelspec": {
605
+ "display_name": "Python 3 (ipykernel)",
606
+ "language": "python",
607
+ "name": "python3"
608
+ },
609
+ "language_info": {
610
+ "codemirror_mode": {
611
+ "name": "ipython",
612
+ "version": 3
613
+ },
614
+ "file_extension": ".py",
615
+ "mimetype": "text/x-python",
616
+ "name": "python",
617
+ "nbconvert_exporter": "python",
618
+ "pygments_lexer": "ipython3",
619
+ "version": "3.12.3"
620
+ }
621
+ },
622
+ "nbformat": 4,
623
+ "nbformat_minor": 5
624
+ }
04-gene-sft/gene_bpe_seg.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:433c302f3c6642f4400e95e5143e08d3cf1a102fd34e4143b2c837550b13e8a6
3
+ size 1102702
04-gene-sft/gene_bpe_seg.vocab ADDED
The diff for this file is too large to render. See raw diff
 
04-gene-sft/img/.ipynb_checkpoints/sft-checkpoint.png ADDED
04-gene-sft/img/.ipynb_checkpoints/sft2-checkpoint.png ADDED
04-gene-sft/img/deepspeed.png ADDED
04-gene-sft/llama_sft_test.ipynb ADDED
@@ -0,0 +1,1627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "81a2413e-8629-4016-aace-17d2f757f726",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "https://hf-mirror.com\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import os\n",
19
+ "\n",
20
+ "# 设置环境变量\n",
21
+ "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
22
+ "\n",
23
+ "# 打印环境变量以确认设置成功\n",
24
+ "print(os.environ.get('HF_ENDPOINT'))"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "id": "89e2d33a-6d84-4ef3-b44e-daa57ac81e58",
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stderr",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "2024-11-24 11:21:51.020375: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
38
+ "2024-11-24 11:21:51.036615: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
39
+ "2024-11-24 11:21:51.053557: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
40
+ "2024-11-24 11:21:51.058466: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
41
+ "2024-11-24 11:21:51.071840: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
42
+ "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
43
+ "2024-11-24 11:21:51.923693: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
49
+ "from transformers import DataCollatorForLanguageModeling\n",
50
+ "from transformers import Trainer, TrainingArguments\n",
51
+ "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
52
+ "from tokenizers import Tokenizer\n",
53
+ "from datasets import load_dataset"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 3,
59
+ "id": "68fc5c44-b444-402e-aaf2-0ba4e2000e42",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "DatasetDict({\n",
66
+ " train: Dataset({\n",
67
+ " features: ['instruction', 'input', 'output'],\n",
68
+ " num_rows: 19839\n",
69
+ " })\n",
70
+ "})"
71
+ ]
72
+ },
73
+ "execution_count": 3,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "from datasets import load_dataset\n",
80
+ "dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n",
81
+ "dna_ft_dataset"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 4,
87
+ "id": "4ab4fd3e-5b59-470e-9b46-f0ffd7b9d1aa",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "data": {
92
+ "text/plain": [
93
+ "DatasetDict({\n",
94
+ " train: Dataset({\n",
95
+ " features: ['instruction', 'input', 'output'],\n",
96
+ " num_rows: 1983\n",
97
+ " })\n",
98
+ " test: Dataset({\n",
99
+ " features: ['instruction', 'input', 'output'],\n",
100
+ " num_rows: 17856\n",
101
+ " })\n",
102
+ "})"
103
+ ]
104
+ },
105
+ "execution_count": 4,
106
+ "metadata": {},
107
+ "output_type": "execute_result"
108
+ }
109
+ ],
110
+ "source": [
111
+ "data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n",
112
+ "data"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 5,
118
+ "id": "85ca97f5-6864-4d6f-944a-182ed1fa2f00",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n",
123
+ "tokenizer.pad_token = tokenizer.eos_token"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 6,
129
+ "id": "e904c0b2-bf21-4036-b510-8e57177c1767",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "application/vnd.jupyter.widget-view+json": {
135
+ "model_id": "99ce92d0373a498d929bed42f770ed16",
136
+ "version_major": 2,
137
+ "version_minor": 0
138
+ },
139
+ "text/plain": [
140
+ "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
141
+ ]
142
+ },
143
+ "metadata": {},
144
+ "output_type": "display_data"
145
+ },
146
+ {
147
+ "data": {
148
+ "text/plain": [
149
+ "LlamaForCausalLM(\n",
150
+ " (model): LlamaModel(\n",
151
+ " (embed_tokens): Embedding(61973, 4096, padding_idx=0)\n",
152
+ " (layers): ModuleList(\n",
153
+ " (0-31): 32 x LlamaDecoderLayer(\n",
154
+ " (self_attn): LlamaSdpaAttention(\n",
155
+ " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
156
+ " (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
157
+ " (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
158
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
159
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
160
+ " )\n",
161
+ " (mlp): LlamaMLP(\n",
162
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
163
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
164
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
165
+ " (act_fn): SiLU()\n",
166
+ " )\n",
167
+ " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
168
+ " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
169
+ " )\n",
170
+ " )\n",
171
+ " (norm): LlamaRMSNorm((4096,), eps=1e-06)\n",
172
+ " (rotary_emb): LlamaRotaryEmbedding()\n",
173
+ " )\n",
174
+ " (lm_head): Linear(in_features=4096, out_features=61973, bias=False)\n",
175
+ ")"
176
+ ]
177
+ },
178
+ "execution_count": 6,
179
+ "metadata": {},
180
+ "output_type": "execute_result"
181
+ }
182
+ ],
183
+ "source": [
184
+ "model = LlamaForCausalLM.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #continue pretrain\n",
185
+ "model"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 7,
191
+ "id": "5b361c5c-c43f-4ed9-a5c7-c72403cd7a0a",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "#构建提示词\n",
196
+ "def format_input(entry):\n",
197
+ " instruction_text = (\n",
198
+ " f\"Below is an instruction that describes a task. \"\n",
199
+ " f\"Write a response that appropriately completes the request.\"\n",
200
+ " f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
201
+ " )\n",
202
+ "\n",
203
+ " input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
204
+ "\n",
205
+ " return instruction_text + input_text + \"\\n\\n### Response:\\n\"\n",
206
+ "\n",
207
+ "#构建提示词\n",
208
+ "def build_prompt(entry):\n",
209
+ "\n",
210
+ " input_data = format_input(entry)\n",
211
+ "\n",
212
+ " desired_response = entry['output']\n",
213
+ "\n",
214
+ " return input_data + desired_response\n",
215
+ "\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 8,
221
+ "id": "ed031a26-d79e-4f50-85d1-169ebd409c6d",
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "{'instruction': 'Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.',\n",
228
+ " 'input': 'CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC',\n",
229
+ " 'output': 'promoter'}"
230
+ ]
231
+ },
232
+ "execution_count": 8,
233
+ "metadata": {},
234
+ "output_type": "execute_result"
235
+ }
236
+ ],
237
+ "source": [
238
+ "example = data[\"test\"][0]\n",
239
+ "example"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 9,
245
+ "id": "31bd4bb5-86a6-4046-b510-492b0548323b",
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
253
+ "\n",
254
+ "### Instruction:\n",
255
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
256
+ "\n",
257
+ "### Input:\n",
258
+ "CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
259
+ "\n",
260
+ "### Response:\n",
261
+ "promoter\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "prompt = build_prompt(example)\n",
267
+ "print(prompt)"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 10,
273
+ "id": "ed0b5b8b-916c-499b-a6da-f1124b9add1c",
274
+ "metadata": {
275
+ "scrolled": true
276
+ },
277
+ "outputs": [
278
+ {
279
+ "data": {
280
+ "text/plain": [
281
+ "['▁Below',\n",
282
+ " '▁is',\n",
283
+ " '▁an',\n",
284
+ " '▁instruction',\n",
285
+ " '▁that',\n",
286
+ " '▁describes',\n",
287
+ " '▁a',\n",
288
+ " '▁task',\n",
289
+ " '.',\n",
290
+ " '▁Write',\n",
291
+ " '▁a',\n",
292
+ " '▁response',\n",
293
+ " '▁that',\n",
294
+ " '▁appropri',\n",
295
+ " 'ately',\n",
296
+ " '▁comple',\n",
297
+ " 'tes',\n",
298
+ " '▁the',\n",
299
+ " '▁request',\n",
300
+ " '.',\n",
301
+ " '<0x0A>',\n",
302
+ " '<0x0A>',\n",
303
+ " '##',\n",
304
+ " '#',\n",
305
+ " '▁Inst',\n",
306
+ " 'ruction',\n",
307
+ " ':',\n",
308
+ " '<0x0A>',\n",
309
+ " 'Det',\n",
310
+ " 'erm',\n",
311
+ " 'ine',\n",
312
+ " '▁core',\n",
313
+ " '▁prom',\n",
314
+ " 'oter',\n",
315
+ " '▁detection',\n",
316
+ " '▁of',\n",
317
+ " '▁following',\n",
318
+ " '▁d',\n",
319
+ " 'na',\n",
320
+ " '▁sequence',\n",
321
+ " ',',\n",
322
+ " '▁The',\n",
323
+ " '▁result',\n",
324
+ " '▁will',\n",
325
+ " '▁be',\n",
326
+ " '▁one',\n",
327
+ " '▁of',\n",
328
+ " '▁the',\n",
329
+ " '▁following',\n",
330
+ " ':',\n",
331
+ " '▁Non',\n",
332
+ " '-',\n",
333
+ " 'prom',\n",
334
+ " 'oter',\n",
335
+ " ',',\n",
336
+ " '▁prom',\n",
337
+ " 'oter',\n",
338
+ " '.',\n",
339
+ " '<0x0A>',\n",
340
+ " '<0x0A>',\n",
341
+ " '##',\n",
342
+ " '#',\n",
343
+ " '▁Input',\n",
344
+ " ':',\n",
345
+ " '<0x0A>',\n",
346
+ " 'CCGTG',\n",
347
+ " 'C',\n",
348
+ " 'GAC',\n",
349
+ " 'CGGAA',\n",
350
+ " 'GTG',\n",
351
+ " 'GGGC',\n",
352
+ " 'GGC',\n",
353
+ " 'GAC',\n",
354
+ " 'CCCGGAA',\n",
355
+ " 'GTCC',\n",
356
+ " 'CCGCC',\n",
357
+ " 'GGGTG',\n",
358
+ " 'CA',\n",
359
+ " 'GCT',\n",
360
+ " 'TG',\n",
361
+ " 'GTC',\n",
362
+ " 'GGT',\n",
363
+ " 'TC',\n",
364
+ " 'GATCGCC',\n",
365
+ " '<0x0A>',\n",
366
+ " '<0x0A>',\n",
367
+ " '##',\n",
368
+ " '#',\n",
369
+ " '▁Response',\n",
370
+ " ':',\n",
371
+ " '<0x0A>',\n",
372
+ " 'prom',\n",
373
+ " 'oter']"
374
+ ]
375
+ },
376
+ "execution_count": 10,
377
+ "metadata": {},
378
+ "output_type": "execute_result"
379
+ }
380
+ ],
381
+ "source": [
382
+ "tokenizer.tokenize(prompt)"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 11,
388
+ "id": "f0449aee-1ac6-4db5-873f-afdfb0fc9691",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n",
393
+ " # Tokenize\n",
394
+ " input_ids = tokenizer.encode(\n",
395
+ " text,\n",
396
+ " return_tensors=\"pt\",\n",
397
+ " truncation=True,\n",
398
+ " max_length=max_input_tokens\n",
399
+ " # return_attention_mask=True,\n",
400
+ " )\n",
401
+ "\n",
402
+ " # Generate\n",
403
+ " device = model.device\n",
404
+ " generated_tokens_with_prompt = model.generate(\n",
405
+ " input_ids=input_ids.to(device),\n",
406
+ " #max_length=max_output_tokens,\n",
407
+ " max_new_tokens=8,\n",
408
+ " temperature=0.01 # 控制生成的多样性\n",
409
+ " )\n",
410
+ "\n",
411
+ " # Decode\n",
412
+ " generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n",
413
+ " generated_text_answer = generated_text_with_prompt[len(text):]\n",
414
+ "\n",
415
+ "\n",
416
+ " return generated_text_answer\n",
417
+ "\n",
418
+ "# 如果需要进一步清理\n",
419
+ "def clean_generated_text(text):\n",
420
+ " # 去除 'Ġ' 符号并替换为空格\n",
421
+ " text = text.replace('Ġ', ' ')\n",
422
+ " # 去除多余的空格\n",
423
+ " text = ' '.join(text.split())\n",
424
+ " return text"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": 12,
430
+ "id": "e9041426-eb59-4314-82dd-7b6d6d477783",
431
+ "metadata": {},
432
+ "outputs": [
433
+ {
434
+ "name": "stdout",
435
+ "output_type": "stream",
436
+ "text": [
437
+ "input (test): Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
438
+ "\n",
439
+ "### Instruction:\n",
440
+ "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
441
+ "\n",
442
+ "### Input:\n",
443
+ "CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
444
+ "\n",
445
+ "### Response:\n",
446
+ "\n",
447
+ "real answer: promoter\n",
448
+ "--------------------------\n",
449
+ "\n",
450
+ "model's answer: \n",
451
+ "\n"
452
+ ]
453
+ },
454
+ {
455
+ "name": "stderr",
456
+ "output_type": "stream",
457
+ "text": [
458
+ "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.01` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
459
+ " warnings.warn(\n",
460
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
461
+ ]
462
+ },
463
+ {
464
+ "name": "stdout",
465
+ "output_type": "stream",
466
+ "text": [
467
+ " Non-promoter\n"
468
+ ]
469
+ }
470
+ ],
471
+ "source": [
472
+ "input_text = format_input(data[\"test\"][0])\n",
473
+ "\n",
474
+ "print(\"input (test):\", input_text)\n",
475
+ "\n",
476
+ "print(\"real answer:\", data[\"test\"][0][\"output\"])\n",
477
+ "\n",
478
+ "print(\"--------------------------\\n\")\n",
479
+ "\n",
480
+ "print(\"model's answer: \\n\")\n",
481
+ "print(inference(input_text, model, tokenizer))"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": 13,
487
+ "id": "d1489173-84af-4c8e-b66b-0cdbe42c7ea7",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "test_data = data[\"test\"].shuffle(seed=199).select(range(1000))\n",
492
+ "\n",
493
+ "data_list = []\n",
494
+ "\n",
495
+ "for entry in test_data:\n",
496
+ " input_text = format_input(entry)\n",
497
+ " #print(input_text)\n",
498
+ " response_text = inference(input_text, model, tokenizer)\n",
499
+ " #print(response_text)\n",
500
+ " data = {\n",
501
+ " \"instruction\":entry[\"instruction\"],\n",
502
+ " \"input\":entry[\"input\"],\n",
503
+ " \"output\":entry[\"output\"],\n",
504
+ " \"model_response\":response_text\n",
505
+ " }\n",
506
+ "\n",
507
+ " data_list.append(data)"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": 14,
513
+ "id": "39275fe6-ac3b-4558-9f4c-2853a41d48c4",
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": [
517
+ "import json\n",
518
+ "\n",
519
+ "# 定义输出文件路径\n",
520
+ "output_file = 'llama-sft-2.json'\n",
521
+ "\n",
522
+ "# 将 Dataset 对象导出为 JSON 文件\n",
523
+ "# test_data.to_json(output_file)\n",
524
+ "with open(output_file, \"w\") as file:\n",
525
+ " json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n",
526
+ "\n"
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "execution_count": 15,
532
+ "id": "7ffaba65-a270-4433-b234-932f5e288f7c",
533
+ "metadata": {},
534
+ "outputs": [
535
+ {
536
+ "data": {
537
+ "text/plain": [
538
+ "'▁prom oter'"
539
+ ]
540
+ },
541
+ "execution_count": 15,
542
+ "metadata": {},
543
+ "output_type": "execute_result"
544
+ }
545
+ ],
546
+ "source": [
547
+ "\" \".join(tokenizer.tokenize(\"promoter\"))"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": 16,
553
+ "id": "a7e373a4-6857-4874-b2da-58da2928925d",
554
+ "metadata": {},
555
+ "outputs": [
556
+ {
557
+ "name": "stdout",
558
+ "output_type": "stream",
559
+ "text": [
560
+ "Donor Sites |||||||||||| Donor Sites\n",
561
+ "promoter |||||||||||| promoter\n",
562
+ "promoter |||||||||||| promoter\n",
563
+ "promoter |||||||||||| promoter\n",
564
+ "promoter |||||||||||| promoter\n",
565
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
566
+ "promoter |||||||||||| Non-promoter\n",
567
+ "promoter |||||||||||| Non-promoter\n",
568
+ "Non-promoter |||||||||||| promoter\n",
569
+ "Non-promoter |||||||||||| Non-promoter\n",
570
+ "Donor Sites |||||||||||| Donor Sites\n",
571
+ "Non-promoter |||||||||||| Non-promoter\n",
572
+ "Non-promoter |||||||||||| promoter\n",
573
+ "Non-promoter |||||||||||| promoter\n",
574
+ "promoter |||||||||||| Non-promoter\n",
575
+ "promoter |||||||||||| Non-promoter\n",
576
+ "Donor Sites |||||||||||| Donor Sites\n",
577
+ "Background Sequences |||||||||||| Background Sequences\n",
578
+ "Non-promoter |||||||||||| Non-promoter\n",
579
+ "Non-promoter |||||||||||| promoter\n",
580
+ "promoter |||||||||||| promoter\n",
581
+ "promoter |||||||||||| promoter\n",
582
+ "promoter |||||||||||| promoter\n",
583
+ "promoter |||||||||||| promoter\n",
584
+ "promoter |||||||||||| Non-promoter\n",
585
+ "promoter |||||||||||| promoter\n",
586
+ "Non-promoter |||||||||||| Non-promoter\n",
587
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
588
+ "Non-promoter |||||||||||| Non-promoter\n",
589
+ "promoter |||||||||||| promoter\n",
590
+ "Non-promoter |||||||||||| Non-promoter\n",
591
+ "Binding Sites |||||||||||| Background Sequences\n",
592
+ "Non-promoter |||||||||||| Non-promoter\n",
593
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
594
+ "Non-promoter |||||||||||| Non-promoter\n",
595
+ "Non-promoter |||||||||||| promoter\n",
596
+ "Non-promoter |||||||||||| Non-promoter\n",
597
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
598
+ "Non-promoter |||||||||||| promoter\n",
599
+ "promoter |||||||||||| promoter\n",
600
+ "Background Sequences |||||||||||| Background Sequences\n",
601
+ "Non-promoter |||||||||||| Non-promoter\n",
602
+ "Binding Sites |||||||||||| Binding Sites\n",
603
+ "promoter |||||||||||| promoter\n",
604
+ "Non-promoter |||||||||||| Non-promoter\n",
605
+ "Non-promoter |||||||||||| Non-promoter\n",
606
+ "Non-promoter |||||||||||| Non-promoter\n",
607
+ "Non-promoter |||||||||||| Non-promoter\n",
608
+ "Donor Sites |||||||||||| Donor Sites\n",
609
+ "promoter |||||||||||| Non-promoter\n",
610
+ "promoter |||||||||||| Non-promoter\n",
611
+ "Non-promoter |||||||||||| promoter\n",
612
+ "Binding Sites |||||||||||| Binding Sites\n",
613
+ "promoter |||||||||||| Non-promoter\n",
614
+ "promoter |||||||||||| promoter\n",
615
+ "Background Sequences |||||||||||| Background Sequences\n",
616
+ "promoter |||||||||||| promoter\n",
617
+ "Non-promoter |||||||||||| Non-promoter\n",
618
+ "Background Sequences |||||||||||| Binding Sites\n",
619
+ "promoter |||||||||||| promoter\n",
620
+ "promoter |||||||||||| promoter\n",
621
+ "promoter |||||||||||| promoter\n",
622
+ "Donor Sites |||||||||||| Donor Sites\n",
623
+ "Binding Sites |||||||||||| Binding Sites\n",
624
+ "promoter |||||||||||| promoter\n",
625
+ "Donor Sites |||||||||||| Donor Sites\n",
626
+ "Non-promoter |||||||||||| Non-promoter\n",
627
+ "Binding Sites |||||||||||| Binding Sites\n",
628
+ "Donor Sites |||||||||||| Donor Sites\n",
629
+ "Non-promoter |||||||||||| promoter\n",
630
+ "Donor Sites |||||||||||| Donor Sites\n",
631
+ "Non-promoter |||||||||||| Non-promoter\n",
632
+ "promoter |||||||||||| promoter\n",
633
+ "promoter |||||||||||| promoter\n",
634
+ "promoter |||||||||||| promoter\n",
635
+ "Non-promoter |||||||||||| Non-promoter\n",
636
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
637
+ "promoter |||||||||||| Non-promoter\n",
638
+ "Donor Sites |||||||||||| Donor Sites\n",
639
+ "Donor Sites |||||||||||| Donor Sites\n",
640
+ "promoter |||||||||||| promoter\n",
641
+ "promoter |||||||||||| Non-promoter\n",
642
+ "promoter |||||||||||| promoter\n",
643
+ "Non-promoter |||||||||||| Non-promoter\n",
644
+ "Non-promoter |||||||||||| Non-promoter\n",
645
+ "promoter |||||||||||| promoter\n",
646
+ "Non-promoter |||||||||||| Non-promoter\n",
647
+ "promoter |||||||||||| promoter\n",
648
+ "Background Sequences |||||||||||| Binding Sites\n",
649
+ "Acceptor Sites |||||||||||| Donor Sites\n",
650
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
651
+ "Donor Sites |||||||||||| Donor Sites\n",
652
+ "Donor Sites |||||||||||| Donor Sites\n",
653
+ "Non-promoter |||||||||||| promoter\n",
654
+ "promoter |||||||||||| Non-promoter\n",
655
+ "Background Sequences |||||||||||| Background Sequences\n",
656
+ "promoter |||||||||||| promoter\n",
657
+ "promoter |||||||||||| promoter\n",
658
+ "Acceptor Sites |||||||||||| Donor Sites\n",
659
+ "promoter |||||||||||| promoter\n",
660
+ "Donor Sites |||||||||||| Donor Sites\n",
661
+ "Binding Sites |||||||||||| Courses\n",
662
+ "promoter |||||||||||| promoter\n",
663
+ "Donor Sites |||||||||||| Donor Sites\n",
664
+ "Non-promoter |||||||||||| Non-promoter\n",
665
+ "Non-promoter |||||||||||| Non-promoter\n",
666
+ "Donor Sites |||||||||||| Donor Sites\n",
667
+ "Donor Sites |||||||||||| Donor Sites\n",
668
+ "Non-promoter |||||||||||| Non-promoter\n",
669
+ "Binding Sites |||||||||||| Binding Sites\n",
670
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
671
+ "Binding Sites |||||||||||| Court\n",
672
+ "Donor Sites |||||||||||| Donor Sites\n",
673
+ "Non-promoter |||||||||||| promoter\n",
674
+ "promoter |||||||||||| Non-promoter\n",
675
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
676
+ "promoter |||||||||||| promoter\n",
677
+ "Binding Sites |||||||||||| Background Sequences\n",
678
+ "promoter |||||||||||| promoter\n",
679
+ "Non-promoter |||||||||||| Non-promoter\n",
680
+ "Binding Sites |||||||||||| Binding Sites\n",
681
+ "Non-promoter |||||||||||| Non-promoter\n",
682
+ "Non-promoter |||||||||||| Non-promoter\n",
683
+ "Non-promoter |||||||||||| promoter\n",
684
+ "Non-promoter |||||||||||| Non-promoter\n",
685
+ "promoter |||||||||||| promoter\n",
686
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
687
+ "promoter |||||||||||| promoter\n",
688
+ "promoter |||||||||||| Non-promoter\n",
689
+ "promoter |||||||||||| promoter\n",
690
+ "promoter |||||||||||| promoter\n",
691
+ "Donor Sites |||||||||||| Donor Sites\n",
692
+ "Background Sequences |||||||||||| Binding Sites\n",
693
+ "Background Sequences |||||||||||| Background Sequences\n",
694
+ "promoter |||||||||||| promoter\n",
695
+ "Non-promoter |||||||||||| Non-promoter\n",
696
+ "promoter |||||||||||| promoter\n",
697
+ "Donor Sites |||||||||||| Donor Sites\n",
698
+ "Non-promoter |||||||||||| promoter\n",
699
+ "Acceptor Sites |||||||||||| Non-Splice Sites\n",
700
+ "Non-promoter |||||||||||| promoter\n",
701
+ "Non-promoter |||||||||||| Non-promoter\n",
702
+ "Non-promoter |||||||||||| Non-promoter\n",
703
+ "promoter |||||||||||| promoter\n",
704
+ "promoter |||||||||||| promoter\n",
705
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
706
+ "Non-promoter |||||||||||| Non-promoter\n",
707
+ "Non-promoter |||||||||||| Non-promoter\n",
708
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
709
+ "Donor Sites |||||||||||| Donor Sites\n",
710
+ "Donor Sites |||||||||||| Donor Sites\n",
711
+ "Binding Sites |||||||||||| Background Sequences\n",
712
+ "Binding Sites |||||||||||| Binding Sites\n",
713
+ "promoter |||||||||||| promoter\n",
714
+ "Non-promoter |||||||||||| Non-promoter\n",
715
+ "Binding Sites |||||||||||| Background Sequences\n",
716
+ "Background Sequences |||||||||||| Background Sequences\n",
717
+ "Non-promoter |||||||||||| promoter\n",
718
+ "Non-promoter |||||||||||| Non-promoter\n",
719
+ "promoter |||||||||||| Non-promoter\n",
720
+ "Donor Sites |||||||||||| Donor Sites\n",
721
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
722
+ "Binding Sites |||||||||||| Binding Sites\n",
723
+ "Donor Sites |||||||||||| Donor Sites\n",
724
+ "promoter |||||||||||| Non-promoter\n",
725
+ "Acceptor Sites |||||||||||| Donor Sites\n",
726
+ "Non-promoter |||||||||||| Non-promoter\n",
727
+ "Non-promoter |||||||||||| Non-promoter\n",
728
+ "Donor Sites |||||||||||| Donor Sites\n",
729
+ "Donor Sites |||||||||||| Donor Sites\n",
730
+ "Donor Sites |||||||||||| Donor Sites\n",
731
+ "promoter |||||||||||| promoter\n",
732
+ "promoter |||||||||||| promoter\n",
733
+ "promoter |||||||||||| promoter\n",
734
+ "promoter |||||||||||| promoter\n",
735
+ "promoter |||||||||||| promoter\n",
736
+ "promoter |||||||||||| promoter\n",
737
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
738
+ "promoter |||||||||||| promoter\n",
739
+ "Background Sequences |||||||||||| Background Sequences\n",
740
+ "Non-promoter |||||||||||| Non-promoter\n",
741
+ "promoter |||||||||||| promoter\n",
742
+ "Donor Sites |||||||||||| Donor Sites\n",
743
+ "Non-promoter |||||||||||| promoter\n",
744
+ "Donor Sites |||||||||||| Donor Sites\n",
745
+ "Binding Sites |||||||||||| Binding Sites\n",
746
+ "Donor Sites |||||||||||| Donor Sites\n",
747
+ "Binding Sites |||||||||||| Binding Sites\n",
748
+ "Non-promoter |||||||||||| promoter\n",
749
+ "Non-promoter |||||||||||| Non-promoter\n",
750
+ "Background Sequences |||||||||||| Binding Sites\n",
751
+ "Non-promoter |||||||||||| Non-promoter\n",
752
+ "promoter |||||||||||| promoter\n",
753
+ "Background Sequences |||||||||||| Background Sequences\n",
754
+ "Non-promoter |||||||||||| promoter\n",
755
+ "Non-promoter |||||||||||| Non-promoter\n",
756
+ "Non-promoter |||||||||||| Non-promoter\n",
757
+ "Background Sequences |||||||||||| Binding Sites\n",
758
+ "Background Sequences |||||||||||| Background Sequences\n",
759
+ "Non-promoter |||||||||||| Non-promoter\n",
760
+ "promoter |||||||||||| promoter\n",
761
+ "Background Sequences |||||||||||| Background Sequences\n",
762
+ "Non-promoter |||||||||||| promoter\n",
763
+ "Non-promoter |||||||||||| promoter\n",
764
+ "promoter |||||||||||| promoter\n",
765
+ "promoter |||||||||||| promoter\n",
766
+ "promoter |||||||||||| promoter\n",
767
+ "Non-promoter |||||||||||| Non-promoter\n",
768
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
769
+ "promoter |||||||||||| Non-promoter\n",
770
+ "promoter |||||||||||| promoter\n",
771
+ "Background Sequences |||||||||||| Background Sequences\n",
772
+ "Background Sequences |||||||||||| Background Sequences\n",
773
+ "Background Sequences |||||||||||| Background Sequences\n",
774
+ "Donor Sites |||||||||||| Donor Sites\n",
775
+ "Binding Sites |||||||||||| Binding Sites\n",
776
+ "Non-promoter |||||||||||| Non-promoter\n",
777
+ "Non-promoter |||||||||||| promoter\n",
778
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
779
+ "Binding Sites |||||||||||| Binding Sites\n",
780
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
781
+ "Background Sequences |||||||||||| Binding Sites\n",
782
+ "promoter |||||||||||| promoter\n",
783
+ "Non-Splice Sites |||||||||||| Splice Sites\n",
784
+ "promoter |||||||||||| promoter\n",
785
+ "Donor Sites |||||||||||| Acceptor Sites\n",
786
+ "Binding Sites |||||||||||| Binding Sites\n",
787
+ "Non-promoter |||||||||||| promoter\n",
788
+ "promoter |||||||||||| promoter\n",
789
+ "Donor Sites |||||||||||| Acceptor Sites\n",
790
+ "Non-promoter |||||||||||| Non-promoter\n",
791
+ "promoter |||||||||||| promoter\n",
792
+ "Donor Sites |||||||||||| Donor Sites\n",
793
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
794
+ "Donor Sites |||||||||||| Donor Sites\n",
795
+ "Binding Sites |||||||||||| Binding Sites\n",
796
+ "promoter |||||||||||| promoter\n",
797
+ "Background Sequences |||||||||||| Background Sequences\n",
798
+ "promoter |||||||||||| promoter\n",
799
+ "Binding Sites |||||||||||| Coursing\n",
800
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
801
+ "Non-promoter |||||||||||| Non-promoter\n",
802
+ "Donor Sites |||||||||||| Donor Sites\n",
803
+ "Non-promoter |||||||||||| Non-promoter\n",
804
+ "Non-promoter |||||||||||| promoter\n",
805
+ "Binding Sites |||||||||||| Binding Sites\n",
806
+ "Binding Sites |||||||||||| Binding Sites\n",
807
+ "Background Sequences |||||||||||| Background Sequences\n",
808
+ "Non-promoter |||||||||||| Non-promoter\n",
809
+ "promoter |||||||||||| Non-promoter\n",
810
+ "promoter |||||||||||| promoter\n",
811
+ "Non-promoter |||||||||||| promoter\n",
812
+ "promoter |||||||||||| promoter\n",
813
+ "Non-promoter |||||||||||| promoter\n",
814
+ "Non-promoter |||||||||||| promoter\n",
815
+ "Non-promoter |||||||||||| promoter\n",
816
+ "promoter |||||||||||| Non-promoter\n",
817
+ "promoter |||||||||||| promoter\n",
818
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
819
+ "promoter |||||||||||| promoter\n",
820
+ "Non-promoter |||||||||||| Non-promoter\n",
821
+ "Acceptor Sites |||||||||||| Donor Sites\n",
822
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
823
+ "promoter |||||||||||| promoter\n",
824
+ "promoter |||||||||||| promoter\n",
825
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
826
+ "Donor Sites |||||||||||| Donor Sites\n",
827
+ "Non-promoter |||||||||||| Non-promoter\n",
828
+ "promoter |||||||||||| promoter\n",
829
+ "Acceptor Sites |||||||||||| Donor Sites\n",
830
+ "Non-promoter |||||||||||| Non-promoter\n",
831
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
832
+ "Acceptor Sites |||||||||||| Non-Splice Sites\n",
833
+ "Non-promoter |||||||||||| Non-promoter\n",
834
+ "Background Sequences |||||||||||| Background Sequences\n",
835
+ "Donor Sites |||||||||||| Donor Sites\n",
836
+ "promoter |||||||||||| promoter\n",
837
+ "promoter |||||||||||| promoter\n",
838
+ "Acceptor Sites |||||||||||| Donor Sites\n",
839
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
840
+ "promoter |||||||||||| promoter\n",
841
+ "Non-promoter |||||||||||| Non-promoter\n",
842
+ "promoter |||||||||||| promoter\n",
843
+ "Non-promoter |||||||||||| promoter\n",
844
+ "promoter |||||||||||| promoter\n",
845
+ "Non-promoter |||||||||||| Non-promoter\n",
846
+ "Donor Sites |||||||||||| Donor Sites\n",
847
+ "promoter |||||||||||| promoter\n",
848
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
849
+ "Donor Sites |||||||||||| Donor Sites\n",
850
+ "Donor Sites |||||||||||| Donor Sites\n",
851
+ "Donor Sites |||||||||||| Donor Sites\n",
852
+ "promoter |||||||||||| promoter\n",
853
+ "Non-promoter |||||||||||| promoter\n",
854
+ "Binding Sites |||||||||||| Binding Sites\n",
855
+ "promoter |||||||||||| promoter\n",
856
+ "promoter |||||||||||| promoter\n",
857
+ "Binding Sites |||||||||||| Binding Sites\n",
858
+ "Binding Sites |||||||||||| Binding Sites\n",
859
+ "Non-promoter |||||||||||| Non-promoter\n",
860
+ "Non-promoter |||||||||||| Non-promoter\n",
861
+ "Non-promoter |||||||||||| Non-promoter\n",
862
+ "promoter |||||||||||| promoter\n",
863
+ "Background Sequences |||||||||||| Background Sequences\n",
864
+ "promoter |||||||||||| promoter\n",
865
+ "promoter |||||||||||| promoter\n",
866
+ "Background Sequences |||||||||||| Background Sequences\n",
867
+ "Binding Sites |||||||||||| Binding Sites\n",
868
+ "Binding Sites |||||||||||| Background Sequences\n",
869
+ "Non-promoter |||||||||||| Non-promoter\n",
870
+ "Non-promoter |||||||||||| promoter\n",
871
+ "Non-promoter |||||||||||| Non-promoter\n",
872
+ "Non-promoter |||||||||||| promoter\n",
873
+ "Donor Sites |||||||||||| Donor Sites\n",
874
+ "promoter |||||||||||| promoter\n",
875
+ "promoter |||||||||||| promoter\n",
876
+ "Non-promoter |||||||||||| Non-promoter\n",
877
+ "Donor Sites |||||||||||| Donor Sites\n",
878
+ "Donor Sites |||||||||||| Donor Sites\n",
879
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
880
+ "promoter |||||||||||| promoter\n",
881
+ "Donor Sites |||||||||||| Donor Sites\n",
882
+ "promoter |||||||||||| promoter\n",
883
+ "Non-promoter |||||||||||| promoter\n",
884
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
885
+ "Binding Sites |||||||||||| Binding Sites\n",
886
+ "promoter |||||||||||| promoter\n",
887
+ "Donor Sites |||||||||||| Donor Sites\n",
888
+ "Donor Sites |||||||||||| Donor Sites\n",
889
+ "promoter |||||||||||| promoter\n",
890
+ "promoter |||||||||||| promoter\n",
891
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
892
+ "promoter |||||||||||| promoter\n",
893
+ "Binding Sites |||||||||||| Background Sequences\n",
894
+ "Non-promoter |||||||||||| Non-promoter\n",
895
+ "Donor Sites |||||||||||| Donor Sites\n",
896
+ "Non-promoter |||||||||||| promoter\n",
897
+ "promoter |||||||||||| promoter\n",
898
+ "Non-promoter |||||||||||| Non-promoter\n",
899
+ "promoter |||||||||||| promoter\n",
900
+ "promoter |||||||||||| promoter\n",
901
+ "Donor Sites |||||||||||| Donor Sites\n",
902
+ "Donor Sites |||||||||||| Donor Sites\n",
903
+ "Donor Sites |||||||||||| Donor Sites\n",
904
+ "Binding Sites |||||||||||| Binding Sites\n",
905
+ "Acceptor Sites |||||||||||| Donor Sites\n",
906
+ "Non-promoter |||||||||||| promoter\n",
907
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
908
+ "Background Sequences |||||||||||| Background Sequences\n",
909
+ "Donor Sites |||||||||||| Donor Sites\n",
910
+ "promoter |||||||||||| promoter\n",
911
+ "Donor Sites |||||||||||| Donor Sites\n",
912
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
913
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
914
+ "Background Sequences |||||||||||| Background Sequences\n",
915
+ "Non-promoter |||||||||||| Non-promoter\n",
916
+ "Non-promoter |||||||||||| Non-promoter\n",
917
+ "Non-promoter |||||||||||| Non-promoter\n",
918
+ "Non-promoter |||||||||||| promoter\n",
919
+ "Binding Sites |||||||||||| Binding Sites\n",
920
+ "promoter |||||||||||| promoter\n",
921
+ "promoter |||||||||||| Non-promoter\n",
922
+ "promoter |||||||||||| promoter\n",
923
+ "promoter |||||||||||| promoter\n",
924
+ "Non-promoter |||||||||||| Non-promoter\n",
925
+ "Donor Sites |||||||||||| Donor Sites\n",
926
+ "Non-promoter |||||||||||| Non-promoter\n",
927
+ "Non-promoter |||||||||||| Non-promoter\n",
928
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
929
+ "promoter |||||||||||| Non-promoter\n",
930
+ "Non-promoter |||||||||||| promoter\n",
931
+ "Binding Sites |||||||||||| Binding Sites\n",
932
+ "Binding Sites |||||||||||| Background Sequences\n",
933
+ "Donor Sites |||||||||||| D Donor Sites\n",
934
+ "promoter |||||||||||| promoter\n",
935
+ "Background Sequences |||||||||||| Background Sequences\n",
936
+ "Background Sequences |||||||||||| Background Sequences\n",
937
+ "Non-promoter |||||||||||| Non-promoter\n",
938
+ "promoter |||||||||||| promoter\n",
939
+ "Non-promoter |||||||||||| promoter\n",
940
+ "promoter |||||||||||| promoter\n",
941
+ "Non-promoter |||||||||||| promoter\n",
942
+ "Non-promoter |||||||||||| promoter\n",
943
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
944
+ "Non-promoter |||||||||||| Non-promoter\n",
945
+ "promoter |||||||||||| promoter\n",
946
+ "Donor Sites |||||||||||| Acceptor Sites\n",
947
+ "Donor Sites |||||||||||| Donor Sites\n",
948
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
949
+ "Background Sequences |||||||||||| Background Sequences\n",
950
+ "promoter |||||||||||| promoter\n",
951
+ "promoter |||||||||||| promoter\n",
952
+ "Non-promoter |||||||||||| Non-promoter\n",
953
+ "promoter |||||||||||| promoter\n",
954
+ "promoter |||||||||||| promoter\n",
955
+ "Background Sequences |||||||||||| Background Sequences\n",
956
+ "Donor Sites |||||||||||| Donor Sites\n",
957
+ "promoter |||||||||||| promoter\n",
958
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
959
+ "Binding Sites |||||||||||| Binding Sites\n",
960
+ "Non-promoter |||||||||||| Non-promoter\n",
961
+ "Non-promoter |||||||||||| Non-promoter\n",
962
+ "promoter |||||||||||| promoter\n",
963
+ "Non-promoter |||||||||||| promoter\n",
964
+ "Non-promoter |||||||||||| Non-promoter\n",
965
+ "Acceptor Sites |||||||||||| Donor Sites\n",
966
+ "promoter |||||||||||| promoter\n",
967
+ "Acceptor Sites |||||||||||| Donor Sites\n",
968
+ "promoter |||||||||||| promoter\n",
969
+ "promoter |||||||||||| promoter\n",
970
+ "Acceptor Sites |||||||||||| Donor Sites\n",
971
+ "promoter |||||||||||| promoter\n",
972
+ "promoter |||||||||||| promoter\n",
973
+ "promoter |||||||||||| Non-promoter\n",
974
+ "Non-promoter |||||||||||| promoter\n",
975
+ "promoter |||||||||||| promoter\n",
976
+ "Non-promoter |||||||||||| Non-promoter\n",
977
+ "Background Sequences |||||||||||| Background Sequences\n",
978
+ "Non-promoter |||||||||||| Non-promoter\n",
979
+ "Background Sequences |||||||||||| Background Sequences\n",
980
+ "Binding Sites |||||||||||| Binding Sites\n",
981
+ "Background Sequences |||||||||||| Background Sequences\n",
982
+ "promoter |||||||||||| promoter\n",
983
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
984
+ "Background Sequences |||||||||||| Background Sequences\n",
985
+ "Background Sequences |||||||||||| Background Sequences\n",
986
+ "Non-promoter |||||||||||| Non-promoter\n",
987
+ "Donor Sites |||||||||||| Donor Sites\n",
988
+ "Non-promoter |||||||||||| Non-promoter\n",
989
+ "Acceptor Sites |||||||||||| Donor Sites\n",
990
+ "Non-promoter |||||||||||| promoter\n",
991
+ "Non-promoter |||||||||||| Non-promoter\n",
992
+ "promoter |||||||||||| Non-promoter\n",
993
+ "Binding Sites |||||||||||| Background Sequences\n",
994
+ "Binding Sites |||||||||||| Background Sequences\n",
995
+ "Non-promoter |||||||||||| Non-promoter\n",
996
+ "Non-promoter |||||||||||| Non-promoter\n",
997
+ "Non-promoter |||||||||||| Non-promoter\n",
998
+ "Non-promoter |||||||||||| Non-promoter\n",
999
+ "Non-promoter |||||||||||| Non-promoter\n",
1000
+ "Non-promoter |||||||||||| promoter\n",
1001
+ "promoter |||||||||||| Non-promoter\n",
1002
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1003
+ "Non-promoter |||||||||||| Non-promoter\n",
1004
+ "Non-promoter |||||||||||| Non-promoter\n",
1005
+ "Non-promoter |||||||||||| Non-promoter\n",
1006
+ "Non-promoter |||||||||||| Non-promoter\n",
1007
+ "Binding Sites |||||||||||| Binding Sites\n",
1008
+ "Non-promoter |||||||||||| Non-promoter\n",
1009
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1010
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1011
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1012
+ "promoter |||||||||||| promoter\n",
1013
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1014
+ "promoter |||||||||||| promoter\n",
1015
+ "Non-promoter |||||||||||| Non-promoter\n",
1016
+ "Non-promoter |||||||||||| Non-promoter\n",
1017
+ "Donor Sites |||||||||||| Donor Sites\n",
1018
+ "promoter |||||||||||| Non-promoter\n",
1019
+ "promoter |||||||||||| promoter\n",
1020
+ "promoter |||||||||||| promoter\n",
1021
+ "Binding Sites |||||||||||| Binding Sites\n",
1022
+ "Donor Sites |||||||||||| Donor Sites\n",
1023
+ "Non-promoter |||||||||||| promoter\n",
1024
+ "Donor Sites |||||||||||| Donor Sites\n",
1025
+ "Non-promoter |||||||||||| promoter\n",
1026
+ "Background Sequences |||||||||||| Background Sequences\n",
1027
+ "Non-promoter |||||||||||| promoter\n",
1028
+ "Non-promoter |||||||||||| Non-promoter\n",
1029
+ "promoter |||||||||||| promoter\n",
1030
+ "Non-promoter |||||||||||| Non-promoter\n",
1031
+ "Binding Sites |||||||||||| Binding Sites\n",
1032
+ "Non-promoter |||||||||||| promoter\n",
1033
+ "Donor Sites |||||||||||| Donor Sites\n",
1034
+ "promoter |||||||||||| promoter\n",
1035
+ "promoter |||||||||||| promoter\n",
1036
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1037
+ "Non-promoter |||||||||||| Non-promoter\n",
1038
+ "promoter |||||||||||| promoter\n",
1039
+ "promoter |||||||||||| promoter\n",
1040
+ "promoter |||||||||||| promoter\n",
1041
+ "Non-promoter |||||||||||| Non-promoter\n",
1042
+ "Non-promoter |||||||||||| promoter\n",
1043
+ "promoter |||||||||||| promoter\n",
1044
+ "Non-promoter |||||||||||| Non-promoter\n",
1045
+ "promoter |||||||||||| promoter\n",
1046
+ "Non-promoter |||||||||||| promoter\n",
1047
+ "promoter |||||||||||| promoter\n",
1048
+ "Donor Sites |||||||||||| Donor Sites\n",
1049
+ "promoter |||||||||||| promoter\n",
1050
+ "Binding Sites |||||||||||| Background Sequences\n",
1051
+ "promoter |||||||||||| promoter\n",
1052
+ "Non-promoter |||||||||||| promoter\n",
1053
+ "promoter |||||||||||| promoter\n",
1054
+ "Non-promoter |||||||||||| Non-promoter\n",
1055
+ "Non-promoter |||||||||||| promoter\n",
1056
+ "promoter |||||||||||| promoter\n",
1057
+ "promoter |||||||||||| Non-promoter\n",
1058
+ "Non-promoter |||||||||||| Non-promoter\n",
1059
+ "promoter |||||||||||| promoter\n",
1060
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1061
+ "Non-promoter |||||||||||| Non-promoter\n",
1062
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1063
+ "Binding Sites |||||||||||| Background Sequences\n",
1064
+ "promoter |||||||||||| promoter\n",
1065
+ "Donor Sites |||||||||||| Donor Sites\n",
1066
+ "Non-promoter |||||||||||| Non-promoter\n",
1067
+ "Non-promoter |||||||||||| Non-promoter\n",
1068
+ "Non-promoter |||||||||||| promoter\n",
1069
+ "Non-promoter |||||||||||| Non-promoter\n",
1070
+ "Binding Sites |||||||||||| Binding Sites\n",
1071
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1072
+ "Donor Sites |||||||||||| Donor Sites\n",
1073
+ "Donor Sites |||||||||||| Donor Sites\n",
1074
+ "promoter |||||||||||| promoter\n",
1075
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1076
+ "Non-promoter |||||||||||| Non-promoter\n",
1077
+ "promoter |||||||||||| promoter\n",
1078
+ "promoter |||||||||||| promoter\n",
1079
+ "Donor Sites |||||||||||| Donor Sites\n",
1080
+ "Non-promoter |||||||||||| promoter\n",
1081
+ "Binding Sites |||||||||||| Background Sequences\n",
1082
+ "Background Sequences |||||||||||| Background Sequences\n",
1083
+ "promoter |||||||||||| Non-promoter\n",
1084
+ "promoter |||||||||||| promoter\n",
1085
+ "promoter |||||||||||| promoter\n",
1086
+ "promoter |||||||||||| promoter\n",
1087
+ "Non-promoter |||||||||||| Non-promoter\n",
1088
+ "Non-promoter |||||||||||| Non-promoter\n",
1089
+ "Donor Sites |||||||||||| Donor Sites\n",
1090
+ "Background Sequences |||||||||||| Background Sequences\n",
1091
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1092
+ "Non-promoter |||||||||||| Non-promoter\n",
1093
+ "promoter |||||||||||| promoter\n",
1094
+ "Non-promoter |||||||||||| Non-promoter\n",
1095
+ "Non-promoter |||||||||||| C promoter\n",
1096
+ "promoter |||||||||||| Non-promoter\n",
1097
+ "promoter |||||||||||| promoter\n",
1098
+ "Non-promoter |||||||||||| Non-promoter\n",
1099
+ "Donor Sites |||||||||||| Donor Sites\n",
1100
+ "Donor Sites |||||||||||| Donor Sites\n",
1101
+ "Donor Sites |||||||||||| Donor Sites\n",
1102
+ "Background Sequences |||||||||||| Background Sequences\n",
1103
+ "promoter |||||||||||| promoter\n",
1104
+ "Non-promoter |||||||||||| Non-promoter\n",
1105
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1106
+ "Binding Sites |||||||||||| Background Sequences\n",
1107
+ "Non-promoter |||||||||||| Non-promoter\n",
1108
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1109
+ "Non-promoter |||||||||||| Non-promoter\n",
1110
+ "promoter |||||||||||| promoter\n",
1111
+ "Non-promoter |||||||||||| Non-promoter\n",
1112
+ "promoter |||||||||||| promoter\n",
1113
+ "Non-promoter |||||||||||| promoter\n",
1114
+ "promoter |||||||||||| Non-promoter\n",
1115
+ "Non-promoter |||||||||||| Non-promoter\n",
1116
+ "Binding Sites |||||||||||| Binding Sites\n",
1117
+ "Donor Sites |||||||||||| Donor Sites\n",
1118
+ "Non-promoter |||||||||||| promoter\n",
1119
+ "promoter |||||||||||| promoter\n",
1120
+ "Non-promoter |||||||||||| Non-promoter\n",
1121
+ "Background Sequences |||||||||||| Binding Sites\n",
1122
+ "Binding Sites |||||||||||| Binding Sites\n",
1123
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1124
+ "Non-promoter |||||||||||| Non-promoter\n",
1125
+ "Non-promoter |||||||||||| Non-promoter\n",
1126
+ "Non-promoter |||||||||||| Non-promoter\n",
1127
+ "Donor Sites |||||||||||| Donor Sites\n",
1128
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1129
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1130
+ "promoter |||||||||||| promoter\n",
1131
+ "Non-promoter |||||||||||| Non-promoter\n",
1132
+ "Non-promoter |||||||||||| Non-promoter\n",
1133
+ "Donor Sites |||||||||||| Donor Sites\n",
1134
+ "promoter |||||||||||| promoter\n",
1135
+ "promoter |||||||||||| promoter\n",
1136
+ "promoter |||||||||||| promoter\n",
1137
+ "Background Sequences |||||||||||| Background Sequences\n",
1138
+ "promoter |||||||||||| promoter\n",
1139
+ "Donor Sites |||||||||||| Donor Sites\n",
1140
+ "Background Sequences |||||||||||| Background Sequences\n",
1141
+ "Binding Sites |||||||||||| Binding Sites\n",
1142
+ "Non-promoter |||||||||||| promoter\n",
1143
+ "Non-promoter |||||||||||| Non-promoter\n",
1144
+ "promoter |||||||||||| promoter\n",
1145
+ "promoter |||||||||||| promoter\n",
1146
+ "promoter |||||||||||| promoter\n",
1147
+ "Binding Sites |||||||||||| Binding Sites\n",
1148
+ "Background Sequences |||||||||||| Background Sequences\n",
1149
+ "Non-promoter |||||||||||| Non-promoter\n",
1150
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1151
+ "Non-promoter |||||||||||| Non-promoter\n",
1152
+ "Non-promoter |||||||||||| promoter\n",
1153
+ "Background Sequences |||||||||||| Binding Sites\n",
1154
+ "promoter |||||||||||| promoter\n",
1155
+ "Non-promoter |||||||||||| Non-promoter\n",
1156
+ "promoter |||||||||||| Non-promoter\n",
1157
+ "Non-promoter |||||||||||| Non-promoter\n",
1158
+ "Non-promoter |||||||||||| Non-promoter\n",
1159
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1160
+ "Non-promoter |||||||||||| Non-promoter\n",
1161
+ "promoter |||||||||||| promoter\n",
1162
+ "Non-promoter |||||||||||| promoter\n",
1163
+ "Non-promoter |||||||||||| promoter\n",
1164
+ "promoter |||||||||||| promoter\n",
1165
+ "Non-promoter |||||||||||| Non-promoter\n",
1166
+ "Non-promoter |||||||||||| promoter\n",
1167
+ "Non-promoter |||||||||||| Non-promoter\n",
1168
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1169
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1170
+ "promoter |||||||||||| Non-promoter\n",
1171
+ "Binding Sites |||||||||||| Background Sequences\n",
1172
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1173
+ "Non-promoter |||||||||||| Non-promoter\n",
1174
+ "Donor Sites |||||||||||| Donor Sites\n",
1175
+ "Non-promoter |||||||||||| Non-promoter\n",
1176
+ "promoter |||||||||||| promoter\n",
1177
+ "Donor Sites |||||||||||| Donor Sites\n",
1178
+ "Donor Sites |||||||||||| Donor Sites\n",
1179
+ "Non-promoter |||||||||||| promoter\n",
1180
+ "Binding Sites |||||||||||| Binding Sites\n",
1181
+ "Non-promoter |||||||||||| Non-promoter\n",
1182
+ "Binding Sites |||||||||||| Binding Sites\n",
1183
+ "Donor Sites |||||||||||| Donor Sites\n",
1184
+ "Background Sequences |||||||||||| Background Sequences\n",
1185
+ "Donor Sites |||||||||||| Donor Sites\n",
1186
+ "Background Sequences |||||||||||| Binding Sites\n",
1187
+ "Binding Sites |||||||||||| Binding Sites\n",
1188
+ "promoter |||||||||||| promoter\n",
1189
+ "promoter |||||||||||| promoter\n",
1190
+ "promoter |||||||||||| promoter\n",
1191
+ "Binding Sites |||||||||||| Binding Sites\n",
1192
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1193
+ "Non-promoter |||||||||||| Non-promoter\n",
1194
+ "Non-promoter |||||||||||| promoter\n",
1195
+ "promoter |||||||||||| promoter\n",
1196
+ "promoter |||||||||||| promoter\n",
1197
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1198
+ "Binding Sites |||||||||||| Binding Sites\n",
1199
+ "Background Sequences |||||||||||| Background Sequences\n",
1200
+ "Donor Sites |||||||||||| Donor Sites\n",
1201
+ "Non-promoter |||||||||||| Non-promoter\n",
1202
+ "promoter |||||||||||| promoter\n",
1203
+ "Background Sequences |||||||||||| Background Sequences\n",
1204
+ "Donor Sites |||||||||||| Donor Sites\n",
1205
+ "promoter |||||||||||| promoter\n",
1206
+ "Non-promoter |||||||||||| Non-promoter\n",
1207
+ "Non-promoter |||||||||||| Non-promoter\n",
1208
+ "Non-promoter |||||||||||| Non-promoter\n",
1209
+ "promoter |||||||||||| promoter\n",
1210
+ "Binding Sites |||||||||||| Binding Sites\n",
1211
+ "promoter |||||||||||| Non-promoter\n",
1212
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1213
+ "promoter |||||||||||| promoter\n",
1214
+ "promoter |||||||||||| promoter\n",
1215
+ "Background Sequences |||||||||||| Background Sequences\n",
1216
+ "Background Sequences |||||||||||| Background Sequences\n",
1217
+ "Non-promoter |||||||||||| Non-promoter\n",
1218
+ "Binding Sites |||||||||||| Binding Sites\n",
1219
+ "Background Sequences |||||||||||| Background Sequences\n",
1220
+ "Non-promoter |||||||||||| Non-promoter\n",
1221
+ "Non-promoter |||||||||||| Non-promoter\n",
1222
+ "Donor Sites |||||||||||| Donor Sites\n",
1223
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1224
+ "Non-promoter |||||||||||| Non-promoter\n",
1225
+ "Binding Sites |||||||||||| Binding Sites\n",
1226
+ "promoter |||||||||||| promoter\n",
1227
+ "Non-promoter |||||||||||| promoter\n",
1228
+ "promoter |||||||||||| Non-promoter\n",
1229
+ "Donor Sites |||||||||||| Donor Sites\n",
1230
+ "Non-promoter |||||||||||| Non-promoter\n",
1231
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1232
+ "Binding Sites |||||||||||| Background Sequences\n",
1233
+ "Background Sequences |||||||||||| Background Sequences\n",
1234
+ "Non-promoter |||||||||||| promoter\n",
1235
+ "Non-promoter |||||||||||| Non-promoter\n",
1236
+ "promoter |||||||||||| promoter\n",
1237
+ "Donor Sites |||||||||||| Donor Sites\n",
1238
+ "promoter |||||||||||| promoter\n",
1239
+ "Donor Sites |||||||||||| Donor Sites\n",
1240
+ "Donor Sites |||||||||||| Donor Sites\n",
1241
+ "promoter |||||||||||| Non-promoter\n",
1242
+ "Binding Sites |||||||||||| Background Sequences\n",
1243
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1244
+ "promoter |||||||||||| promoter\n",
1245
+ "promoter |||||||||||| promoter\n",
1246
+ "Non-promoter |||||||||||| Non-promoter\n",
1247
+ "Non-promoter |||||||||||| Non-promoter\n",
1248
+ "Background Sequences |||||||||||| Binding Sites\n",
1249
+ "Non-promoter |||||||||||| Non-promoter\n",
1250
+ "Non-promoter |||||||||||| promoter\n",
1251
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1252
+ "Non-promoter |||||||||||| Non-promoter\n",
1253
+ "promoter |||||||||||| promoter\n",
1254
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1255
+ "promoter |||||||||||| promoter\n",
1256
+ "Binding Sites |||||||||||| Binding Sites\n",
1257
+ "promoter |||||||||||| promoter\n",
1258
+ "promoter |||||||||||| promoter\n",
1259
+ "Non-promoter |||||||||||| promoter\n",
1260
+ "promoter |||||||||||| Non-promoter\n",
1261
+ "Non-promoter |||||||||||| Non-promoter\n",
1262
+ "promoter |||||||||||| promoter\n",
1263
+ "Donor Sites |||||||||||| Donor Sites\n",
1264
+ "Non-promoter |||||||||||| promoter\n",
1265
+ "Non-promoter |||||||||||| Non-promoter\n",
1266
+ "Donor Sites |||||||||||| Donor Sites\n",
1267
+ "promoter |||||||||||| promoter\n",
1268
+ "promoter |||||||||||| promoter\n",
1269
+ "promoter |||||||||||| promoter\n",
1270
+ "Donor Sites |||||||||||| Donor Sites\n",
1271
+ "Donor Sites |||||||||||| Donor Sites\n",
1272
+ "promoter |||||||||||| promoter\n",
1273
+ "Non-promoter |||||||||||| Non-promoter\n",
1274
+ "promoter |||||||||||| Non-promoter\n",
1275
+ "Non-promoter |||||||||||| Non-promoter\n",
1276
+ "Non-promoter |||||||||||| promoter\n",
1277
+ "promoter |||||||||||| promoter\n",
1278
+ "promoter |||||||||||| promoter\n",
1279
+ "Binding Sites |||||||||||| Background Sequences\n",
1280
+ "Non-promoter |||||||||||| Non-promoter\n",
1281
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1282
+ "Non-promoter |||||||||||| Non-promoter\n",
1283
+ "Non-promoter |||||||||||| Non-promoter\n",
1284
+ "Non-promoter |||||||||||| Non-promoter\n",
1285
+ "promoter |||||||||||| promoter\n",
1286
+ "promoter |||||||||||| promoter\n",
1287
+ "Donor Sites |||||||||||| Donor Sites\n",
1288
+ "Binding Sites |||||||||||| Binding Sites\n",
1289
+ "promoter |||||||||||| Non-promoter\n",
1290
+ "promoter |||||||||||| promoter\n",
1291
+ "Background Sequences |||||||||||| Binding Sites\n",
1292
+ "Non-promoter |||||||||||| Non-promoter\n",
1293
+ "Non-promoter |||||||||||| Non-promoter\n",
1294
+ "promoter |||||||||||| Non-promoter\n",
1295
+ "promoter |||||||||||| promoter\n",
1296
+ "Non-promoter |||||||||||| Non-promoter\n",
1297
+ "Background Sequences |||||||||||| Binding Sites\n",
1298
+ "Binding Sites |||||||||||| Binding Sites\n",
1299
+ "Non-promoter |||||||||||| Non-promoter\n",
1300
+ "Non-promoter |||||||||||| Non-promoter\n",
1301
+ "Binding Sites |||||||||||| Binding Sites\n",
1302
+ "promoter |||||||||||| promoter\n",
1303
+ "Non-promoter |||||||||||| Non-promoter\n",
1304
+ "promoter |||||||||||| Non-promoter\n",
1305
+ "promoter |||||||||||| promoter\n",
1306
+ "Non-promoter |||||||||||| Non-promoter\n",
1307
+ "promoter |||||||||||| Non-promoter\n",
1308
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1309
+ "Binding Sites |||||||||||| Background Sequences\n",
1310
+ "Background Sequences |||||||||||| Background Sequences\n",
1311
+ "promoter |||||||||||| Non-promoter\n",
1312
+ "Donor Sites |||||||||||| Donor Sites\n",
1313
+ "promoter |||||||||||| promoter\n",
1314
+ "Binding Sites |||||||||||| Binding Sites\n",
1315
+ "promoter |||||||||||| promoter\n",
1316
+ "Non-promoter |||||||||||| promoter\n",
1317
+ "Non-promoter |||||||||||| promoter\n",
1318
+ "Background Sequences |||||||||||| Background Sequences\n",
1319
+ "Non-promoter |||||||||||| Non-promoter\n",
1320
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1321
+ "promoter |||||||||||| promoter\n",
1322
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1323
+ "Donor Sites |||||||||||| Donor Sites\n",
1324
+ "promoter |||||||||||| promoter\n",
1325
+ "Non-promoter |||||||||||| Non-promoter\n",
1326
+ "Non-promoter |||||||||||| promoter\n",
1327
+ "Acceptor Sites |||||||||||| Splice Sites\n",
1328
+ "Binding Sites |||||||||||| Binding Sites\n",
1329
+ "Non-promoter |||||||||||| Non-promoter\n",
1330
+ "promoter |||||||||||| promoter\n",
1331
+ "Binding Sites |||||||||||| Binding Sites\n",
1332
+ "promoter |||||||||||| Non-promoter\n",
1333
+ "Donor Sites |||||||||||| Donor Sites\n",
1334
+ "promoter |||||||||||| promoter\n",
1335
+ "promoter |||||||||||| promoter\n",
1336
+ "Donor Sites |||||||||||| Donor Sites\n",
1337
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1338
+ "Non-promoter |||||||||||| promoter\n",
1339
+ "Non-promoter |||||||||||| Non-promoter\n",
1340
+ "promoter |||||||||||| promoter\n",
1341
+ "Non-promoter |||||||||||| Non-promoter\n",
1342
+ "Binding Sites |||||||||||| Background Sequences\n",
1343
+ "Non-promoter |||||||||||| Non-promoter\n",
1344
+ "Binding Sites |||||||||||| Binding Sites\n",
1345
+ "promoter |||||||||||| promoter\n",
1346
+ "promoter |||||||||||| promoter\n",
1347
+ "Non-promoter |||||||||||| Non-promoter\n",
1348
+ "Non-promoter |||||||||||| Non-promoter\n",
1349
+ "Donor Sites |||||||||||| Donor Sites\n",
1350
+ "Donor Sites |||||||||||| Donor Sites\n",
1351
+ "Background Sequences |||||||||||| Background Sequences\n",
1352
+ "promoter |||||||||||| promoter\n",
1353
+ "promoter |||||||||||| promoter\n",
1354
+ "Non-promoter |||||||||||| Non-promoter\n",
1355
+ "Binding Sites |||||||||||| Binding Sites\n",
1356
+ "promoter |||||||||||| promoter\n",
1357
+ "Binding Sites |||||||||||| Binding Sites\n",
1358
+ "promoter |||||||||||| promoter\n",
1359
+ "Donor Sites |||||||||||| Donor Sites\n",
1360
+ "promoter |||||||||||| promoter\n",
1361
+ "promoter |||||||||||| promoter\n",
1362
+ "Background Sequences |||||||||||| Background Sequences\n",
1363
+ "Non-promoter |||||||||||| Non-promoter\n",
1364
+ "promoter |||||||||||| promoter\n",
1365
+ "Non-promoter |||||||||||| Non-promoter\n",
1366
+ "Donor Sites |||||||||||| Donor Sites\n",
1367
+ "Background Sequences |||||||||||| Binding Sites\n",
1368
+ "Non-promoter |||||||||||| Non-promoter\n",
1369
+ "Donor Sites |||||||||||| Donor Sites\n",
1370
+ "promoter |||||||||||| promoter\n",
1371
+ "promoter |||||||||||| promoter\n",
1372
+ "promoter |||||||||||| promoter\n",
1373
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1374
+ "Background Sequences |||||||||||| Binding Sites\n",
1375
+ "promoter |||||||||||| Non-promoter\n",
1376
+ "Donor Sites |||||||||||| Donor Sites\n",
1377
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1378
+ "Non-promoter |||||||||||| Non-promoter\n",
1379
+ "Background Sequences |||||||||||| Background Sequences\n",
1380
+ "promoter |||||||||||| promoter\n",
1381
+ "Non-promoter |||||||||||| promoter\n",
1382
+ "Non-promoter |||||||||||| Non-promoter\n",
1383
+ "promoter |||||||||||| promoter\n",
1384
+ "promoter |||||||||||| promoter\n",
1385
+ "promoter |||||||||||| promoter\n",
1386
+ "promoter |||||||||||| promoter\n",
1387
+ "Non-promoter |||||||||||| Non-promoter\n",
1388
+ "Non-promoter |||||||||||| promoter\n",
1389
+ "Non-promoter |||||||||||| Non-promoter\n",
1390
+ "promoter |||||||||||| Non-promoter\n",
1391
+ "promoter |||||||||||| promoter\n",
1392
+ "Non-promoter |||||||||||| Non-promoter\n",
1393
+ "promoter |||||||||||| promoter\n",
1394
+ "Non-promoter |||||||||||| promoter\n",
1395
+ "promoter |||||||||||| Non-promoter\n",
1396
+ "Non-promoter |||||||||||| promoter\n",
1397
+ "promoter |||||||||||| promoter\n",
1398
+ "Binding Sites |||||||||||| Binding Sites\n",
1399
+ "promoter |||||||||||| promoter\n",
1400
+ "Non-promoter |||||||||||| Non-promoter\n",
1401
+ "promoter |||||||||||| promoter\n",
1402
+ "promoter |||||||||||| Non-promoter\n",
1403
+ "Non-promoter |||||||||||| Non-promoter\n",
1404
+ "Background Sequences |||||||||||| Binding Sites\n",
1405
+ "Donor Sites |||||||||||| Donor Sites\n",
1406
+ "Donor Sites |||||||||||| Donor Sites\n",
1407
+ "Binding Sites |||||||||||| Binding Sites\n",
1408
+ "Non-promoter |||||||||||| promoter\n",
1409
+ "Non-promoter |||||||||||| Non-promoter\n",
1410
+ "Non-promoter |||||||||||| Non-promoter\n",
1411
+ "promoter |||||||||||| promoter\n",
1412
+ "promoter |||||||||||| promoter\n",
1413
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1414
+ "Non-promoter |||||||||||| Non-promoter\n",
1415
+ "Non-promoter |||||||||||| promoter\n",
1416
+ "promoter |||||||||||| promoter\n",
1417
+ "Donor Sites |||||||||||| Donor Sites\n",
1418
+ "promoter |||||||||||| Non-promoter\n",
1419
+ "Non-promoter |||||||||||| promoter\n",
1420
+ "promoter |||||||||||| promoter\n",
1421
+ "Non-promoter |||||||||||| Non-promoter\n",
1422
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1423
+ "Non-promoter |||||||||||| Non-promoter\n",
1424
+ "promoter |||||||||||| Non-promoter\n",
1425
+ "Donor Sites |||||||||||| Donor Sites\n",
1426
+ "Non-promoter |||||||||||| Non-promoter\n",
1427
+ "Background Sequences |||||||||||| Background Sequences\n",
1428
+ "promoter |||||||||||| promoter\n",
1429
+ "promoter |||||||||||| promoter\n",
1430
+ "Donor Sites |||||||||||| Donor Sites\n",
1431
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1432
+ "promoter |||||||||||| promoter\n",
1433
+ "promoter |||||||||||| promoter\n",
1434
+ "promoter |||||||||||| promoter\n",
1435
+ "promoter |||||||||||| promoter\n",
1436
+ "promoter |||||||||||| promoter\n",
1437
+ "promoter |||||||||||| Non-promoter\n",
1438
+ "promoter |||||||||||| promoter\n",
1439
+ "promoter |||||||||||| promoter\n",
1440
+ "promoter |||||||||||| promoter\n",
1441
+ "Background Sequences |||||||||||| Background Sequences\n",
1442
+ "Background Sequences |||||||||||| Background Sequences\n",
1443
+ "promoter |||||||||||| promoter\n",
1444
+ "promoter |||||||||||| promoter\n",
1445
+ "Non-promoter |||||||||||| Non-promoter\n",
1446
+ "Background Sequences |||||||||||| Background Sequences\n",
1447
+ "Non-promoter |||||||||||| Non-promoter\n",
1448
+ "Non-promoter |||||||||||| Non-promoter\n",
1449
+ "Non-promoter |||||||||||| promoter\n",
1450
+ "Non-Splice Sites |||||||||||| Acceptor Sites\n",
1451
+ "promoter |||||||||||| promoter\n",
1452
+ "Non-promoter |||||||||||| promoter\n",
1453
+ "Non-promoter |||||||||||| Non-promoter\n",
1454
+ "Background Sequences |||||||||||| Background Sequences\n",
1455
+ "promoter |||||||||||| Non-promoter\n",
1456
+ "promoter |||||||||||| Non-promoter\n",
1457
+ "Background Sequences |||||||||||| Background Sequences\n",
1458
+ "Background Sequences |||||||||||| Background Sequences\n",
1459
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1460
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1461
+ "Non-promoter |||||||||||| Non-promoter\n",
1462
+ "Non-promoter |||||||||||| Non-promoter\n",
1463
+ "promoter |||||||||||| promoter\n",
1464
+ "Non-promoter |||||||||||| Non-promoter\n",
1465
+ "promoter |||||||||||| Non-promoter\n",
1466
+ "Binding Sites |||||||||||| Background Sequences\n",
1467
+ "Binding Sites |||||||||||| Binding Sites\n",
1468
+ "Non-promoter |||||||||||| Non-promoter\n",
1469
+ "promoter |||||||||||| promoter\n",
1470
+ "Non-promoter |||||||||||| Non-promoter\n",
1471
+ "promoter |||||||||||| promoter\n",
1472
+ "Binding Sites |||||||||||| Binding Sites\n",
1473
+ "Non-promoter |||||||||||| Non-promoter\n",
1474
+ "Non-promoter |||||||||||| Non-promoter\n",
1475
+ "promoter |||||||||||| promoter\n",
1476
+ "Non-promoter |||||||||||| Non-promoter\n",
1477
+ "Binding Sites |||||||||||| Background Sequences\n",
1478
+ "Donor Sites |||||||||||| Donor Sites\n",
1479
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1480
+ "Non-promoter |||||||||||| promoter\n",
1481
+ "Non-promoter |||||||||||| Non-promoter\n",
1482
+ "promoter |||||||||||| promoter\n",
1483
+ "Background Sequences |||||||||||| Background Sequences\n",
1484
+ "Donor Sites |||||||||||| Donor Sites\n",
1485
+ "Non-promoter |||||||||||| promoter\n",
1486
+ "promoter |||||||||||| promoter\n",
1487
+ "Non-Splice Sites |||||||||||| Donor Sites\n",
1488
+ "Binding Sites |||||||||||| Binding Sites\n",
1489
+ "Non-promoter |||||||||||| promoter\n",
1490
+ "Donor Sites |||||||||||| Donor Sites\n",
1491
+ "promoter |||||||||||| promoter\n",
1492
+ "promoter |||||||||||| promoter\n",
1493
+ "Non-promoter |||||||||||| promoter\n",
1494
+ "Non-promoter |||||||||||| Non-promoter\n",
1495
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1496
+ "Non-promoter |||||||||||| Non-promoter\n",
1497
+ "Background Sequences |||||||||||| Background Sequences\n",
1498
+ "promoter |||||||||||| Non-promoter\n",
1499
+ "Acceptor Sites |||||||||||| Acceptor Sites\n",
1500
+ "Donor Sites |||||||||||| Donor Sites\n",
1501
+ "promoter |||||||||||| promoter\n",
1502
+ "Binding Sites |||||||||||| Binding Sites\n",
1503
+ "promoter |||||||||||| promoter\n",
1504
+ "Donor Sites |||||||||||| Donor Sites\n",
1505
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1506
+ "promoter |||||||||||| promoter\n",
1507
+ "Non-promoter |||||||||||| Non-promoter\n",
1508
+ "promoter |||||||||||| Non-promoter\n",
1509
+ "Binding Sites |||||||||||| Binding Sites\n",
1510
+ "Non-promoter |||||||||||| Non-promoter\n",
1511
+ "Non-promoter |||||||||||| promoter\n",
1512
+ "Non-promoter |||||||||||| Non-promoter\n",
1513
+ "Non-promoter |||||||||||| Non-promoter\n",
1514
+ "Non-promoter |||||||||||| Non-promoter\n",
1515
+ "Non-promoter |||||||||||| promoter\n",
1516
+ "promoter |||||||||||| promoter\n",
1517
+ "Background Sequences |||||||||||| Binding Sites\n",
1518
+ "Non-promoter |||||||||||| promoter\n",
1519
+ "Donor Sites |||||||||||| Non-Splice Sites\n",
1520
+ "Donor Sites |||||||||||| Donor Sites\n",
1521
+ "Non-Splice Sites |||||||||||| Non-Splice Sites\n",
1522
+ "Non-promoter |||||||||||| Non-promoter\n",
1523
+ "promoter |||||||||||| Non-promoter\n",
1524
+ "Non-promoter |||||||||||| promoter\n",
1525
+ "promoter |||||||||||| Non-promoter\n",
1526
+ "promoter |||||||||||| promoter\n",
1527
+ "promoter |||||||||||| promoter\n",
1528
+ "Donor Sites |||||||||||| Donor Sites\n",
1529
+ "promoter |||||||||||| promoter\n",
1530
+ "Donor Sites |||||||||||| Donor Sites\n",
1531
+ "Non-promoter |||||||||||| Non-promoter\n",
1532
+ "Donor Sites |||||||||||| Donor Sites\n",
1533
+ "Non-promoter |||||||||||| promoter\n",
1534
+ "Donor Sites |||||||||||| Acceptor Sites\n",
1535
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1536
+ "Non-promoter |||||||||||| Non-promoter\n",
1537
+ "promoter |||||||||||| Non-promoter\n",
1538
+ "promoter |||||||||||| promoter\n",
1539
+ "Non-promoter |||||||||||| Non-promoter\n",
1540
+ "Non-promoter |||||||||||| Non-promoter\n",
1541
+ "promoter |||||||||||| Non-promoter\n",
1542
+ "promoter |||||||||||| promoter\n",
1543
+ "promoter |||||||||||| promoter\n",
1544
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1545
+ "Non-promoter |||||||||||| Non-promoter\n",
1546
+ "Acceptor Sites |||||||||||| Donor Sites\n",
1547
+ "promoter |||||||||||| promoter\n",
1548
+ "promoter |||||||||||| promoter\n",
1549
+ "Background Sequences |||||||||||| Background Sequences\n",
1550
+ "Binding Sites |||||||||||| Binding Sites\n",
1551
+ "Donor Sites |||||||||||| Donor Sites\n",
1552
+ "Binding Sites |||||||||||| Binding Sites\n",
1553
+ "Non-promoter |||||||||||| Non-promoter\n",
1554
+ "promoter |||||||||||| promoter\n",
1555
+ "Background Sequences |||||||||||| Binding Sites\n",
1556
+ "Non-promoter |||||||||||| Non-promoter\n",
1557
+ "Background Sequences |||||||||||| Background Sequences\n",
1558
+ "promoter |||||||||||| promoter\n",
1559
+ "Non-promoter |||||||||||| Non-promoter\n",
1560
+ "presicion 0.739 same 0.253\n"
1561
+ ]
1562
+ }
1563
+ ],
1564
+ "source": [
1565
+ "import json\n",
1566
+ "from tqdm import tqdm\n",
1567
+ "\n",
1568
+ "\n",
1569
+ "\n",
1570
+ "with open(output_file, \"r\") as file:\n",
1571
+ " test_data = json.load(file)\n",
1572
+ "\n",
1573
+ "all_num = len(test_data)\n",
1574
+ "right_sum = 0\n",
1575
+ "same_sum = 0\n",
1576
+ "for item in test_data:\n",
1577
+ " output = item[\"output\"]\n",
1578
+ " #output = \" \".join(tokenizer.tokenize(output))\n",
1579
+ " model_response = item[\"model_response\"]\n",
1580
+ "\n",
1581
+ " print(output,\"||||||||||||\", model_response)\n",
1582
+ "\n",
1583
+ " if model_response == output: #same it\n",
1584
+ " same_sum = same_sum + 1\n",
1585
+ " \n",
1586
+ " if output.find(\"Non\")==-1: # no Non\n",
1587
+ " if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n",
1588
+ " right_sum = right_sum + 1\n",
1589
+ " else:\n",
1590
+ " if model_response.find(output)!=-1: #find it\n",
1591
+ " right_sum = right_sum + 1\n",
1592
+ "\n",
1593
+ "\n",
1594
+ "print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n"
1595
+ ]
1596
+ },
1597
+ {
1598
+ "cell_type": "code",
1599
+ "execution_count": null,
1600
+ "id": "294d46f3-2f5b-4e55-ae41-081d5195f5e2",
1601
+ "metadata": {},
1602
+ "outputs": [],
1603
+ "source": []
1604
+ }
1605
+ ],
1606
+ "metadata": {
1607
+ "kernelspec": {
1608
+ "display_name": "Python 3 (ipykernel)",
1609
+ "language": "python",
1610
+ "name": "python3"
1611
+ },
1612
+ "language_info": {
1613
+ "codemirror_mode": {
1614
+ "name": "ipython",
1615
+ "version": 3
1616
+ },
1617
+ "file_extension": ".py",
1618
+ "mimetype": "text/x-python",
1619
+ "name": "python",
1620
+ "nbconvert_exporter": "python",
1621
+ "pygments_lexer": "ipython3",
1622
+ "version": "3.12.3"
1623
+ }
1624
+ },
1625
+ "nbformat": 4,
1626
+ "nbformat_minor": 5
1627
+ }
04-gene-sft/merge_llama_with_dna_lora.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python merge_llama_with_chinese_lora.py \
4
+ --base_model path/to/llama/model \
5
+ --lora_model path/to/first/lora/model [path/to/second/lora/model] \
6
+ --output_type [pth|huggingface] \
7
+ --output_dir path/to/output/dir
8
+ """
9
+
10
+ import os
11
+
12
+ # 设置环境变量
13
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
14
+
15
+ # 打印环境变量以确认设置成功
16
+ print(os.environ.get('HF_ENDPOINT'))
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import gc
22
+ import torch
23
+ import peft
24
+ from peft import PeftModel
25
+ from transformers import LlamaForCausalLM, LlamaTokenizer
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument('--base_model', default=None, required=True,
30
+ type=str, help="Please specify a base_model")
31
+ parser.add_argument('--lora_model', default=None, required=True,
32
+ type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models.")
33
+ parser.add_argument('--offload_dir', default=None, type=str,
34
+ help="(Optional) Please specify a temp folder for offloading (useful for low-RAM machines). Default None (disable offload).")
35
+ parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], type=str,
36
+ help="save the merged model in pth or huggingface format.")
37
+ parser.add_argument('--output_dir', default='./', type=str)
38
+
39
+
40
+ emb_to_model_size = {
41
+ 4096 : '7B',
42
+ 5120 : '13B',
43
+ 6656 : '33B',
44
+ 8192 : '65B',
45
+ }
46
+ num_shards_of_models = {'7B': 1, '13B': 2, '33B': 4, '65B': 8}
47
+ params_of_models = {
48
+ '7B':
49
+ {
50
+ "dim": 4096,
51
+ "multiple_of": 256,
52
+ "n_heads": 32,
53
+ "n_layers": 32,
54
+ "norm_eps": 1e-06,
55
+ "vocab_size": -1,
56
+ },
57
+ '13B':
58
+ {
59
+ "dim": 5120,
60
+ "multiple_of": 256,
61
+ "n_heads": 40,
62
+ "n_layers": 40,
63
+ "norm_eps": 1e-06,
64
+ "vocab_size": -1,
65
+ },
66
+ '33B':
67
+ {
68
+ "dim": 6656,
69
+ "multiple_of": 256,
70
+ "n_heads": 52,
71
+ "n_layers": 60,
72
+ "norm_eps": 1e-06,
73
+ "vocab_size": -1,
74
+ },
75
+ '65B':
76
+ {
77
+ "dim": 8192,
78
+ "multiple_of": 256,
79
+ "n_heads": 64,
80
+ "n_layers": 80,
81
+ "norm_eps": 1e-05,
82
+ "vocab_size": -1,
83
+ },
84
+ }
85
+
86
+ def transpose(weight, fan_in_fan_out):
87
+ return weight.T if fan_in_fan_out else weight
88
+
89
+ # Borrowed and modified from https://github.com/tloen/alpaca-lora
90
+ def translate_state_dict_key(k):
91
+ k = k.replace("base_model.model.", "")
92
+ if k == "model.embed_tokens.weight":
93
+ return "tok_embeddings.weight"
94
+ elif k == "model.norm.weight":
95
+ return "norm.weight"
96
+ elif k == "lm_head.weight":
97
+ return "output.weight"
98
+ elif k.startswith("model.layers."):
99
+ layer = k.split(".")[2]
100
+ if k.endswith(".self_attn.q_proj.weight"):
101
+ return f"layers.{layer}.attention.wq.weight"
102
+ elif k.endswith(".self_attn.k_proj.weight"):
103
+ return f"layers.{layer}.attention.wk.weight"
104
+ elif k.endswith(".self_attn.v_proj.weight"):
105
+ return f"layers.{layer}.attention.wv.weight"
106
+ elif k.endswith(".self_attn.o_proj.weight"):
107
+ return f"layers.{layer}.attention.wo.weight"
108
+ elif k.endswith(".mlp.gate_proj.weight"):
109
+ return f"layers.{layer}.feed_forward.w1.weight"
110
+ elif k.endswith(".mlp.down_proj.weight"):
111
+ return f"layers.{layer}.feed_forward.w2.weight"
112
+ elif k.endswith(".mlp.up_proj.weight"):
113
+ return f"layers.{layer}.feed_forward.w3.weight"
114
+ elif k.endswith(".input_layernorm.weight"):
115
+ return f"layers.{layer}.attention_norm.weight"
116
+ elif k.endswith(".post_attention_layernorm.weight"):
117
+ return f"layers.{layer}.ffn_norm.weight"
118
+ elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
119
+ return None
120
+ else:
121
+ print(layer, k)
122
+ raise NotImplementedError
123
+ else:
124
+ print(k)
125
+ raise NotImplementedError
126
+
127
+
128
+ def unpermute(w):
129
+ return (
130
+ w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
131
+ )
132
+
133
+
134
+ def save_shards(model_sd, num_shards: int):
135
+ # Add the no_grad context manager
136
+ with torch.no_grad():
137
+ if num_shards == 1:
138
+ new_state_dict = {}
139
+ for k, v in model_sd.items():
140
+ new_k = translate_state_dict_key(k)
141
+ if new_k is not None:
142
+ if "wq" in new_k or "wk" in new_k:
143
+ new_state_dict[new_k] = unpermute(v)
144
+ else:
145
+ new_state_dict[new_k] = v
146
+
147
+ os.makedirs(output_dir, exist_ok=True)
148
+ print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth")
149
+ torch.save(new_state_dict, output_dir + "/consolidated.00.pth")
150
+ with open(output_dir + "/params.json", "w") as f:
151
+ json.dump(params, f)
152
+ else:
153
+ new_state_dicts = [dict() for _ in range(num_shards)]
154
+ for k in list(model_sd.keys()):
155
+ v = model_sd[k]
156
+ new_k = translate_state_dict_key(k)
157
+ if new_k is not None:
158
+ if new_k=='tok_embeddings.weight':
159
+ print(f"Processing {new_k}")
160
+ assert v.size(1)%num_shards==0
161
+ splits = v.split(v.size(1)//num_shards,dim=1)
162
+ elif new_k=='output.weight':
163
+ print(f"Processing {new_k}")
164
+ if v.size(0)%num_shards==0:
165
+ splits = v.split(v.size(0)//num_shards,dim=0)
166
+ else:
167
+ size_list = [v.size(0)//num_shards] * num_shards
168
+ size_list[-1] += v.size(0)%num_shards
169
+ splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
170
+ elif new_k=='norm.weight':
171
+ print(f"Processing {new_k}")
172
+ splits = [v] * num_shards
173
+ elif 'ffn_norm.weight' in new_k:
174
+ print(f"Processing {new_k}")
175
+ splits = [v] * num_shards
176
+ elif 'attention_norm.weight' in new_k:
177
+ print(f"Processing {new_k}")
178
+ splits = [v] * num_shards
179
+
180
+
181
+ elif 'w1.weight' in new_k:
182
+ print(f"Processing {new_k}")
183
+ splits = v.split(v.size(0)//num_shards,dim=0)
184
+ elif 'w2.weight' in new_k:
185
+ print(f"Processing {new_k}")
186
+ splits = v.split(v.size(1)//num_shards,dim=1)
187
+ elif 'w3.weight' in new_k:
188
+ print(f"Processing {new_k}")
189
+ splits = v.split(v.size(0)//num_shards,dim=0)
190
+
191
+
192
+ elif 'wo.weight' in new_k:
193
+ print(f"Processing {new_k}")
194
+ splits = v.split(v.size(1)//num_shards,dim=1)
195
+
196
+ elif 'wv.weight' in new_k:
197
+ print(f"Processing {new_k}")
198
+ splits = v.split(v.size(0)//num_shards,dim=0)
199
+
200
+ elif "wq.weight" in new_k or "wk.weight" in new_k:
201
+ print(f"Processing {new_k}")
202
+ v = unpermute(v)
203
+ splits = v.split(v.size(0)//num_shards,dim=0)
204
+ else:
205
+ print(f"Unexpected key {new_k}")
206
+ raise ValueError
207
+ for sd,split in zip(new_state_dicts,splits):
208
+ sd[new_k] = split.clone()
209
+ del split
210
+ del splits
211
+ del model_sd[k],v
212
+ gc.collect() # Effectively enforce garbage collection
213
+
214
+ os.makedirs(output_dir, exist_ok=True)
215
+ for i,new_state_dict in enumerate(new_state_dicts):
216
+ print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth")
217
+ torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth")
218
+ with open(output_dir + "/params.json", "w") as f:
219
+ print(f"Saving params.json into {output_dir}/params.json")
220
+ json.dump(params, f)
221
+
222
+
223
+ if __name__=='__main__':
224
+
225
+ args = parser.parse_args()
226
+ base_model_path = args.base_model
227
+ lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0]
228
+ output_dir = args.output_dir
229
+ output_type = args.output_type
230
+ offload_dir = args.offload_dir
231
+
232
+ print(f"Base model: {base_model_path}")
233
+ print(f"LoRA model(s) {lora_model_paths}:")
234
+
235
+ if offload_dir is not None:
236
+ # Load with offloading, which is useful for low-RAM machines.
237
+ # Note that if you have enough RAM, please use original method instead, as it is faster.
238
+ base_model = LlamaForCausalLM.from_pretrained(
239
+ base_model_path,
240
+ load_in_8bit=False,
241
+ torch_dtype=torch.float16,
242
+ offload_folder=offload_dir,
243
+ offload_state_dict=True,
244
+ low_cpu_mem_usage=True,
245
+ device_map={"": "cpu"},
246
+ )
247
+ else:
248
+ # Original method without offloading
249
+ base_model = LlamaForCausalLM.from_pretrained(
250
+ base_model_path,
251
+ load_in_8bit=False,
252
+ torch_dtype=torch.float16,
253
+ device_map={"": "cpu"},
254
+ cache_dir=None, # 不使用缓存目录
255
+ force_download=False, # 禁止从远程下载
256
+ local_files_only=True # 强制仅从本地文件加载
257
+ )
258
+
259
+ ## infer the model size from the checkpoint
260
+ embedding_size = base_model.get_input_embeddings().weight.size(1)
261
+ model_size = emb_to_model_size[embedding_size]
262
+ print(f"Peft version: {peft.__version__}")
263
+ print(f"Loading LoRA for {model_size} model")
264
+
265
+ lora_model = None
266
+ lora_model_sd = None
267
+ for lora_index, lora_model_path in enumerate(lora_model_paths):
268
+ print(f"Loading LoRA {lora_model_path}...")
269
+ tokenizer = LlamaTokenizer.from_pretrained(lora_model_path,
270
+ cache_dir=None, # 不使用缓存目录
271
+ force_download=False, # 禁止从远程下载
272
+ local_files_only=True # 强制仅从本地文件加载
273
+ )
274
+
275
+ print(f"base_model vocab size: {base_model.get_input_embeddings().weight.size(0)}")
276
+ print(f"tokenizer vocab size: {len(tokenizer)}")
277
+
278
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
279
+ assert len(tokenizer) >= model_vocab_size, \
280
+ (f"The vocab size of the tokenizer {len(tokenizer)} is smaller than the vocab size of the base model {model_vocab_size}\n"
281
+ "This is not the intended use. Please check your model and tokenizer.")
282
+ if model_vocab_size != len(tokenizer):
283
+ base_model.resize_token_embeddings(len(tokenizer))
284
+ print(f"Extended vocabulary size to {len(tokenizer)}")
285
+
286
+ first_weight = base_model.model.layers[0].self_attn.q_proj.weight
287
+ first_weight_old = first_weight.clone()
288
+
289
+ print(f"Loading LoRA weights")
290
+ if hasattr(peft.LoraModel,'merge_and_unload'):
291
+ try:
292
+ lora_model = PeftModel.from_pretrained(
293
+ base_model,
294
+ lora_model_path,
295
+ device_map={"": "cpu"},
296
+ torch_dtype=torch.float16,
297
+ local_files_only=True
298
+ )
299
+ except RuntimeError as e:
300
+ if '[49953, 4096]' in str(e):
301
+ print("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
302
+ "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
303
+ "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")
304
+ raise e
305
+ assert torch.allclose(first_weight_old, first_weight)
306
+ print(f"Merging with merge_and_unload...")
307
+ base_model = lora_model.merge_and_unload()
308
+ else:
309
+ base_model_sd = base_model.state_dict()
310
+ try:
311
+ lora_model_sd = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu')
312
+ except FileNotFoundError:
313
+ print("Cannot find lora model on the disk. Downloading lora model from hub...")
314
+ filename = hf_hub_download(repo_id=lora_model_path,filename='adapter_model.bin')
315
+ lora_model_sd = torch.load(filename,map_location='cpu')
316
+ if 'base_model.model.model.embed_tokens.weight' in lora_model_sd:
317
+ assert lora_model_sd['base_model.model.model.embed_tokens.weight'].shape[0]==len(tokenizer), \
318
+ ("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
319
+ "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
320
+ "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")
321
+
322
+ lora_config = peft.LoraConfig.from_pretrained(lora_model_path)
323
+ lora_scaling = lora_config.lora_alpha / lora_config.r
324
+ fan_in_fan_out = lora_config.fan_in_fan_out
325
+ lora_keys = [k for k in lora_model_sd if 'lora_A' in k]
326
+ non_lora_keys = [k for k in lora_model_sd if not 'lora_' in k]
327
+
328
+ for k in non_lora_keys:
329
+ print(f"merging {k}")
330
+ original_k = k.replace('base_model.model.','')
331
+ base_model_sd[original_k].copy_(lora_model_sd[k])
332
+
333
+ for k in lora_keys:
334
+ print(f"merging {k}")
335
+ original_key = k.replace('.lora_A','').replace('base_model.model.','')
336
+ assert original_key in base_model_sd
337
+ lora_a_key = k
338
+ lora_b_key = k.replace('lora_A','lora_B')
339
+ base_model_sd[original_key] += (
340
+ transpose(lora_model_sd[lora_b_key].float() @ lora_model_sd[lora_a_key].float(),fan_in_fan_out) * lora_scaling
341
+ )
342
+ assert base_model_sd[original_key].dtype == torch.float16
343
+
344
+ # did we do anything?
345
+ assert not torch.allclose(first_weight_old, first_weight)
346
+
347
+ tokenizer.save_pretrained(output_dir)
348
+
349
+ if output_type=='huggingface':
350
+ print("Saving to Hugging Face format...")
351
+ LlamaForCausalLM.save_pretrained(base_model, output_dir) #, state_dict=deloreanized_sd)
352
+ else: # output_type=='pth
353
+ print("Saving to pth format...")
354
+
355
+ base_model_sd = base_model.state_dict()
356
+ del lora_model, base_model, lora_model_sd
357
+
358
+ params = params_of_models[model_size]
359
+ num_shards = num_shards_of_models[model_size]
360
+ n_layers = params["n_layers"]
361
+ n_heads = params["n_heads"]
362
+ dim = params["dim"]
363
+ dims_per_head = dim // n_heads
364
+ base = 10000.0
365
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
366
+
367
+ save_shards(model_sd=base_model_sd, num_shards=num_shards)
04-gene-sft/merge_pt_model.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ python merge_llama_with_dna_lora.py \
3
+ --base_model llama-7b-hf \
4
+ --lora_model dnahlm_llama_7b/pt_lora_model \
5
+ --output_type huggingface \
6
+ --output_dir dnahlm-merge-hf
04-gene-sft/merge_sft_model.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ python merge_llama_with_dna_lora.py \
3
+ --base_model dnahlm-merge-hf \
4
+ --lora_model dnahlm-llama7b-sft/sft_lora_model \
5
+ --output_type huggingface \
6
+ --output_dir dnahlm-llama-7b-sft-v0
04-gene-sft/merged_gene_eng_tokenizer_hf/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
04-gene-sft/merged_gene_eng_tokenizer_hf/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f9bfd4fff4bf6132695295a6443cf0c9fdf923ba58ea628e5efbeb25ce95aed
3
+ size 1360570
04-gene-sft/merged_gene_eng_tokenizer_hf/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": true,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": null,
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
04-gene-sft/merged_gene_eng_tokenizer_sp/gene_eng_llama_tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f9bfd4fff4bf6132695295a6443cf0c9fdf923ba58ea628e5efbeb25ce95aed
3
+ size 1360570
04-gene-sft/run_clm_pt_with_peft.py CHANGED
@@ -68,7 +68,11 @@ class SavePeftModelCallback(transformers.TrainerCallback):
68
 
69
  peft_model_path = os.path.join(checkpoint_folder, "pt_lora_model")
70
  kwargs["model"].save_pretrained(peft_model_path)
71
- kwargs["tokenizer"].save_pretrained(peft_model_path)
 
 
 
 
72
 
73
  def on_save(self, args, state, control, **kwargs):
74
  self.save_model(args, state, kwargs)
@@ -77,7 +81,11 @@ class SavePeftModelCallback(transformers.TrainerCallback):
77
  def on_train_end(self, args, state, control, **kwargs):
78
  peft_model_path = os.path.join(args.output_dir, "pt_lora_model")
79
  kwargs["model"].save_pretrained(peft_model_path)
80
- kwargs["tokenizer"].save_pretrained(peft_model_path)
 
 
 
 
81
 
82
 
83
  def accuracy(predictions, references, normalize=True, sample_weight=None):
 
68
 
69
  peft_model_path = os.path.join(checkpoint_folder, "pt_lora_model")
70
  kwargs["model"].save_pretrained(peft_model_path)
71
+
72
+ if "tokenizer" in kwargs:
73
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
74
+ else:
75
+ kwargs["processing_class"].save_pretrained(peft_model_path)
76
 
77
  def on_save(self, args, state, control, **kwargs):
78
  self.save_model(args, state, kwargs)
 
81
  def on_train_end(self, args, state, control, **kwargs):
82
  peft_model_path = os.path.join(args.output_dir, "pt_lora_model")
83
  kwargs["model"].save_pretrained(peft_model_path)
84
+
85
+ if "tokenizer" in kwargs:
86
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
87
+ else:
88
+ kwargs["processing_class"].save_pretrained(peft_model_path)
89
 
90
 
91
  def accuracy(predictions, references, normalize=True, sample_weight=None):
04-gene-sft/run_clm_sft_with_peft.py CHANGED
@@ -69,7 +69,12 @@ class SavePeftModelCallback(transformers.TrainerCallback):
69
 
70
  peft_model_path = os.path.join(checkpoint_folder, "sft_lora_model")
71
  kwargs["model"].save_pretrained(peft_model_path)
72
- kwargs["tokenizer"].save_pretrained(peft_model_path)
 
 
 
 
 
73
 
74
  def on_save(self, args, state, control, **kwargs):
75
  self.save_model(args, state, kwargs)
@@ -78,7 +83,12 @@ class SavePeftModelCallback(transformers.TrainerCallback):
78
  def on_train_end(self, args, state, control, **kwargs):
79
  peft_model_path = os.path.join(args.output_dir, "sft_lora_model")
80
  kwargs["model"].save_pretrained(peft_model_path)
81
- kwargs["tokenizer"].save_pretrained(peft_model_path)
 
 
 
 
 
82
 
83
 
84
  @dataclass
 
69
 
70
  peft_model_path = os.path.join(checkpoint_folder, "sft_lora_model")
71
  kwargs["model"].save_pretrained(peft_model_path)
72
+
73
+ if "tokenizer" in kwargs:
74
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
75
+ else:
76
+ kwargs["processing_class"].save_pretrained(peft_model_path)
77
+
78
 
79
  def on_save(self, args, state, control, **kwargs):
80
  self.save_model(args, state, kwargs)
 
83
  def on_train_end(self, args, state, control, **kwargs):
84
  peft_model_path = os.path.join(args.output_dir, "sft_lora_model")
85
  kwargs["model"].save_pretrained(peft_model_path)
86
+
87
+ if "tokenizer" in kwargs:
88
+ kwargs["tokenizer"].save_pretrained(peft_model_path)
89
+ else:
90
+ kwargs["processing_class"].save_pretrained(peft_model_path)
91
+
92
 
93
 
94
  @dataclass
04-gene-sft/run_sft.sh CHANGED
@@ -56,5 +56,4 @@ torchrun --nnodes 1 --nproc_per_node 6 run_clm_sft_with_peft.py \
56
  --torch_dtype float16 \
57
  --validation_file ${validation_file} \
58
  --gradient_checkpointing \
59
- --ddp_find_unused_parameters False \
60
- --resume_from_checkpoint dnahlm-llama7b-sft/checkpoint-464
 
56
  --torch_dtype float16 \
57
  --validation_file ${validation_file} \
58
  --gradient_checkpointing \
59
+ --ddp_find_unused_parameters False
 
04-gene-sft/train_data/dna_1g.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32d950f86ccdb368f4652795117d23898dbccfce5a18a0ee84f78aebc43e8742
3
+ size 1080669660
04-gene-sft/train_data/english_500m.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:085ebb9d461cae266410953bcd2d07d9a08d50cd93db24d5c3e15d38275cd8cd
3
+ size 541727453
04-gene-sft/train_data/protein_1g.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1c361441538520a5501605fa483970b80d72b5dbb28dbe5276890c8632ba1d4
3
+ size 1059750637