|
import os |
|
import torch |
|
|
|
from transformers import Trainer |
|
from typing import Optional |
|
|
|
LOG_INTERVAL = 5000 |
|
|
|
def get_state(model): |
|
trainable_state_dict = dict() |
|
for name, param in model.state_dict().items(): |
|
try: |
|
if model.get_parameter(name).requires_grad: |
|
trainable_state_dict[name] = param |
|
except: |
|
trainable_state_dict[name] = param |
|
return trainable_state_dict |
|
|
|
|
|
class SALMONNTrainer(Trainer): |
|
|
|
def _save_checkpoint(self, model, trial, metrics=None): |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{(self.state.global_step // LOG_INTERVAL) * LOG_INTERVAL}" |
|
|
|
run_dir = self._get_output_dir(trial=trial) |
|
output_dir = os.path.join(run_dir, checkpoint_folder) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
weight_to_save = get_state(self.model) |
|
torch.save(weight_to_save, os.path.join(output_dir, f'salomnn_7b.bin')) |
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
super(SALMONNTrainer, self)._save(output_dir, state_dict) |
|
|