|
import json |
|
import math |
|
import os |
|
import shutil |
|
import sys |
|
import time |
|
from distutils import dist |
|
|
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
|
|
from torch.utils.data import Sampler |
|
from packaging import version |
|
|
|
from transformers import Trainer, TrainerState, is_torch_tpu_available, is_apex_available |
|
from transformers.debug_utils import DebugOption |
|
from transformers.integrations import hp_params |
|
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint |
|
|
|
from transformers.trainer import ( |
|
is_sagemaker_mp_enabled, |
|
get_parameter_names, |
|
has_length, |
|
ALL_LAYERNORM_LAYERS, |
|
ShardedDDPOption, |
|
logger, TRAINER_STATE_NAME, |
|
) |
|
from typing import List, Optional |
|
|
|
from transformers.trainer_pt_utils import get_model_param_count |
|
from transformers.trainer_utils import HPSearchBackend, speed_metrics, TrainOutput |
|
from transformers.training_args import ParallelMode |
|
from transformers.utils import is_accelerate_available |
|
|
|
if is_accelerate_available(): |
|
from accelerate import Accelerator, skip_first_batches |
|
from accelerate import __version__ as accelerate_version |
|
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin |
|
|
|
if version.parse(accelerate_version) > version.parse("0.20.3"): |
|
from accelerate.utils import ( |
|
load_fsdp_model, |
|
load_fsdp_optimizer, |
|
save_fsdp_model, |
|
save_fsdp_optimizer, |
|
) |
|
|
|
if is_torch_tpu_available(check_device=False): |
|
import torch_xla.core.xla_model as xm |
|
import torch_xla.debug.metrics as met |
|
|
|
if is_apex_available(): |
|
from apex import amp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_weights = None |
|
print("no weighting") |
|
|
|
if token_weights is not None: |
|
min_weight = min(token_weights.values()) |
|
extra_token_weight = min_weight / 100 |
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
from deepspeed import zero |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
if hasattr(param, "ds_id"): |
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
if not ignore_status: |
|
print(name, 'no ignore status') |
|
with zero.GatheredParameters([param]): |
|
param = param.data.detach().cpu().clone() |
|
else: |
|
param = param.detach().cpu().clone() |
|
return param |
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
|
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
|
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} |
|
return to_return |
|
|
|
|
|
def split_to_even_chunks(indices, lengths, num_chunks): |
|
""" |
|
Split a list of indices into `chunks` chunks of roughly equal lengths. |
|
""" |
|
|
|
if len(indices) % num_chunks != 0: |
|
return [indices[i::num_chunks] for i in range(num_chunks)] |
|
|
|
num_indices_per_chunk = len(indices) // num_chunks |
|
|
|
chunks = [[] for _ in range(num_chunks)] |
|
chunks_lengths = [0 for _ in range(num_chunks)] |
|
for index in indices: |
|
shortest_chunk = chunks_lengths.index(min(chunks_lengths)) |
|
chunks[shortest_chunk].append(index) |
|
chunks_lengths[shortest_chunk] += lengths[index] |
|
if len(chunks[shortest_chunk]) == num_indices_per_chunk: |
|
chunks_lengths[shortest_chunk] = float("inf") |
|
|
|
return chunks |
|
|
|
|
|
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): |
|
|
|
assert all(l != 0 for l in lengths), "Should not have zero length." |
|
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): |
|
|
|
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) |
|
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) |
|
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) |
|
|
|
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] |
|
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] |
|
megabatch_size = world_size * batch_size |
|
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] |
|
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] |
|
|
|
last_mm = mm_megabatches[-1] |
|
last_lang = lang_megabatches[-1] |
|
additional_batch = last_mm + last_lang |
|
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] |
|
megabatch_indices = torch.randperm(len(megabatches), generator=generator) |
|
megabatches = [megabatches[i] for i in megabatch_indices] |
|
|
|
if len(additional_batch) > 0: |
|
megabatches.append(sorted(additional_batch)) |
|
|
|
return [i for megabatch in megabatches for i in megabatch] |
|
|
|
|
|
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): |
|
|
|
indices = torch.randperm(len(lengths), generator=generator) |
|
megabatch_size = world_size * batch_size |
|
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] |
|
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] |
|
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] |
|
|
|
return [i for megabatch in megabatches for batch in megabatch for i in batch] |
|
|
|
|
|
class LengthGroupedSampler(Sampler): |
|
r""" |
|
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while |
|
keeping a bit of randomness. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
batch_size: int, |
|
world_size: int, |
|
lengths: Optional[List[int]] = None, |
|
generator=None, |
|
group_by_modality: bool = False, |
|
): |
|
if lengths is None: |
|
raise ValueError("Lengths must be provided.") |
|
|
|
self.batch_size = batch_size |
|
self.world_size = world_size |
|
self.lengths = lengths |
|
self.generator = generator |
|
self.group_by_modality = group_by_modality |
|
|
|
def __len__(self): |
|
return len(self.lengths) |
|
|
|
def __iter__(self): |
|
if self.group_by_modality: |
|
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) |
|
else: |
|
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) |
|
return iter(indices) |
|
|
|
|
|
class LLaVATrainer(Trainer): |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
""" |
|
How the loss is computed by Trainer. By default, all models return the loss in the first element. |
|
|
|
Subclass and override for custom behavior. |
|
""" |
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
if self.args.past_index >= 0: |
|
self._past = outputs[self.args.past_index] |
|
|
|
if token_weights is not None: |
|
|
|
if not hasattr(self, 'vocab_weight'): |
|
vocab = self.tokenizer.get_vocab() |
|
self.vocab_weight = torch.ones(len(vocab)) * extra_token_weight |
|
|
|
for k, v in token_weights.items(): |
|
self.vocab_weight[vocab[k]] = v |
|
self.vocab_weight = self.vocab_weight.to(self.args.device) |
|
|
|
|
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
shift_labels = outputs.modified_labels[..., 1:].contiguous() |
|
|
|
loss_fct = nn.CrossEntropyLoss(weight=self.vocab_weight) |
|
shift_logits = shift_logits.view(-1, self.model.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
else: |
|
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
def _inner_training_loop( |
|
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None |
|
): |
|
self.accelerator.free_memory() |
|
self._train_batch_size = batch_size |
|
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") |
|
|
|
train_dataloader = self.get_train_dataloader() |
|
|
|
|
|
|
|
|
|
|
|
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size |
|
|
|
len_dataloader = None |
|
if has_length(train_dataloader): |
|
len_dataloader = len(train_dataloader) |
|
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps |
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
|
num_examples = self.num_examples(train_dataloader) |
|
if args.max_steps > 0: |
|
max_steps = args.max_steps |
|
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( |
|
args.max_steps % num_update_steps_per_epoch > 0 |
|
) |
|
|
|
|
|
num_train_samples = args.max_steps * total_train_batch_size |
|
else: |
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
|
num_train_epochs = math.ceil(args.num_train_epochs) |
|
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs |
|
elif args.max_steps > 0: |
|
max_steps = args.max_steps |
|
|
|
num_train_epochs = sys.maxsize |
|
num_update_steps_per_epoch = max_steps |
|
num_examples = total_train_batch_size * args.max_steps |
|
num_train_samples = args.max_steps * total_train_batch_size |
|
else: |
|
raise ValueError( |
|
"args.max_steps must be set to a positive value if dataloader does not have a length, was" |
|
f" {args.max_steps}" |
|
) |
|
|
|
|
|
if args.logging_steps and args.logging_steps < 1: |
|
args.logging_steps = math.ceil(max_steps * args.logging_steps) |
|
if args.eval_steps and args.eval_steps < 1: |
|
args.eval_steps = math.ceil(max_steps * args.eval_steps) |
|
if args.save_steps and args.save_steps < 1: |
|
args.save_steps = math.ceil(max_steps * args.save_steps) |
|
|
|
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: |
|
if self.args.n_gpu > 1: |
|
|
|
|
|
raise ValueError( |
|
"Currently --debug underflow_overflow is not supported under DP. Please use DDP" |
|
" (torch.distributed.launch)." |
|
) |
|
else: |
|
debug_overflow = DebugUnderflowOverflow(self.model) |
|
|
|
delay_optimizer_creation = ( |
|
self.sharded_ddp is not None |
|
and self.sharded_ddp != ShardedDDPOption.SIMPLE |
|
or is_sagemaker_mp_enabled() |
|
or self.fsdp is not None |
|
) |
|
|
|
|
|
if self._created_lr_scheduler: |
|
self.lr_scheduler = None |
|
self._created_lr_scheduler = False |
|
|
|
if self.is_deepspeed_enabled: |
|
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) |
|
|
|
if not delay_optimizer_creation: |
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
self.state = TrainerState() |
|
self.state.is_hyper_param_search = trial is not None |
|
|
|
|
|
if args.gradient_checkpointing: |
|
self.model.gradient_checkpointing_enable() |
|
|
|
model = self._wrap_model(self.model_wrapped) |
|
|
|
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: |
|
self._load_from_checkpoint(resume_from_checkpoint, model) |
|
|
|
|
|
|
|
|
|
use_accelerator_prepare = True if model is self.model else False |
|
|
|
if delay_optimizer_creation: |
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
|
|
|
if use_accelerator_prepare: |
|
self.model.train() |
|
if hasattr(self.lr_scheduler, "step"): |
|
if self.use_apex: |
|
model = self.accelerator.prepare(self.model) |
|
else: |
|
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
|
else: |
|
|
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
|
self.model, self.optimizer, self.lr_scheduler |
|
) |
|
|
|
if self.is_fsdp_enabled: |
|
self.model = model |
|
|
|
|
|
if model is not self.model: |
|
self.model_wrapped = model |
|
|
|
|
|
if self.is_deepspeed_enabled: |
|
self.deepspeed = self.model_wrapped |
|
|
|
|
|
if resume_from_checkpoint is not None and self.is_deepspeed_enabled: |
|
print(f"DeepSpeed info: Loading model from {resume_from_checkpoint}") |
|
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) |
|
|
|
|
|
for param_tensor, state in self.lr_scheduler.optimizer.state.items(): |
|
step_tensor = state['step'] |
|
step_value = step_tensor.item() |
|
print(f"Step value for a parameter tensor: {step_value}") |
|
|
|
|
|
break |
|
|
|
for _ in range(int(step_value)): |
|
self.lr_scheduler.step() |
|
|
|
self._load_optimizer_and_scheduler(resume_from_checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {num_examples:,}") |
|
logger.info(f" Num Epochs = {num_train_epochs:,}") |
|
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") |
|
if self.args.per_device_train_batch_size != self._train_batch_size: |
|
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") |
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
logger.info(f" Total optimization steps = {max_steps:,}") |
|
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") |
|
|
|
self.state.epoch = 0 |
|
start_time = time.time() |
|
epochs_trained = 0 |
|
steps_trained_in_current_epoch = 0 |
|
steps_trained_progress_bar = None |
|
|
|
|
|
if resume_from_checkpoint is not None and os.path.isfile( |
|
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
|
): |
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch |
|
if not args.ignore_data_skip: |
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
|
steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
|
else: |
|
steps_trained_in_current_epoch = 0 |
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
|
logger.info(f" Continuing training from epoch {epochs_trained}") |
|
logger.info(f" Continuing training from global step {self.state.global_step}") |
|
if not args.ignore_data_skip: |
|
logger.info( |
|
f" Will skip the first {epochs_trained} epochs then the first" |
|
f" {steps_trained_in_current_epoch} batches in the first epoch." |
|
) |
|
|
|
|
|
self.callback_handler.model = self.model |
|
self.callback_handler.optimizer = self.optimizer |
|
self.callback_handler.lr_scheduler = self.lr_scheduler |
|
self.callback_handler.train_dataloader = train_dataloader |
|
if self.hp_name is not None and self._trial is not None: |
|
|
|
|
|
self.state.trial_name = self.hp_name(self._trial) |
|
if trial is not None: |
|
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial |
|
self.state.trial_params = hp_params(assignments) |
|
else: |
|
self.state.trial_params = None |
|
|
|
|
|
self.state.max_steps = max_steps |
|
self.state.num_train_epochs = num_train_epochs |
|
self.state.is_local_process_zero = self.is_local_process_zero() |
|
self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
|
|
|
tr_loss = torch.tensor(0.0).to(args.device) |
|
|
|
self._total_loss_scalar = 0.0 |
|
self._globalstep_last_logged = self.state.global_step |
|
model.zero_grad() |
|
|
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
|
|
|
|
|
if not args.ignore_data_skip: |
|
for epoch in range(epochs_trained): |
|
for _ in train_dataloader: |
|
break |
|
|
|
total_batched_samples = 0 |
|
for epoch in range(epochs_trained, num_train_epochs): |
|
epoch_iterator = train_dataloader |
|
|
|
|
|
if args.past_index >= 0: |
|
self._past = None |
|
|
|
steps_in_epoch = ( |
|
len(epoch_iterator) |
|
if len_dataloader is not None |
|
else args.max_steps * args.gradient_accumulation_steps |
|
) |
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
|
|
|
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: |
|
self._load_rng_state(resume_from_checkpoint) |
|
|
|
rng_to_sync = False |
|
steps_skipped = 0 |
|
if steps_trained_in_current_epoch > 0: |
|
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) |
|
steps_skipped = steps_trained_in_current_epoch |
|
steps_trained_in_current_epoch = 0 |
|
rng_to_sync = True |
|
|
|
step = -1 |
|
for step, inputs in enumerate(epoch_iterator): |
|
total_batched_samples += 1 |
|
if rng_to_sync: |
|
self._load_rng_state(resume_from_checkpoint) |
|
rng_to_sync = False |
|
|
|
|
|
if steps_trained_in_current_epoch > 0: |
|
steps_trained_in_current_epoch -= 1 |
|
if steps_trained_progress_bar is not None: |
|
steps_trained_progress_bar.update(1) |
|
if steps_trained_in_current_epoch == 0: |
|
self._load_rng_state(resume_from_checkpoint) |
|
continue |
|
elif steps_trained_progress_bar is not None: |
|
steps_trained_progress_bar.close() |
|
steps_trained_progress_bar = None |
|
|
|
if step % args.gradient_accumulation_steps == 0: |
|
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
|
|
|
with self.accelerator.accumulate(model): |
|
tr_loss_step = self.training_step(model, inputs) |
|
|
|
if ( |
|
args.logging_nan_inf_filter |
|
and not is_torch_tpu_available() |
|
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
|
): |
|
|
|
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
|
else: |
|
tr_loss += tr_loss_step |
|
|
|
self.current_flos += float(self.floating_point_ops(inputs)) |
|
|
|
is_last_step_and_steps_less_than_grad_acc = ( |
|
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch |
|
) |
|
|
|
if ( |
|
total_batched_samples % args.gradient_accumulation_steps == 0 |
|
or |
|
|
|
is_last_step_and_steps_less_than_grad_acc |
|
): |
|
|
|
|
|
if is_last_step_and_steps_less_than_grad_acc or ( |
|
version.parse(accelerate_version) <= version.parse("0.20.3") |
|
): |
|
self.accelerator.gradient_state._set_sync_gradients(True) |
|
|
|
|
|
if args.max_grad_norm is not None and args.max_grad_norm > 0: |
|
|
|
|
|
if self.do_grad_scaling: |
|
|
|
if is_torch_tpu_available(): |
|
gradients = xm._fetch_gradients(self.optimizer) |
|
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) |
|
|
|
self.scaler.unscale_(self.optimizer) |
|
|
|
if is_sagemaker_mp_enabled() and args.fp16: |
|
self.optimizer.clip_master_grads(args.max_grad_norm) |
|
elif hasattr(self.optimizer, "clip_grad_norm"): |
|
|
|
self.optimizer.clip_grad_norm(args.max_grad_norm) |
|
elif hasattr(model, "clip_grad_norm_"): |
|
|
|
model.clip_grad_norm_(args.max_grad_norm) |
|
elif self.use_apex: |
|
|
|
nn.utils.clip_grad_norm_( |
|
amp.master_params(self.optimizer), |
|
args.max_grad_norm, |
|
) |
|
else: |
|
self.accelerator.clip_grad_norm_( |
|
model.parameters(), |
|
args.max_grad_norm, |
|
) |
|
|
|
|
|
optimizer_was_run = True |
|
if is_torch_tpu_available(): |
|
if self.do_grad_scaling: |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
else: |
|
|
|
self.optimizer.step() |
|
elif self.do_grad_scaling: |
|
scale_before = self.scaler.get_scale() |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
scale_after = self.scaler.get_scale() |
|
optimizer_was_run = scale_before <= scale_after |
|
else: |
|
self.optimizer.step() |
|
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped |
|
|
|
if optimizer_was_run: |
|
|
|
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
|
self.lr_scheduler.step() |
|
|
|
model.zero_grad() |
|
self.state.global_step += 1 |
|
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch |
|
self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
|
|
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
|
else: |
|
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
|
|
|
if self.control.should_epoch_stop or self.control.should_training_stop: |
|
break |
|
if step < 0: |
|
logger.warning( |
|
"There seems to be not a single sample in your epoch_iterator, stopping training at step" |
|
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
|
f" num_steps ({max_steps}) higher than the number of available samples." |
|
) |
|
self.control.should_training_stop = True |
|
|
|
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
|
|
|
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
|
if is_torch_tpu_available(): |
|
|
|
xm.master_print(met.metrics_report()) |
|
else: |
|
logger.warning( |
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " |
|
"configured. Check your training configuration if this is unexpected." |
|
) |
|
if self.control.should_training_stop: |
|
break |
|
|
|
if args.past_index and hasattr(self, "_past"): |
|
|
|
delattr(self, "_past") |
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
|
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
|
|
|
if is_torch_tpu_available(): |
|
xm.rendezvous("load_best_model_at_end") |
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED: |
|
dist.barrier() |
|
|
|
|
|
|
|
self._load_best_model() |
|
|
|
|
|
self._total_loss_scalar += tr_loss.item() |
|
train_loss = self._total_loss_scalar / self.state.global_step |
|
|
|
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) |
|
self.store_flos() |
|
metrics["total_flos"] = self.state.total_flos |
|
metrics["train_loss"] = train_loss |
|
|
|
self.is_in_train = False |
|
|
|
self._memory_tracker.stop_and_update_metrics(metrics) |
|
|
|
self.log(metrics) |
|
|
|
run_dir = self._get_output_dir(trial) |
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
|
|
|
|
|
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
|
for checkpoint in checkpoints_sorted: |
|
if checkpoint != self.state.best_model_checkpoint: |
|
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
|
shutil.rmtree(checkpoint) |
|
|
|
self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
|
|
|
return TrainOutput(self.state.global_step, train_loss, metrics) |
|
|
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
|
if self.train_dataset is None or not has_length(self.train_dataset): |
|
return None |
|
|
|
if self.args.group_by_modality_length: |
|
lengths = self.train_dataset.modality_lengths |
|
return LengthGroupedSampler( |
|
self.args.train_batch_size, |
|
world_size=self.args.world_size * self.args.gradient_accumulation_steps, |
|
lengths=lengths, |
|
group_by_modality=True, |
|
) |
|
else: |
|
return super()._get_train_sampler() |
|
|
|
def create_optimizer(self): |
|
""" |
|
Setup the optimizer. |
|
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
|
Trainer's init through `optimizers`, or subclass and override this method in a subclass. |
|
""" |
|
if is_sagemaker_mp_enabled(): |
|
return super().create_optimizer() |
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE: |
|
return super().create_optimizer() |
|
|
|
opt_model = self.model |
|
|
|
if self.optimizer is None: |
|
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) |
|
decay_parameters = [name for name in decay_parameters if "bias" not in name] |
|
if self.args.mm_projector_lr is not None: |
|
projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) |
|
], |
|
"weight_decay": self.args.weight_decay, |
|
}, |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) |
|
], |
|
"weight_decay": 0.0, |
|
}, |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) |
|
], |
|
"weight_decay": self.args.weight_decay, |
|
"lr": self.args.mm_projector_lr, |
|
}, |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) |
|
], |
|
"weight_decay": 0.0, |
|
"lr": self.args.mm_projector_lr, |
|
}, |
|
] |
|
else: |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) |
|
], |
|
"weight_decay": self.args.weight_decay, |
|
}, |
|
{ |
|
"params": [ |
|
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) |
|
], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
|
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) |
|
|
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE: |
|
self.optimizer = OSS( |
|
params=optimizer_grouped_parameters, |
|
optim=optimizer_cls, |
|
**optimizer_kwargs, |
|
) |
|
else: |
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
if optimizer_cls.__name__ == "Adam8bit": |
|
import bitsandbytes |
|
|
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() |
|
|
|
skipped = 0 |
|
for module in opt_model.modules(): |
|
if isinstance(module, nn.Embedding): |
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) |
|
logger.info(f"skipped {module}: {skipped/2**20}M params") |
|
manager.register_module_override(module, "weight", {"optim_bits": 32}) |
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32") |
|
logger.info(f"skipped: {skipped/2**20}M params") |
|
|
|
return self.optimizer |
|
|
|
def _save_checkpoint(self, model, trial, metrics=None): |
|
if getattr(self.args, 'tune_mm_mlp_adapter', False): |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
|
|
|
run_dir = self._get_output_dir(trial=trial) |
|
output_dir = os.path.join(run_dir, checkpoint_folder) |
|
|
|
|
|
keys_to_match = ['mm_projector', 'vision_resampler'] |
|
if getattr(self.args, "use_im_start_end", False): |
|
keys_to_match.extend(['embed_tokens', 'embed_in']) |
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) |
|
|
|
if self.args.local_rank == 0 or self.args.local_rank == -1: |
|
self.model.config.save_pretrained(output_dir) |
|
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) |
|
else: |
|
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) |
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
if getattr(self.args, 'tune_mm_mlp_adapter', False): |
|
pass |
|
else: |
|
super(LLaVATrainer, self)._save(output_dir, state_dict) |
|
|