|
|
|
import os |
|
|
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
DataCollatorForLanguageModeling, |
|
TrainingArguments, |
|
Trainer, |
|
BitsAndBytesConfig, |
|
TrainerCallback, |
|
) |
|
from datasets import load_from_disk |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
from peft.tuners.lora import LoraLayer |
|
from accelerate import Accelerator |
|
|
|
|
|
batch_size = 2 |
|
|
|
checkpoint = "google/gemma-2b" |
|
data_dir = "dataset_ro_small_v1/" |
|
save_dir = "gemma-2b-romanian-1.6gb-finetuned-qlora" |
|
log_dir = "training_logs/" |
|
|
|
|
|
tokenized_datasets = load_from_disk(f'tokenized_{data_dir}') |
|
|
|
tokenized_datasets = tokenized_datasets.shuffle(seed=42) |
|
|
|
print(tokenized_datasets) |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_quant_dtype=torch.float16, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
checkpoint, |
|
load_in_8bit=False, |
|
quantization_config=bnb_config, |
|
device_map={ "": Accelerator().process_index }, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
attn_implementation='sdpa', |
|
use_cache=False, |
|
) |
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
lora_config = LoraConfig( |
|
lora_alpha=32, |
|
lora_dropout=0.1, |
|
r=8, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
model.print_trainable_parameters() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
args = TrainingArguments( |
|
output_dir='training_checkpoints/', |
|
logging_dir=log_dir, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
evaluation_strategy='no', |
|
logging_steps=100, |
|
save_strategy='steps', |
|
save_steps=100, |
|
save_total_limit=10, |
|
gradient_accumulation_steps=4, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={ "use_reentrant": False }, |
|
num_train_epochs=1, |
|
warmup_steps=1_000, |
|
weight_decay=0.001, |
|
lr_scheduler_type='cosine', |
|
learning_rate=1e-4, |
|
max_grad_norm=0.3, |
|
fp16=True, |
|
ddp_find_unused_parameters=False, |
|
) |
|
|
|
|
|
class StopCallback(TrainerCallback): |
|
def on_step_end(self, args, state, control, **kwargs): |
|
if state.global_step != 0 and state.global_step % 1000 == 0: |
|
|
|
control.should_training_stop = True |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
data_collator=data_collator, |
|
train_dataset=tokenized_datasets['train'], |
|
eval_dataset=tokenized_datasets['test'], |
|
tokenizer=tokenizer, |
|
) |
|
trainer.add_callback(StopCallback) |
|
|
|
print("Starting training...") |
|
|
|
train_checkpoint = os.getenv("TRAIN_CHECKPOINT") |
|
if train_checkpoint is not None: |
|
trainer.train(train_checkpoint) |
|
else: |
|
trainer.train() |
|
|
|
|
|
torch.save(trainer.state.log_history, "trainer_log_history.pth") |
|
|
|
model.save_pretrained(save_dir) |
|
tokenizer.save_pretrained(save_dir) |
|
|