unimer_demo / unimernet /runners /runner_iter.py
wufan's picture
Upload 111 files
18e4b60 verified
raw
history blame
11.8 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import os
import time
import torch
import torch.distributed as dist
import webdataset as wds
from unimernet.common.dist_utils import download_cached_file, is_main_process, main_process
from unimernet.common.registry import registry
from unimernet.common.utils import is_url
from unimernet.datasets.data_utils import reorg_datasets_by_split
from unimernet.runners.runner_base import RunnerBase
from torch.utils.data.dataset import ChainDataset
@registry.register_runner("runner_iter")
class RunnerIter(RunnerBase):
"""
Run training based on the number of iterations. This is common when
the training dataset size is large. Underhood logic is similar to
epoch-based training by considering every #iters_per_inner_epoch as an
inner epoch.
In iter-based runner, after every #iters_per_inner_epoch steps, we
1) do a validation epoch;
2) schedule the learning rate;
3) save the checkpoint.
We refer every #iters_per_inner_epoch steps as an inner epoch.
"""
def __init__(self, cfg, task, model, datasets, job_id):
super().__init__(cfg, task, model, datasets, job_id)
self.start_iters = 0
self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
assert self.max_iters > 0, "max_iters must be greater than 0."
self.iters_per_inner_epoch = int(
self.config.run_cfg.get("iters_per_inner_epoch", -1)
)
assert (
self.iters_per_inner_epoch > 0
), "iters_per_inner_epoch must be greater than 0."
@property
def max_epoch(self):
return int(self.max_iters / self.iters_per_inner_epoch)
@property
def cur_epoch(self):
try:
return self.train_loader.epoch
except AttributeError:
# pipeline data (e.g. LAION) is streaming, have no concept of epoch
return 0
def _progress(self, cur_iters):
return "{}_iters={}".format(self.cur_epoch, cur_iters)
def train(self):
start_time = time.time()
best_agg_metric = 0
best_iters = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
cur_epoch = 0
for start_iters in range(
self.start_iters, self.max_iters, self.iters_per_inner_epoch
):
end_iters = start_iters + self.iters_per_inner_epoch
# training phase
if not self.evaluate_only:
logging.info(
"Start training, max_iters={}, in total {} inner epochs.".format(
self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
)
)
train_stats = self.train_iters(self.cur_epoch, start_iters)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log = self.eval_epoch(
split_name=split_name, cur_epoch=self._progress(end_iters)
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "eval":
best_iters, best_agg_metric = end_iters, agg_metrics
self._save_checkpoint(end_iters, is_best=True)
val_log.update({"best_iters": best_iters})
self.log_stats(val_log, split_name)
# print evaluation metric
print(f"bleu:{val_log['bleu']:.6f}, edit_distance:{val_log['edit_distance']:.6f}, token_accuracy:{val_log['token_accuracy']:.6f} ")
print("="*80)
if self.evaluate_only:
break
if self.milestone and cur_epoch + 1 in self.milestone:
self._save_checkpoint(cur_epoch)
self._save_checkpoint(end_iters, latest=True)
dist.barrier()
cur_epoch += 1
# testing phase
self.evaluate(cur_epoch=self.cur_epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def train_iters(self, epoch, start_iters):
# train by iterations
self.model.train()
return self.task.train_iters(
epoch=epoch,
start_iters=start_iters,
iters_per_inner_epoch=self.iters_per_inner_epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@main_process
def _save_checkpoint(self, cur_iters, is_best=False, latest=False):
# only save the params requires gradient
assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
unwrapped_model = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in unwrapped_model.named_parameters()
}
state_dict = unwrapped_model.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"iters": cur_iters,
}
if is_best:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best"),
)
elif latest:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("latest"),
)
else:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format(cur_iters),
)
logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
torch.save(save_obj, save_to)
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_iters = checkpoint["iters"] + 1
logging.info("Resume checkpoint from {}".format(url_or_filename))
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# reoganize datasets by split and concatenate/chain if necessary
self.datasets = reorg_datasets_by_split(self.datasets)
# to keep the same structure as return value of concat_datasets
self.datasets = {
k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
}
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
try:
# a single map-style dataset
num_records = len(self.datasets[split_name])
except TypeError:
# a single wds.DataPipeline or ChainDataset
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
)
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders