train model
Browse files- scripts/requirements.in +2 -1
- scripts/train_model.py +100 -11
scripts/requirements.in
CHANGED
@@ -3,4 +3,5 @@ datasets
|
|
3 |
jinja2
|
4 |
transformers
|
5 |
jsonlines
|
6 |
-
litgpt[all]
|
|
|
|
3 |
jinja2
|
4 |
transformers
|
5 |
jsonlines
|
6 |
+
litgpt[all]
|
7 |
+
litdata
|
scripts/train_model.py
CHANGED
@@ -1,17 +1,31 @@
|
|
1 |
import gc
|
|
|
|
|
|
|
|
|
2 |
from datasets import load_dataset, Dataset
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
#
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
|
16 |
# code
|
17 |
dataset = (
|
@@ -166,4 +180,79 @@ def batch_iterator():
|
|
166 |
yield f'{row["character"]}\n{row["unicode"]}\n{row["short description"]}\n{row["tags"]}\n{row["LLM description"]}'
|
167 |
|
168 |
del dataset
|
169 |
-
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.optim import AdamW
|
5 |
+
import bitsandbytes as bnb
|
6 |
from datasets import load_dataset, Dataset
|
7 |
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
TrainingArguments,
|
13 |
+
Trainer,
|
14 |
+
DataCollatorForLanguageModeling,
|
15 |
+
)
|
16 |
|
17 |
+
|
18 |
+
def _batch_iterator():
|
19 |
+
# code
|
20 |
+
dataset = load_dataset('bigcode/programming-languages-keywords', split='train')
|
21 |
+
|
22 |
+
for row in dataset:
|
23 |
+
for n in row['keywords']:
|
24 |
+
yield n
|
25 |
+
|
26 |
+
del dataset
|
27 |
+
gc.collect()
|
28 |
+
return
|
29 |
|
30 |
# code
|
31 |
dataset = (
|
|
|
180 |
yield f'{row["character"]}\n{row["unicode"]}\n{row["short description"]}\n{row["tags"]}\n{row["LLM description"]}'
|
181 |
|
182 |
del dataset
|
183 |
+
gc.collect()
|
184 |
+
|
185 |
+
|
186 |
+
def batch_iterator():
|
187 |
+
for text in _batch_iterator():
|
188 |
+
for i in range(0, len(text), 2048):
|
189 |
+
chunk = text[i:i + 2048]
|
190 |
+
yield {'text': chunk}
|
191 |
+
|
192 |
+
|
193 |
+
tokenizer = AutoTokenizer.from_pretrained('../')
|
194 |
+
print(tokenizer)
|
195 |
+
|
196 |
+
config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.3')
|
197 |
+
config.bos_token_id = tokenizer.bos_token_id
|
198 |
+
config.eos_token_id = tokenizer.eos_token_id
|
199 |
+
config.unk_token_id = tokenizer.unk_token_id
|
200 |
+
config.pad_token_id = tokenizer.pad_token_id
|
201 |
+
config.hidden_size = 512
|
202 |
+
config.intermediate_size = 1792 # int(512 * 3.5)
|
203 |
+
config.max_position_embeddings = 32768 # 32 * 1024
|
204 |
+
config.num_attention_heads = 12
|
205 |
+
config.num_hidden_layers = 10
|
206 |
+
config.num_key_value_heads = 4
|
207 |
+
config.rope_theta = 1_000_000.0
|
208 |
+
config.sliding_window = 4096
|
209 |
+
config.torch_dtype = torch.bfloat16
|
210 |
+
config.use_cache = False
|
211 |
+
print(config)
|
212 |
+
|
213 |
+
model = AutoModelForCausalLM.from_config(config)
|
214 |
+
print(model)
|
215 |
+
|
216 |
+
dataset = Dataset.from_generator(batch_iterator)
|
217 |
+
print(dataset)
|
218 |
+
|
219 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
220 |
+
print(data_collator)
|
221 |
+
|
222 |
+
optimizer = bnb.optim.AdamW8bit(
|
223 |
+
model.parameters(),
|
224 |
+
lr=1e-5,
|
225 |
+
betas=(0.9, 0.95),
|
226 |
+
weight_decay=0.1,
|
227 |
+
)
|
228 |
+
print(optimizer)
|
229 |
+
|
230 |
+
training_args = TrainingArguments(
|
231 |
+
output_dir='./mistral-custom',
|
232 |
+
num_train_epochs=3,
|
233 |
+
per_device_train_batch_size=1,
|
234 |
+
gradient_accumulation_steps=8,
|
235 |
+
warmup_steps=500,
|
236 |
+
learning_rate=1e-5,
|
237 |
+
fp16=False,
|
238 |
+
bf16=True,
|
239 |
+
logging_dir='./logs',
|
240 |
+
logging_steps=10,
|
241 |
+
evaluation_strategy='no',
|
242 |
+
save_strategy='epoch',
|
243 |
+
torch_compile=True,
|
244 |
+
remove_unused_columns=False,
|
245 |
+
)
|
246 |
+
print(training_args)
|
247 |
+
|
248 |
+
trainer = Trainer(
|
249 |
+
model=model,
|
250 |
+
args=training_args,
|
251 |
+
train_dataset=dataset,
|
252 |
+
data_collator=data_collator,
|
253 |
+
optimizers=(optimizer, None)
|
254 |
+
)
|
255 |
+
print(trainer)
|
256 |
+
|
257 |
+
trainer.train()
|
258 |
+
trainer.save_model('./mistral-custom-final')
|