Upload 4 files
Browse files- data_prep.ipynb +1013 -0
- data_prep.pdf +0 -0
- training.ipynb +472 -0
- training.pdf +0 -0
data_prep.ipynb
ADDED
@@ -0,0 +1,1013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "8198fee9-000e-4ef9-bb13-82c649c2e816",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Data prep for retrieving beliefs for dialogs\n",
|
9 |
+
"\n",
|
10 |
+
"**Goal:** Create a dataset to match dialogs with (possibly) relevant facts \n",
|
11 |
+
" \n",
|
12 |
+
"**Method:**\n",
|
13 |
+
"- [x] Use stacked_samsum as training dataset\n",
|
14 |
+
"- [x] Prepare datasets\n",
|
15 |
+
" - [x] remove unnecessary columns\n",
|
16 |
+
" - [x] expand the stacked dataset\n",
|
17 |
+
" - [x] truncate on the right to create dangling examples\n",
|
18 |
+
" - [x] augment dialog using openai to make longer"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"id": "fe53fc09-0942-4e9a-921c-3804a1ede8ac",
|
24 |
+
"metadata": {},
|
25 |
+
"source": [
|
26 |
+
"### Constants"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 2,
|
32 |
+
"id": "94dea7bd-f87b-4559-bd82-dadf3dfd6025",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"model_name = \"BAAI/bge-small-en-v1.5\"\n",
|
37 |
+
"max_len = 512\n",
|
38 |
+
"next_concept_sep = \"\\n[NEXT_CONCEPT]\\n\"\n",
|
39 |
+
"training_input_file = \"./data/train-soft.jsonl\"\n",
|
40 |
+
"eval_input_file = \"./data/eval.jsonl\"\n",
|
41 |
+
"training_hn_file = \"./data/train.jsonl\"\n",
|
42 |
+
"eval_size = 12_500\n",
|
43 |
+
"seed = 42\n",
|
44 |
+
"query_prefix = \"Represent this sentence for searching relevant passages: \"\n",
|
45 |
+
"hf_repo_name = \"julep-ai/dfe-stacked_samsum\""
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "6a1ec397-3b13-4e2b-8e0f-9cf127378b8f",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"### Imports and utils"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 3,
|
59 |
+
"id": "7b69b396-1ef2-41f7-aea8-76cf902dec8b",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"from functools import partial\n",
|
64 |
+
"import os\n",
|
65 |
+
"import random\n",
|
66 |
+
"import time\n",
|
67 |
+
"\n",
|
68 |
+
"from datasets import load_dataset, load_from_disk\n",
|
69 |
+
"from FlagEmbedding import FlagModel\n",
|
70 |
+
"from FlagEmbedding.baai_general_embedding.finetune.hn_mine import find_knn_neg\n",
|
71 |
+
"from huggingface_hub import HfApi\n",
|
72 |
+
"import jsonlines as jsonl\n",
|
73 |
+
"import langchain\n",
|
74 |
+
"from langchain.cache import SQLiteCache\n",
|
75 |
+
"from langchain.llms import OpenAI\n",
|
76 |
+
"from langchain.prompts import PromptTemplate\n",
|
77 |
+
"from math import ceil\n",
|
78 |
+
"from numpy import cumsum, dot\n",
|
79 |
+
"from numpy.linalg import norm\n",
|
80 |
+
"from tqdm.auto import tqdm\n",
|
81 |
+
"from transformers import AutoTokenizer"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "markdown",
|
86 |
+
"id": "8b7b4bfb-5b60-4a76-903d-cb528731745a",
|
87 |
+
"metadata": {},
|
88 |
+
"source": [
|
89 |
+
"#### Tokenizer"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 3,
|
95 |
+
"id": "7656e742-9baa-4acc-b536-b2a861fd1d75",
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "markdown",
|
104 |
+
"id": "5473558d-45bb-430a-9d0d-9679ea6e2bcd",
|
105 |
+
"metadata": {},
|
106 |
+
"source": [
|
107 |
+
"#### LLM"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": 5,
|
113 |
+
"id": "7dedef47-411d-4803-a2a5-4789f668e4ad",
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [],
|
116 |
+
"source": [
|
117 |
+
"langchain.llm_cache = SQLiteCache(database_path=\".langchain.db\")\n",
|
118 |
+
"llm = OpenAI(model_name=\"gpt-3.5-turbo-instruct\", temperature=0.7)"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 6,
|
124 |
+
"id": "552f665a-4d32-40d2-8269-ed6031473aec",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"prompt_template = PromptTemplate.from_template(\n",
|
129 |
+
"\"\"\"\\\n",
|
130 |
+
"You are a dialog writer. Given a dialog continue it for {n} more turns in the same style as the original speakers. You can be creative in coming up with the next turns as long as you make sure that the new dialog is consistent with the previous messages.\n",
|
131 |
+
"\n",
|
132 |
+
"### Example Dialog\n",
|
133 |
+
"\n",
|
134 |
+
"Ken: Hi, how are you?\n",
|
135 |
+
"Ang: Just peachy! You?\n",
|
136 |
+
"Ken: I'm okay...\n",
|
137 |
+
"Ang: Just okay? What's wrong?\n",
|
138 |
+
"Ken: Just stressed; work stuff, fighting with Brad, too much going on at mom's.\n",
|
139 |
+
"Ang: Hang in there, it will get better!\n",
|
140 |
+
"Ken: I know, but it's a lot.\n",
|
141 |
+
"Ang: Can I do anything to help?\n",
|
142 |
+
"Ken: You are! Listening to me vent! LOL!\n",
|
143 |
+
"Ang: Are you at least doing anything fun this weekend?\n",
|
144 |
+
"Ken: Show Saturday night, then seeing the grandkids on Sunday at the zoo.\n",
|
145 |
+
"\n",
|
146 |
+
"### Continuation\n",
|
147 |
+
"\n",
|
148 |
+
"Ang: Sounds great! That will cheer you up!\n",
|
149 |
+
"Ken: Gotta run, work calls. Love you!\n",
|
150 |
+
"Ang: Love you too! Have a fantastic day!\n",
|
151 |
+
"Ken: You too!\n",
|
152 |
+
"\n",
|
153 |
+
"### Input Dialog\n",
|
154 |
+
"\n",
|
155 |
+
"{input_dialog}\n",
|
156 |
+
"\n",
|
157 |
+
"### Continuation\n",
|
158 |
+
"\"\"\"\n",
|
159 |
+
")\n",
|
160 |
+
"\n",
|
161 |
+
"def gen_continuation(input_dialog, n=4):\n",
|
162 |
+
" wait = round(random.uniform(0.3, 1.2), 3)\n",
|
163 |
+
" time.sleep(wait)\n",
|
164 |
+
"\n",
|
165 |
+
" prompt = prompt_template.format(n=n, input_dialog=input_dialog)\n",
|
166 |
+
" continuation = llm(prompt).strip()\n",
|
167 |
+
" \n",
|
168 |
+
" return continuation"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "markdown",
|
173 |
+
"id": "2eb6f55d-ec09-4bc5-8f1a-31e521ad3121",
|
174 |
+
"metadata": {},
|
175 |
+
"source": [
|
176 |
+
"#### Dataset load"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 7,
|
182 |
+
"id": "3f5420aa-d327-4d3a-8e02-90473dcca1be",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"# Get everything, we'll split it later\n",
|
187 |
+
"dataset = load_dataset(\n",
|
188 |
+
" \"stacked-summaries/stacked-samsum-1024\", \n",
|
189 |
+
")\n",
|
190 |
+
"\n",
|
191 |
+
"\n",
|
192 |
+
"# Remove unnecessary columns\n",
|
193 |
+
"dataset = dataset.remove_columns(['chapter_length', 'summary_length', 'is_stacked',])\n",
|
194 |
+
"\n",
|
195 |
+
"# Remove empty/null dialogs\n",
|
196 |
+
"dataset = dataset.filter(\n",
|
197 |
+
" lambda row: row[\"dialogue\"]\n",
|
198 |
+
")\n",
|
199 |
+
"\n",
|
200 |
+
"# Convert windows-style line endings to unix-style\n",
|
201 |
+
"dataset = dataset.map(\n",
|
202 |
+
" lambda row: dict(dialogue=row[\"dialogue\"].replace(\"\\r\\n\", '\\n'))\n",
|
203 |
+
")"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "markdown",
|
208 |
+
"id": "1d728969-c3bc-42e5-8a49-2e8fb16f582c",
|
209 |
+
"metadata": {},
|
210 |
+
"source": [
|
211 |
+
"#### Dataset prep"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 8,
|
217 |
+
"id": "c56780b7-1e2f-458d-b370-82b6c95f5173",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"def count_tokens(row):\n",
|
222 |
+
" \"\"\"Count tokens using the tokenizer\"\"\"\n",
|
223 |
+
"\n",
|
224 |
+
" dialogue = row[\"dialogue\"]\n",
|
225 |
+
" tokens = tokenizer.encode(dialogue, add_special_tokens=False)\n",
|
226 |
+
"\n",
|
227 |
+
" return dict(token_count=len(tokens))"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 9,
|
233 |
+
"id": "416b074f-9660-40c3-9774-7ea17bfae5bb",
|
234 |
+
"metadata": {},
|
235 |
+
"outputs": [],
|
236 |
+
"source": [
|
237 |
+
"# Add token count to every row in dataset\n",
|
238 |
+
"dataset = dataset.map(count_tokens)"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": 10,
|
244 |
+
"id": "5c3666f1-0457-4304-aeff-10060405f72e",
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": [
|
248 |
+
"def offset_left(\n",
|
249 |
+
" dialogue: str,\n",
|
250 |
+
" split_offset=0,\n",
|
251 |
+
" splits=1,\n",
|
252 |
+
" max_len=max_len,\n",
|
253 |
+
"):\n",
|
254 |
+
" # Split dialog lines\n",
|
255 |
+
" lines = dialogue.split(\"\\n\")\n",
|
256 |
+
"\n",
|
257 |
+
" # Count tokens per line\n",
|
258 |
+
" toks_by_line = [\n",
|
259 |
+
" len(tokenizer.encode(line, add_special_tokens=False))\n",
|
260 |
+
" for line in lines\n",
|
261 |
+
" ]\n",
|
262 |
+
"\n",
|
263 |
+
" # Cumulative sum of tokens per line\n",
|
264 |
+
" cum_toks_by_line = cumsum(toks_by_line)\n",
|
265 |
+
"\n",
|
266 |
+
" # Total no. of tokens\n",
|
267 |
+
" total_tokens = sum(toks_by_line)\n",
|
268 |
+
"\n",
|
269 |
+
" # Return as is if total tokens is less than max len of model\n",
|
270 |
+
" if total_tokens <= max_len:\n",
|
271 |
+
" return dialogue\n",
|
272 |
+
"\n",
|
273 |
+
" # Calculate step size\n",
|
274 |
+
" step_size = ceil(total_tokens / (splits * 2))\n",
|
275 |
+
"\n",
|
276 |
+
" # Calculate left index\n",
|
277 |
+
" left_index = 0\n",
|
278 |
+
" for cum_toks in cum_toks_by_line:\n",
|
279 |
+
" if cum_toks > (split_offset * step_size):\n",
|
280 |
+
" break\n",
|
281 |
+
" \n",
|
282 |
+
" left_index += 1\n",
|
283 |
+
"\n",
|
284 |
+
" # Calculate right index\n",
|
285 |
+
" right_index = 0\n",
|
286 |
+
" for last_cum_toks in cum_toks_by_line[::-1]:\n",
|
287 |
+
" if last_cum_toks < max_len:\n",
|
288 |
+
" break\n",
|
289 |
+
" \n",
|
290 |
+
" right_index -= 1\n",
|
291 |
+
"\n",
|
292 |
+
" # Calc final section\n",
|
293 |
+
" if right_index == 0:\n",
|
294 |
+
" lines = lines[left_index:]\n",
|
295 |
+
" else:\n",
|
296 |
+
" lines = lines[left_index:right_index]\n",
|
297 |
+
"\n",
|
298 |
+
" return \"\\n\".join(lines)"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": 11,
|
304 |
+
"id": "580d654b-ed6a-4cf5-b81a-886905d0bd30",
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"def truncate_lines(dialog, num=3, min=5):\n",
|
309 |
+
" \"\"\"\n",
|
310 |
+
" Split dialog into lines and then drop the last `num` lines,\n",
|
311 |
+
" making sure there are at least `min` lines remaining.\n",
|
312 |
+
" \"\"\"\n",
|
313 |
+
" \n",
|
314 |
+
" lines = dialog.split(\"\\n\")\n",
|
315 |
+
"\n",
|
316 |
+
" # If too short, return as is\n",
|
317 |
+
" if len(lines) - num < min:\n",
|
318 |
+
" return dialog\n",
|
319 |
+
"\n",
|
320 |
+
" if num > 0:\n",
|
321 |
+
" return \"\\n\".join(lines[:-num])\n",
|
322 |
+
" else:\n",
|
323 |
+
" return \"\\n\".join(lines[-num:])\n"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "code",
|
328 |
+
"execution_count": 12,
|
329 |
+
"id": "6f8b5214-1f51-4974-8c20-b3e4a6aa33ab",
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [],
|
332 |
+
"source": [
|
333 |
+
"def expand_stacked(rows):\n",
|
334 |
+
" \"\"\"Expand stacked samsum dataset by splitting concepts in every summary per dialog\"\"\"\n",
|
335 |
+
" \n",
|
336 |
+
" # Get fields by batch\n",
|
337 |
+
" dialogues = rows[\"dialogue\"]\n",
|
338 |
+
" summaries = rows[\"summary\"]\n",
|
339 |
+
"\n",
|
340 |
+
" # Containers for final results\n",
|
341 |
+
" is_augmented = []\n",
|
342 |
+
" is_truncated = []\n",
|
343 |
+
" final_dialogues = []\n",
|
344 |
+
" final_summaries = []\n",
|
345 |
+
"\n",
|
346 |
+
" # Process every dialog and summary\n",
|
347 |
+
" for dialogue, summary in tqdm(zip(dialogues, summaries)):\n",
|
348 |
+
" # Split the summary by the NEXT_CONCEPT separator from the dataset\n",
|
349 |
+
" ss = summary.split(next_concept_sep)\n",
|
350 |
+
"\n",
|
351 |
+
" # Split different conversations within the sample\n",
|
352 |
+
" # offset on the left to try to match relevance\n",
|
353 |
+
" dd = [\n",
|
354 |
+
" offset_left(d, split_offset=1) for d in dialogue.split(\"\\n\\n\")\n",
|
355 |
+
" ]\n",
|
356 |
+
"\n",
|
357 |
+
" is_truncated += [False] * len(dd)\n",
|
358 |
+
" is_augmented += [False] * len(dd)\n",
|
359 |
+
" final_dialogues += dd\n",
|
360 |
+
" final_summaries += ss\n",
|
361 |
+
"\n",
|
362 |
+
" # ---\n",
|
363 |
+
" # Now truncate and add\n",
|
364 |
+
" truncated = [truncate_lines(d) for d in dd]\n",
|
365 |
+
"\n",
|
366 |
+
" is_augmented += [False] * len(dd)\n",
|
367 |
+
" is_truncated += [t != d for t, d in zip(truncated, dd)]\n",
|
368 |
+
" final_dialogues += truncated\n",
|
369 |
+
" final_summaries += ss\n",
|
370 |
+
"\n",
|
371 |
+
" # ---\n",
|
372 |
+
" # Now augment and add\n",
|
373 |
+
"\n",
|
374 |
+
" augmented = [\n",
|
375 |
+
" truncate_lines(d + gen_continuation(d), num=-4)\n",
|
376 |
+
" for d in dd\n",
|
377 |
+
" ]\n",
|
378 |
+
" \n",
|
379 |
+
" is_truncated += [False] * len(dd)\n",
|
380 |
+
" is_augmented += [True] * len(dd)\n",
|
381 |
+
" final_dialogues += augmented\n",
|
382 |
+
" final_summaries += ss\n",
|
383 |
+
"\n",
|
384 |
+
" return dict(\n",
|
385 |
+
" dialogue=final_dialogues,\n",
|
386 |
+
" summary=final_summaries,\n",
|
387 |
+
" is_truncated=is_truncated,\n",
|
388 |
+
" token_count=[None]*len(final_summaries),\n",
|
389 |
+
" )"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": 13,
|
395 |
+
"id": "e79f4bb3-614a-4a5a-9135-fda2dce33c55",
|
396 |
+
"metadata": {
|
397 |
+
"scrolled": true
|
398 |
+
},
|
399 |
+
"outputs": [
|
400 |
+
{
|
401 |
+
"name": "stderr",
|
402 |
+
"output_type": "stream",
|
403 |
+
"text": [
|
404 |
+
"Parameter 'function'=<function expand_stacked at 0x7f0a3a68eef0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
|
405 |
+
]
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"data": {
|
409 |
+
"application/vnd.jupyter.widget-view+json": {
|
410 |
+
"model_id": "091a1ff1b3c34d1b8cee91d5468e48a8",
|
411 |
+
"version_major": 2,
|
412 |
+
"version_minor": 0
|
413 |
+
},
|
414 |
+
"text/plain": [
|
415 |
+
"Map (num_proc=75): 0%| | 0/29441 [00:00<?, ? examples/s]"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
"metadata": {},
|
419 |
+
"output_type": "display_data"
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"data": {
|
423 |
+
"application/vnd.jupyter.widget-view+json": {
|
424 |
+
"model_id": "a41b9133bf5a4974b5525a1406590bc0",
|
425 |
+
"version_major": 2,
|
426 |
+
"version_minor": 0
|
427 |
+
},
|
428 |
+
"text/plain": [
|
429 |
+
"Map (num_proc=75): 0%| | 0/1633 [00:00<?, ? examples/s]"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
"metadata": {},
|
433 |
+
"output_type": "display_data"
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"data": {
|
437 |
+
"application/vnd.jupyter.widget-view+json": {
|
438 |
+
"model_id": "036f4b46482141bb89cc7924767c8427",
|
439 |
+
"version_major": 2,
|
440 |
+
"version_minor": 0
|
441 |
+
},
|
442 |
+
"text/plain": [
|
443 |
+
"Map (num_proc=75): 0%| | 0/1637 [00:00<?, ? examples/s]"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
"metadata": {},
|
447 |
+
"output_type": "display_data"
|
448 |
+
}
|
449 |
+
],
|
450 |
+
"source": [
|
451 |
+
"# Use batched mode to be able to expand the size of the dataset\n",
|
452 |
+
"dataset = dataset.map(expand_stacked, batch_size=10, batched=True, num_proc=75)\n",
|
453 |
+
"dataset = dataset.remove_columns([\"token_count\"])"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "code",
|
458 |
+
"execution_count": 14,
|
459 |
+
"id": "22beb7aa-f191-4660-a860-ef4169c229b1",
|
460 |
+
"metadata": {},
|
461 |
+
"outputs": [
|
462 |
+
{
|
463 |
+
"data": {
|
464 |
+
"application/vnd.jupyter.widget-view+json": {
|
465 |
+
"model_id": "44e4d3202e914fcf9b388e47c70d5e28",
|
466 |
+
"version_major": 2,
|
467 |
+
"version_minor": 0
|
468 |
+
},
|
469 |
+
"text/plain": [
|
470 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
471 |
+
]
|
472 |
+
},
|
473 |
+
"metadata": {},
|
474 |
+
"output_type": "display_data"
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"data": {
|
478 |
+
"application/vnd.jupyter.widget-view+json": {
|
479 |
+
"model_id": "d02fecde242d4109a9f88fbcaf55ec6b",
|
480 |
+
"version_major": 2,
|
481 |
+
"version_minor": 0
|
482 |
+
},
|
483 |
+
"text/plain": [
|
484 |
+
"Creating parquet from Arrow format: 0%| | 0/339 [00:00<?, ?ba/s]"
|
485 |
+
]
|
486 |
+
},
|
487 |
+
"metadata": {},
|
488 |
+
"output_type": "display_data"
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"data": {
|
492 |
+
"application/vnd.jupyter.widget-view+json": {
|
493 |
+
"model_id": "95cdcc57a7b94830af4a5661d087df9a",
|
494 |
+
"version_major": 2,
|
495 |
+
"version_minor": 0
|
496 |
+
},
|
497 |
+
"text/plain": [
|
498 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
"metadata": {},
|
502 |
+
"output_type": "display_data"
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"data": {
|
506 |
+
"application/vnd.jupyter.widget-view+json": {
|
507 |
+
"model_id": "a1289ee9a1ac4a8fa39ac9e26cc2b360",
|
508 |
+
"version_major": 2,
|
509 |
+
"version_minor": 0
|
510 |
+
},
|
511 |
+
"text/plain": [
|
512 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
"metadata": {},
|
516 |
+
"output_type": "display_data"
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"data": {
|
520 |
+
"application/vnd.jupyter.widget-view+json": {
|
521 |
+
"model_id": "4a2c5e8ccbc04b56a41071401d493b66",
|
522 |
+
"version_major": 2,
|
523 |
+
"version_minor": 0
|
524 |
+
},
|
525 |
+
"text/plain": [
|
526 |
+
"Creating parquet from Arrow format: 0%| | 0/20 [00:00<?, ?ba/s]"
|
527 |
+
]
|
528 |
+
},
|
529 |
+
"metadata": {},
|
530 |
+
"output_type": "display_data"
|
531 |
+
},
|
532 |
+
{
|
533 |
+
"data": {
|
534 |
+
"application/vnd.jupyter.widget-view+json": {
|
535 |
+
"model_id": "d887773bdb354eec83ef2d2a7f135c97",
|
536 |
+
"version_major": 2,
|
537 |
+
"version_minor": 0
|
538 |
+
},
|
539 |
+
"text/plain": [
|
540 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
541 |
+
]
|
542 |
+
},
|
543 |
+
"metadata": {},
|
544 |
+
"output_type": "display_data"
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"data": {
|
548 |
+
"application/vnd.jupyter.widget-view+json": {
|
549 |
+
"model_id": "2405e0ad01fb4117b4a90a21b764a91e",
|
550 |
+
"version_major": 2,
|
551 |
+
"version_minor": 0
|
552 |
+
},
|
553 |
+
"text/plain": [
|
554 |
+
"Pushing dataset shards to the dataset hub: 0%| | 0/1 [00:00<?, ?it/s]"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
"metadata": {},
|
558 |
+
"output_type": "display_data"
|
559 |
+
},
|
560 |
+
{
|
561 |
+
"data": {
|
562 |
+
"application/vnd.jupyter.widget-view+json": {
|
563 |
+
"model_id": "d0583b97bdb4403185692e70f2e3eb8e",
|
564 |
+
"version_major": 2,
|
565 |
+
"version_minor": 0
|
566 |
+
},
|
567 |
+
"text/plain": [
|
568 |
+
"Creating parquet from Arrow format: 0%| | 0/19 [00:00<?, ?ba/s]"
|
569 |
+
]
|
570 |
+
},
|
571 |
+
"metadata": {},
|
572 |
+
"output_type": "display_data"
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"data": {
|
576 |
+
"application/vnd.jupyter.widget-view+json": {
|
577 |
+
"model_id": "c33b9b1c9cde4256a5b3840830a29628",
|
578 |
+
"version_major": 2,
|
579 |
+
"version_minor": 0
|
580 |
+
},
|
581 |
+
"text/plain": [
|
582 |
+
"Deleting unused files from dataset repository: 0%| | 0/1 [00:00<?, ?it/s]"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
"metadata": {},
|
586 |
+
"output_type": "display_data"
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"data": {
|
590 |
+
"application/vnd.jupyter.widget-view+json": {
|
591 |
+
"model_id": "02687da16cf0401ea4f19a89e2e7ac9c",
|
592 |
+
"version_major": 2,
|
593 |
+
"version_minor": 0
|
594 |
+
},
|
595 |
+
"text/plain": [
|
596 |
+
"Downloading metadata: 0%| | 0.00/752 [00:00<?, ?B/s]"
|
597 |
+
]
|
598 |
+
},
|
599 |
+
"metadata": {},
|
600 |
+
"output_type": "display_data"
|
601 |
+
}
|
602 |
+
],
|
603 |
+
"source": [
|
604 |
+
"dataset.push_to_hub(hf_repo_name)"
|
605 |
+
]
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"cell_type": "markdown",
|
609 |
+
"id": "767a4251-fab6-47ce-8cdc-e2416d70b440",
|
610 |
+
"metadata": {},
|
611 |
+
"source": [
|
612 |
+
"### Prepare dataset for finetuning\n",
|
613 |
+
"[Docs](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune)\n",
|
614 |
+
"\n",
|
615 |
+
"Format:\n",
|
616 |
+
"```json\n",
|
617 |
+
"{\"query\": str, \"pos\": List[str], \"neg\":List[str]}\n",
|
618 |
+
"```\n",
|
619 |
+
"\n",
|
620 |
+
"Keys:\n",
|
621 |
+
"- query: belief\n",
|
622 |
+
"- pos: list of matching conversations\n",
|
623 |
+
"- neg: list of random conversations from dataset"
|
624 |
+
]
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"execution_count": 4,
|
629 |
+
"id": "ea1f2c3c-211d-4740-be1b-5eac3f57416c",
|
630 |
+
"metadata": {},
|
631 |
+
"outputs": [
|
632 |
+
{
|
633 |
+
"data": {
|
634 |
+
"application/vnd.jupyter.widget-view+json": {
|
635 |
+
"model_id": "bc180cbde424436193fbaef12800d924",
|
636 |
+
"version_major": 2,
|
637 |
+
"version_minor": 0
|
638 |
+
},
|
639 |
+
"text/plain": [
|
640 |
+
"Downloading readme: 0%| | 0.00/752 [00:00<?, ?B/s]"
|
641 |
+
]
|
642 |
+
},
|
643 |
+
"metadata": {},
|
644 |
+
"output_type": "display_data"
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"data": {
|
648 |
+
"application/vnd.jupyter.widget-view+json": {
|
649 |
+
"model_id": "4e8a26389c2b4ed6ba6be892cf0c594d",
|
650 |
+
"version_major": 2,
|
651 |
+
"version_minor": 0
|
652 |
+
},
|
653 |
+
"text/plain": [
|
654 |
+
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
|
655 |
+
]
|
656 |
+
},
|
657 |
+
"metadata": {},
|
658 |
+
"output_type": "display_data"
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"data": {
|
662 |
+
"application/vnd.jupyter.widget-view+json": {
|
663 |
+
"model_id": "dfbe638ec7424b6d99552ccf00d3703b",
|
664 |
+
"version_major": 2,
|
665 |
+
"version_minor": 0
|
666 |
+
},
|
667 |
+
"text/plain": [
|
668 |
+
"Downloading data: 0%| | 0.00/81.5M [00:00<?, ?B/s]"
|
669 |
+
]
|
670 |
+
},
|
671 |
+
"metadata": {},
|
672 |
+
"output_type": "display_data"
|
673 |
+
},
|
674 |
+
{
|
675 |
+
"data": {
|
676 |
+
"application/vnd.jupyter.widget-view+json": {
|
677 |
+
"model_id": "b88de79e28d643f38ab0b3def530c008",
|
678 |
+
"version_major": 2,
|
679 |
+
"version_minor": 0
|
680 |
+
},
|
681 |
+
"text/plain": [
|
682 |
+
"Downloading data: 0%| | 0.00/3.91M [00:00<?, ?B/s]"
|
683 |
+
]
|
684 |
+
},
|
685 |
+
"metadata": {},
|
686 |
+
"output_type": "display_data"
|
687 |
+
},
|
688 |
+
{
|
689 |
+
"data": {
|
690 |
+
"application/vnd.jupyter.widget-view+json": {
|
691 |
+
"model_id": "e9daf99846404246ab60f02caabf66ef",
|
692 |
+
"version_major": 2,
|
693 |
+
"version_minor": 0
|
694 |
+
},
|
695 |
+
"text/plain": [
|
696 |
+
"Downloading data: 0%| | 0.00/3.84M [00:00<?, ?B/s]"
|
697 |
+
]
|
698 |
+
},
|
699 |
+
"metadata": {},
|
700 |
+
"output_type": "display_data"
|
701 |
+
},
|
702 |
+
{
|
703 |
+
"data": {
|
704 |
+
"application/vnd.jupyter.widget-view+json": {
|
705 |
+
"model_id": "6c441707e6b04ff8ba0a3d68790eced7",
|
706 |
+
"version_major": 2,
|
707 |
+
"version_minor": 0
|
708 |
+
},
|
709 |
+
"text/plain": [
|
710 |
+
"Extracting data files: 0%| | 0/3 [00:00<?, ?it/s]"
|
711 |
+
]
|
712 |
+
},
|
713 |
+
"metadata": {},
|
714 |
+
"output_type": "display_data"
|
715 |
+
},
|
716 |
+
{
|
717 |
+
"data": {
|
718 |
+
"application/vnd.jupyter.widget-view+json": {
|
719 |
+
"model_id": "29fe824823394171b01394b813cabb1e",
|
720 |
+
"version_major": 2,
|
721 |
+
"version_minor": 0
|
722 |
+
},
|
723 |
+
"text/plain": [
|
724 |
+
"Generating train split: 0%| | 0/338127 [00:00<?, ? examples/s]"
|
725 |
+
]
|
726 |
+
},
|
727 |
+
"metadata": {},
|
728 |
+
"output_type": "display_data"
|
729 |
+
},
|
730 |
+
{
|
731 |
+
"data": {
|
732 |
+
"application/vnd.jupyter.widget-view+json": {
|
733 |
+
"model_id": "04895f3438bd40e6b3e34c2e7934f920",
|
734 |
+
"version_major": 2,
|
735 |
+
"version_minor": 0
|
736 |
+
},
|
737 |
+
"text/plain": [
|
738 |
+
"Generating validation split: 0%| | 0/19131 [00:00<?, ? examples/s]"
|
739 |
+
]
|
740 |
+
},
|
741 |
+
"metadata": {},
|
742 |
+
"output_type": "display_data"
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"data": {
|
746 |
+
"application/vnd.jupyter.widget-view+json": {
|
747 |
+
"model_id": "4c21386349cc4e15b259f3b462fe8d9a",
|
748 |
+
"version_major": 2,
|
749 |
+
"version_minor": 0
|
750 |
+
},
|
751 |
+
"text/plain": [
|
752 |
+
"Generating test split: 0%| | 0/18381 [00:00<?, ? examples/s]"
|
753 |
+
]
|
754 |
+
},
|
755 |
+
"metadata": {},
|
756 |
+
"output_type": "display_data"
|
757 |
+
}
|
758 |
+
],
|
759 |
+
"source": [
|
760 |
+
"dataset = load_dataset(hf_repo_name)"
|
761 |
+
]
|
762 |
+
},
|
763 |
+
{
|
764 |
+
"cell_type": "code",
|
765 |
+
"execution_count": 5,
|
766 |
+
"id": "10817e24-a6b5-49da-b1e7-6101b32a9135",
|
767 |
+
"metadata": {},
|
768 |
+
"outputs": [],
|
769 |
+
"source": [
|
770 |
+
"def pick_random(dataset, split=\"train\", far_from=0):\n",
|
771 |
+
" ds = dataset[split]\n",
|
772 |
+
" ds_len = len(ds)\n",
|
773 |
+
" mid = ds_len // 2\n",
|
774 |
+
" which_half = far_from // mid\n",
|
775 |
+
" \n",
|
776 |
+
" start = (1 - which_half) * mid\n",
|
777 |
+
" end = ds_len - which_half * mid\n",
|
778 |
+
" idx = random.randrange(start, end)\n",
|
779 |
+
" \n",
|
780 |
+
" return ds[idx]"
|
781 |
+
]
|
782 |
+
},
|
783 |
+
{
|
784 |
+
"cell_type": "code",
|
785 |
+
"execution_count": 6,
|
786 |
+
"id": "9bf3bf97-86c4-41f4-ab07-7de94ed72344",
|
787 |
+
"metadata": {},
|
788 |
+
"outputs": [
|
789 |
+
{
|
790 |
+
"data": {
|
791 |
+
"application/vnd.jupyter.widget-view+json": {
|
792 |
+
"model_id": "6c031292641d4e668428e227d0cb22e5",
|
793 |
+
"version_major": 2,
|
794 |
+
"version_minor": 0
|
795 |
+
},
|
796 |
+
"text/plain": [
|
797 |
+
" 0%| | 0/338127 [00:00<?, ?it/s]"
|
798 |
+
]
|
799 |
+
},
|
800 |
+
"metadata": {},
|
801 |
+
"output_type": "display_data"
|
802 |
+
}
|
803 |
+
],
|
804 |
+
"source": [
|
805 |
+
"with jsonl.open(training_input_file, mode='w') as writer:\n",
|
806 |
+
" for i, row in enumerate(tqdm(dataset[\"train\"], total=len(dataset[\"train\"]))):\n",
|
807 |
+
" query = row[\"summary\"]\n",
|
808 |
+
" pos = [row[\"dialogue\"]]\n",
|
809 |
+
" \n",
|
810 |
+
" neg = [\n",
|
811 |
+
" pick_random(dataset, split=\"train\", far_from=i)[\"dialogue\"]\n",
|
812 |
+
" for _ in range(3)\n",
|
813 |
+
" ]\n",
|
814 |
+
" \n",
|
815 |
+
" writer.write(dict(query=query, pos=pos, neg=neg))"
|
816 |
+
]
|
817 |
+
},
|
818 |
+
{
|
819 |
+
"cell_type": "code",
|
820 |
+
"execution_count": 7,
|
821 |
+
"id": "e07bc44f-302c-4c7c-b7c6-62c9cd9db3e4",
|
822 |
+
"metadata": {},
|
823 |
+
"outputs": [
|
824 |
+
{
|
825 |
+
"data": {
|
826 |
+
"application/vnd.jupyter.widget-view+json": {
|
827 |
+
"model_id": "5394a57e7aca4e9d9c36b2f7f3b9b0f3",
|
828 |
+
"version_major": 2,
|
829 |
+
"version_minor": 0
|
830 |
+
},
|
831 |
+
"text/plain": [
|
832 |
+
" 0%| | 0/12500 [00:00<?, ?it/s]"
|
833 |
+
]
|
834 |
+
},
|
835 |
+
"metadata": {},
|
836 |
+
"output_type": "display_data"
|
837 |
+
}
|
838 |
+
],
|
839 |
+
"source": [
|
840 |
+
"with jsonl.open(eval_input_file, mode='w') as writer:\n",
|
841 |
+
" for i, row in enumerate(tqdm(dataset[\"validation\"], total=eval_size)):\n",
|
842 |
+
" if i > eval_size:\n",
|
843 |
+
" break\n",
|
844 |
+
"\n",
|
845 |
+
" query = row[\"summary\"]\n",
|
846 |
+
" pos = [row[\"dialogue\"]]\n",
|
847 |
+
" \n",
|
848 |
+
" neg = [\n",
|
849 |
+
" pick_random(dataset, split=\"validation\", far_from=i)[\"dialogue\"]\n",
|
850 |
+
" for _ in range(3)\n",
|
851 |
+
" ]\n",
|
852 |
+
" \n",
|
853 |
+
" writer.write(dict(query=query, pos=pos, neg=neg))"
|
854 |
+
]
|
855 |
+
},
|
856 |
+
{
|
857 |
+
"cell_type": "markdown",
|
858 |
+
"id": "b6c895f9-9ef4-4edc-b65d-722188eaa8bd",
|
859 |
+
"metadata": {},
|
860 |
+
"source": [
|
861 |
+
"### Mine hard negatives"
|
862 |
+
]
|
863 |
+
},
|
864 |
+
{
|
865 |
+
"cell_type": "code",
|
866 |
+
"execution_count": 9,
|
867 |
+
"id": "b73cf693-4138-429f-8188-0a72b36ed44b",
|
868 |
+
"metadata": {},
|
869 |
+
"outputs": [],
|
870 |
+
"source": [
|
871 |
+
"model = FlagModel(\n",
|
872 |
+
" model_name,\n",
|
873 |
+
" query_instruction_for_retrieval=query_prefix,\n",
|
874 |
+
")"
|
875 |
+
]
|
876 |
+
},
|
877 |
+
{
|
878 |
+
"cell_type": "code",
|
879 |
+
"execution_count": 10,
|
880 |
+
"id": "adc677e6-c28f-49f9-a812-5cd4e93084b3",
|
881 |
+
"metadata": {},
|
882 |
+
"outputs": [
|
883 |
+
{
|
884 |
+
"name": "stdout",
|
885 |
+
"output_type": "stream",
|
886 |
+
"text": [
|
887 |
+
"inferencing embedding for corpus (number=37361)--------------\n"
|
888 |
+
]
|
889 |
+
},
|
890 |
+
{
|
891 |
+
"name": "stderr",
|
892 |
+
"output_type": "stream",
|
893 |
+
"text": [
|
894 |
+
"Inference Embeddings: 100%|██████████| 146/146 [00:37<00:00, 3.87it/s]\n"
|
895 |
+
]
|
896 |
+
},
|
897 |
+
{
|
898 |
+
"name": "stdout",
|
899 |
+
"output_type": "stream",
|
900 |
+
"text": [
|
901 |
+
"inferencing embedding for queries (number=338127)--------------\n"
|
902 |
+
]
|
903 |
+
},
|
904 |
+
{
|
905 |
+
"name": "stderr",
|
906 |
+
"output_type": "stream",
|
907 |
+
"text": [
|
908 |
+
"Inference Embeddings: 100%|██████████| 1321/1321 [00:52<00:00, 25.34it/s]\n"
|
909 |
+
]
|
910 |
+
},
|
911 |
+
{
|
912 |
+
"name": "stdout",
|
913 |
+
"output_type": "stream",
|
914 |
+
"text": [
|
915 |
+
"create index and search------------------\n"
|
916 |
+
]
|
917 |
+
},
|
918 |
+
{
|
919 |
+
"name": "stderr",
|
920 |
+
"output_type": "stream",
|
921 |
+
"text": [
|
922 |
+
"Batches: 100%|██████████| 5284/5284 [00:07<00:00, 740.63it/s]\n"
|
923 |
+
]
|
924 |
+
}
|
925 |
+
],
|
926 |
+
"source": [
|
927 |
+
"find_knn_neg(\n",
|
928 |
+
" model,\n",
|
929 |
+
" input_file=training_input_file,\n",
|
930 |
+
" candidate_pool=None,\n",
|
931 |
+
" output_file=training_hn_file,\n",
|
932 |
+
" sample_range=list(range(2, 200)),\n",
|
933 |
+
" negative_number=10,\n",
|
934 |
+
" use_gpu=True,\n",
|
935 |
+
")"
|
936 |
+
]
|
937 |
+
},
|
938 |
+
{
|
939 |
+
"cell_type": "markdown",
|
940 |
+
"id": "d408f52e-d8b8-4e6a-86bc-234d2b862a86",
|
941 |
+
"metadata": {},
|
942 |
+
"source": [
|
943 |
+
"### Add processed files to hf dataset"
|
944 |
+
]
|
945 |
+
},
|
946 |
+
{
|
947 |
+
"cell_type": "code",
|
948 |
+
"execution_count": 11,
|
949 |
+
"id": "fd79a43e-7add-4037-9b5f-5bf60db89158",
|
950 |
+
"metadata": {},
|
951 |
+
"outputs": [
|
952 |
+
{
|
953 |
+
"data": {
|
954 |
+
"application/vnd.jupyter.widget-view+json": {
|
955 |
+
"model_id": "a8be41b80f2b42c8800eb12d0ec57bf9",
|
956 |
+
"version_major": 2,
|
957 |
+
"version_minor": 0
|
958 |
+
},
|
959 |
+
"text/plain": [
|
960 |
+
"train.jsonl: 0%| | 0.00/2.42G [00:00<?, ?B/s]"
|
961 |
+
]
|
962 |
+
},
|
963 |
+
"metadata": {},
|
964 |
+
"output_type": "display_data"
|
965 |
+
}
|
966 |
+
],
|
967 |
+
"source": [
|
968 |
+
"hf_api = HfApi()\n",
|
969 |
+
"\n",
|
970 |
+
"for path in [\n",
|
971 |
+
" training_input_file,\n",
|
972 |
+
" eval_input_file,\n",
|
973 |
+
" training_hn_file,\n",
|
974 |
+
"]:\n",
|
975 |
+
" hf_api.upload_file(\n",
|
976 |
+
" path_or_fileobj=path,\n",
|
977 |
+
" path_in_repo=path.split('/')[-1],\n",
|
978 |
+
" repo_id=hf_repo_name,\n",
|
979 |
+
" repo_type=\"dataset\",\n",
|
980 |
+
" )\n"
|
981 |
+
]
|
982 |
+
},
|
983 |
+
{
|
984 |
+
"cell_type": "code",
|
985 |
+
"execution_count": null,
|
986 |
+
"id": "78410dd0-80c9-4f27-9a95-4b8a34604e1e",
|
987 |
+
"metadata": {},
|
988 |
+
"outputs": [],
|
989 |
+
"source": []
|
990 |
+
}
|
991 |
+
],
|
992 |
+
"metadata": {
|
993 |
+
"kernelspec": {
|
994 |
+
"display_name": "Python 3 (ipykernel)",
|
995 |
+
"language": "python",
|
996 |
+
"name": "python3"
|
997 |
+
},
|
998 |
+
"language_info": {
|
999 |
+
"codemirror_mode": {
|
1000 |
+
"name": "ipython",
|
1001 |
+
"version": 3
|
1002 |
+
},
|
1003 |
+
"file_extension": ".py",
|
1004 |
+
"mimetype": "text/x-python",
|
1005 |
+
"name": "python",
|
1006 |
+
"nbconvert_exporter": "python",
|
1007 |
+
"pygments_lexer": "ipython3",
|
1008 |
+
"version": "3.10.6"
|
1009 |
+
}
|
1010 |
+
},
|
1011 |
+
"nbformat": 4,
|
1012 |
+
"nbformat_minor": 5
|
1013 |
+
}
|
data_prep.pdf
ADDED
Binary file (63.1 kB). View file
|
|
training.ipynb
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "ec403ba5-1356-46b7-a14f-86bf7db0c5b4",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Train Dialog-Fact Encoder\n",
|
9 |
+
"\n",
|
10 |
+
"**Goal:** Train an embedding model to match dialogs with (possibly) relevant facts "
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"id": "723a9f8a-800a-4de0-ab89-e4d984271a5b",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"### Constants"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"id": "7167d6e4-7a7f-4f7f-b4e7-92b9613afed8",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"model_name = \"BAAI/bge-base-en-v1.5\"\n",
|
29 |
+
"query_prefix = \"Represent this sentence for searching relevant passages: \"\n",
|
30 |
+
"max_len = 512\n",
|
31 |
+
"training_hn_file = \"./data/train.jsonl\"\n",
|
32 |
+
"eval_file = \"./data/eval.jsonl\"\n",
|
33 |
+
"batch_size = 1350\n",
|
34 |
+
"output_model_path = \"./dfe-base-en\"\n",
|
35 |
+
"hf_repo_name = \"julep-ai/dfe-base-en\""
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "markdown",
|
40 |
+
"id": "22aad488-38c3-40b9-8e5b-6d47b41d49cf",
|
41 |
+
"metadata": {},
|
42 |
+
"source": [
|
43 |
+
"### Imports"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"id": "98d5e97e-df3b-43e4-b82c-2f4768a217b6",
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"import itertools as it\n",
|
54 |
+
"\n",
|
55 |
+
"import graphviz\n",
|
56 |
+
"import jsonlines as jsonl\n",
|
57 |
+
"from lion_pytorch import Lion\n",
|
58 |
+
"from sentence_transformers import InputExample, SentenceTransformer, losses as ls, models as ml, util\n",
|
59 |
+
"from sentence_transformers.evaluation import SimilarityFunction, TripletEvaluator\n",
|
60 |
+
"import torch\n",
|
61 |
+
"from torch.utils.data import DataLoader, IterableDataset\n",
|
62 |
+
"from tqdm.auto import tqdm"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"id": "72ee0c6c-2785-49ff-85ec-600b76af11b8",
|
68 |
+
"metadata": {},
|
69 |
+
"source": [
|
70 |
+
"### Dataset"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 3,
|
76 |
+
"id": "b17def02-f756-4973-a29f-dd628da34e58",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"def hn_output(file):\n",
|
81 |
+
" with jsonl.open(file) as reader:\n",
|
82 |
+
" for entry in reader:\n",
|
83 |
+
" query = entry[\"query\"]\n",
|
84 |
+
" pos = [dict(dialog=dialog) for dialog in entry[\"pos\"]]\n",
|
85 |
+
" neg = [dict(dialog=dialog) for dialog in entry[\"neg\"]]\n",
|
86 |
+
"\n",
|
87 |
+
" for combined in it.product(\n",
|
88 |
+
" [dict(fact=query)],\n",
|
89 |
+
" pos,\n",
|
90 |
+
" neg,\n",
|
91 |
+
" ):\n",
|
92 |
+
" yield InputExample(texts=list(combined))"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": 4,
|
98 |
+
"id": "34649f83-5bc3-4b1b-a1b2-3d406b84979d",
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"data": {
|
103 |
+
"application/vnd.jupyter.widget-view+json": {
|
104 |
+
"model_id": "01107f542dec483a9a48ed4b9e4b9a76",
|
105 |
+
"version_major": 2,
|
106 |
+
"version_minor": 0
|
107 |
+
},
|
108 |
+
"text/plain": [
|
109 |
+
"0it [00:00, ?it/s]"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
"metadata": {},
|
113 |
+
"output_type": "display_data"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"data": {
|
117 |
+
"application/vnd.jupyter.widget-view+json": {
|
118 |
+
"model_id": "039f46c46d724fa0aac242492248dbff",
|
119 |
+
"version_major": 2,
|
120 |
+
"version_minor": 0
|
121 |
+
},
|
122 |
+
"text/plain": [
|
123 |
+
"0it [00:00, ?it/s]"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
"metadata": {},
|
127 |
+
"output_type": "display_data"
|
128 |
+
}
|
129 |
+
],
|
130 |
+
"source": [
|
131 |
+
"training_data = list(tqdm(hn_output(training_hn_file)))\n",
|
132 |
+
"eval_data = list(tqdm(hn_output(eval_file)))"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": 5,
|
138 |
+
"id": "8e817f20-4e80-4842-bf45-f7439a5e2b7a",
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"dataloader = DataLoader(training_data, shuffle=True, batch_size=batch_size)\n",
|
143 |
+
"eval_dataloader = DataLoader(eval_data, shuffle=True, batch_size=batch_size // 10)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"id": "be0a103c-1c3d-41fa-933c-f0b843087658",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"### DFE Model Architecture"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 6,
|
157 |
+
"id": "c8eea066-1f4e-4184-9215-0b5fdd1cdf16",
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"# Base model\n",
|
162 |
+
"base_model = SentenceTransformer(model_name)"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 7,
|
168 |
+
"id": "7f31eda8-d224-4d30-8a6b-ed4cb32a2c12",
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"# Freeze base transformer layers\n",
|
173 |
+
"for param in base_model.parameters():\n",
|
174 |
+
" param.requires_grad = False"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 8,
|
180 |
+
"id": "721c3897-9ef0-409f-9e9d-a693975486bf",
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [],
|
183 |
+
"source": [
|
184 |
+
"device = torch.device(\"cuda:0\")\n",
|
185 |
+
"\n",
|
186 |
+
"# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset\n",
|
187 |
+
"# the body location\n",
|
188 |
+
"base_model._target_device = device\n",
|
189 |
+
"base_model = base_model.to(device)"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 9,
|
195 |
+
"id": "6115d96b-fe35-4a23-9a21-f3da52304f3a",
|
196 |
+
"metadata": {},
|
197 |
+
"outputs": [],
|
198 |
+
"source": [
|
199 |
+
"emb_dims = base_model._first_module().get_word_embedding_dimension() # 768\n",
|
200 |
+
"\n",
|
201 |
+
"def dense_projector(dims: int):\n",
|
202 |
+
" proj_dims = dims * 2 # 1536\n",
|
203 |
+
" \n",
|
204 |
+
" return [\n",
|
205 |
+
" ml.Dense(dims, proj_dims), # 768 -> 1536\n",
|
206 |
+
" ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n",
|
207 |
+
" ml.Dropout(0.1),\n",
|
208 |
+
" ml.Dense(proj_dims, proj_dims), # 1536 -> 1536\n",
|
209 |
+
" ml.Dense(proj_dims, dims), # 1536 -> 768\n",
|
210 |
+
" ]\n",
|
211 |
+
"\n",
|
212 |
+
"def asym_module(dims: int, keys: list[str], allow_empty_key: bool = False):\n",
|
213 |
+
" return ml.Asym(\n",
|
214 |
+
" {\n",
|
215 |
+
" key: dense_projector(dims)\n",
|
216 |
+
" for key in keys\n",
|
217 |
+
" },\n",
|
218 |
+
" allow_empty_key=allow_empty_key,\n",
|
219 |
+
" )"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "code",
|
224 |
+
"execution_count": 10,
|
225 |
+
"id": "2b273b52-b3b1-4f29-9d9a-1fe00d29c686",
|
226 |
+
"metadata": {},
|
227 |
+
"outputs": [],
|
228 |
+
"source": [
|
229 |
+
"base_model._modules[\"2\"] = asym_module(emb_dims, [\"dialog\", \"fact\"])"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 11,
|
235 |
+
"id": "03004002-b9d1-4b71-8ea5-bd2a2072c751",
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [
|
238 |
+
{
|
239 |
+
"data": {
|
240 |
+
"text/plain": [
|
241 |
+
"OrderedDict([('0',\n",
|
242 |
+
" Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),\n",
|
243 |
+
" ('1',\n",
|
244 |
+
" Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})),\n",
|
245 |
+
" ('2',\n",
|
246 |
+
" Asym(\n",
|
247 |
+
" (dialog-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
248 |
+
" (dialog-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
249 |
+
" (dialog-2): Dropout(\n",
|
250 |
+
" (dropout_layer): Dropout(p=0.1, inplace=False)\n",
|
251 |
+
" )\n",
|
252 |
+
" (dialog-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
253 |
+
" (dialog-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
254 |
+
" (fact-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
255 |
+
" (fact-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
256 |
+
" (fact-2): Dropout(\n",
|
257 |
+
" (dropout_layer): Dropout(p=0.1, inplace=False)\n",
|
258 |
+
" )\n",
|
259 |
+
" (fact-3): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
260 |
+
" (fact-4): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})\n",
|
261 |
+
" ))])"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
"execution_count": 11,
|
265 |
+
"metadata": {},
|
266 |
+
"output_type": "execute_result"
|
267 |
+
}
|
268 |
+
],
|
269 |
+
"source": [
|
270 |
+
"base_model._modules"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "markdown",
|
275 |
+
"id": "6ea33246-2612-443d-a5c0-4179eea1a126",
|
276 |
+
"metadata": {},
|
277 |
+
"source": [
|
278 |
+
"### Prepare training loss and evaluator"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 12,
|
284 |
+
"id": "e0008a08-a08d-4523-b477-212083a93aa8",
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [],
|
287 |
+
"source": [
|
288 |
+
"train_loss = ls.TripletLoss(model=base_model)"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 13,
|
294 |
+
"id": "53b0aba9-a279-4c90-8949-e0096b5ed4c7",
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"triplet_evaluator = TripletEvaluator.from_input_examples(\n",
|
299 |
+
" eval_data, # Triplet is ({dialog: <some_dialog>}, {fact: <relevant_fact>}, [{fact: <negative_irrelevant_fact>}])\n",
|
300 |
+
" batch_size=batch_size // 10,\n",
|
301 |
+
" main_distance_function=SimilarityFunction.COSINE,\n",
|
302 |
+
" show_progress_bar=True,\n",
|
303 |
+
" write_csv=True,\n",
|
304 |
+
")"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "markdown",
|
309 |
+
"id": "a6ea59f8-c1e1-404b-ba84-95c8199cd1df",
|
310 |
+
"metadata": {},
|
311 |
+
"source": [
|
312 |
+
"### Train model"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": null,
|
318 |
+
"id": "dbf3b8c9-8ef8-4198-b284-910c57f2cbca",
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [
|
321 |
+
{
|
322 |
+
"data": {
|
323 |
+
"application/vnd.jupyter.widget-view+json": {
|
324 |
+
"model_id": "ea0ed014f83b4651b810c0abd317add9",
|
325 |
+
"version_major": 2,
|
326 |
+
"version_minor": 0
|
327 |
+
},
|
328 |
+
"text/plain": [
|
329 |
+
"Epoch: 0%| | 0/15 [00:00<?, ?it/s]"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "display_data"
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"data": {
|
337 |
+
"application/vnd.jupyter.widget-view+json": {
|
338 |
+
"model_id": "5690514fe3ac4e3a84fedb128a687ec1",
|
339 |
+
"version_major": 2,
|
340 |
+
"version_minor": 0
|
341 |
+
},
|
342 |
+
"text/plain": [
|
343 |
+
"Iteration: 0%| | 0/2505 [00:00<?, ?it/s]"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
"metadata": {},
|
347 |
+
"output_type": "display_data"
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"data": {
|
351 |
+
"application/vnd.jupyter.widget-view+json": {
|
352 |
+
"model_id": "ef19638fe2504ec095fa9f6aed3d5069",
|
353 |
+
"version_major": 2,
|
354 |
+
"version_minor": 0
|
355 |
+
},
|
356 |
+
"text/plain": [
|
357 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
"metadata": {},
|
361 |
+
"output_type": "display_data"
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"data": {
|
365 |
+
"application/vnd.jupyter.widget-view+json": {
|
366 |
+
"model_id": "62af9e297e044f0bbb4d9b903b229db4",
|
367 |
+
"version_major": 2,
|
368 |
+
"version_minor": 0
|
369 |
+
},
|
370 |
+
"text/plain": [
|
371 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
"metadata": {},
|
375 |
+
"output_type": "display_data"
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"data": {
|
379 |
+
"application/vnd.jupyter.widget-view+json": {
|
380 |
+
"model_id": "4007c6d76e1445fe8263561886ab3196",
|
381 |
+
"version_major": 2,
|
382 |
+
"version_minor": 0
|
383 |
+
},
|
384 |
+
"text/plain": [
|
385 |
+
"Batches: 0%| | 0/278 [00:00<?, ?it/s]"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
"metadata": {},
|
389 |
+
"output_type": "display_data"
|
390 |
+
}
|
391 |
+
],
|
392 |
+
"source": [
|
393 |
+
"base_model.fit(\n",
|
394 |
+
" train_objectives=[(dataloader, train_loss)],\n",
|
395 |
+
" evaluator=triplet_evaluator,\n",
|
396 |
+
" checkpoint_save_steps=600,\n",
|
397 |
+
" evaluation_steps=600,\n",
|
398 |
+
" checkpoint_path=f\"{output_model_path}/ckpts\",\n",
|
399 |
+
" scheduler=\"WarmupCosine\",\n",
|
400 |
+
" save_best_model=True,\n",
|
401 |
+
" epochs=15,\n",
|
402 |
+
" warmup_steps=200,\n",
|
403 |
+
" optimizer_class=Lion,\n",
|
404 |
+
" optimizer_params=dict(lr=1e-4, weight_decay=1e-2),\n",
|
405 |
+
" use_amp=True,\n",
|
406 |
+
" output_path=output_model_path,\n",
|
407 |
+
" checkpoint_save_total_limit=4,\n",
|
408 |
+
")"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": null,
|
414 |
+
"id": "21c91b44-4c0a-4fda-a72c-91dac70e72ae",
|
415 |
+
"metadata": {},
|
416 |
+
"outputs": [],
|
417 |
+
"source": [
|
418 |
+
"base_model.push_to_hub(hf_repo_name)"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": null,
|
424 |
+
"id": "85e7c7cd-6636-42d2-aec4-56f292ea8ba9",
|
425 |
+
"metadata": {},
|
426 |
+
"outputs": [],
|
427 |
+
"source": [
|
428 |
+
"graphviz.set_jupyter_format('png')"
|
429 |
+
]
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"cell_type": "code",
|
433 |
+
"execution_count": null,
|
434 |
+
"id": "bb0419cc-beb7-443e-b733-47c5b6cb267c",
|
435 |
+
"metadata": {},
|
436 |
+
"outputs": [],
|
437 |
+
"source": [
|
438 |
+
"model_graph = draw_graph(base_model, input_size=(1, 512), device='meta')\n",
|
439 |
+
"model_graph.visual_graph"
|
440 |
+
]
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"cell_type": "code",
|
444 |
+
"execution_count": null,
|
445 |
+
"id": "0e478f64-f687-40e5-a315-225de31d6df6",
|
446 |
+
"metadata": {},
|
447 |
+
"outputs": [],
|
448 |
+
"source": []
|
449 |
+
}
|
450 |
+
],
|
451 |
+
"metadata": {
|
452 |
+
"kernelspec": {
|
453 |
+
"display_name": "Python 3 (ipykernel)",
|
454 |
+
"language": "python",
|
455 |
+
"name": "python3"
|
456 |
+
},
|
457 |
+
"language_info": {
|
458 |
+
"codemirror_mode": {
|
459 |
+
"name": "ipython",
|
460 |
+
"version": 3
|
461 |
+
},
|
462 |
+
"file_extension": ".py",
|
463 |
+
"mimetype": "text/x-python",
|
464 |
+
"name": "python",
|
465 |
+
"nbconvert_exporter": "python",
|
466 |
+
"pygments_lexer": "ipython3",
|
467 |
+
"version": "3.10.6"
|
468 |
+
}
|
469 |
+
},
|
470 |
+
"nbformat": 4,
|
471 |
+
"nbformat_minor": 5
|
472 |
+
}
|
training.pdf
ADDED
Binary file (46.1 kB). View file
|
|