ZebangCheng commited on
Commit
691ef95
·
1 Parent(s): e266a77
minigpt4/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from minigpt4.common.registry import registry
14
+
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
minigpt4/runners/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.runners.runner_base import RunnerBase
9
+
10
+ __all__ = ["RunnerBase"]
minigpt4/runners/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (473 Bytes). View file
 
minigpt4/runners/__pycache__/runner_base.cpython-39.pyc ADDED
Binary file (17.6 kB). View file
 
minigpt4/runners/runner_base.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ import webdataset as wds
18
+ from minigpt4.common.dist_utils import (
19
+ download_cached_file,
20
+ get_rank,
21
+ get_world_size,
22
+ is_main_process,
23
+ main_process,
24
+ )
25
+ from minigpt4.common.registry import registry
26
+ from minigpt4.common.utils import is_url
27
+ from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
28
+ from minigpt4.datasets.datasets.dataloader_utils import (
29
+ IterLoader,
30
+ MultiIterLoader,
31
+ PrefetchLoader,
32
+ )
33
+ from torch.nn.parallel import DistributedDataParallel as DDP
34
+ from torch.utils.data import DataLoader, DistributedSampler
35
+
36
+
37
+ @registry.register_runner("runner_base")
38
+ class RunnerBase:
39
+ """
40
+ A runner class to train and evaluate a model given a task and datasets.
41
+
42
+ The runner uses pytorch distributed data parallel by default. Future release
43
+ will support other distributed frameworks.
44
+ """
45
+
46
+ def __init__(self, cfg, task, model, datasets, job_id):
47
+ self.config = cfg
48
+ self.job_id = job_id
49
+
50
+ self.task = task
51
+ self.datasets = datasets
52
+
53
+ self._model = model
54
+
55
+ self._wrapped_model = None
56
+ self._device = None
57
+ self._optimizer = None
58
+ self._scaler = None
59
+ self._dataloaders = None
60
+ self._lr_sched = None
61
+
62
+ self.start_epoch = 0
63
+
64
+ # self.setup_seeds()
65
+ self.setup_output_dir()
66
+
67
+ @property
68
+ def device(self):
69
+ if self._device is None:
70
+ self._device = torch.device(self.config.run_cfg.device)
71
+
72
+ return self._device
73
+
74
+ @property
75
+ def use_distributed(self):
76
+ return self.config.run_cfg.distributed
77
+
78
+ @property
79
+ def model(self):
80
+ """
81
+ A property to get the DDP-wrapped model on the device.
82
+ """
83
+ # move model to device
84
+ if self._model.device != self.device:
85
+ self._model = self._model.to(self.device)
86
+
87
+ # distributed training wrapper
88
+ if self.use_distributed:
89
+ if self._wrapped_model is None:
90
+ self._wrapped_model = DDP(
91
+ self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
92
+ )
93
+ else:
94
+ self._wrapped_model = self._model
95
+
96
+ return self._wrapped_model
97
+
98
+ @property
99
+ def optimizer(self):
100
+ # TODO make optimizer class and configurations
101
+ if self._optimizer is None:
102
+ num_parameters = 0
103
+ p_wd, p_non_wd = [], []
104
+ attention = []
105
+ for n, p in self.model.named_parameters():
106
+ if not p.requires_grad:
107
+ continue # frozen weights
108
+ print(n)
109
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
110
+ p_non_wd.append(p)
111
+ else:
112
+ p_wd.append(p)
113
+ num_parameters += p.data.nelement()
114
+
115
+ logging.info("number of trainable parameters: %d" % num_parameters)
116
+ optim_params = [
117
+ {
118
+ "params": p_wd,
119
+ "weight_decay": float(self.config.run_cfg.weight_decay),
120
+ "lr": float(self.config.run_cfg.init_lr)
121
+ },
122
+ {"params": p_non_wd, "weight_decay": 0, "lr": float(self.config.run_cfg.init_lr)},
123
+ ]
124
+
125
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
126
+ self._optimizer = torch.optim.AdamW(
127
+ optim_params,
128
+ lr=float(self.config.run_cfg.init_lr),
129
+ weight_decay=float(self.config.run_cfg.weight_decay),
130
+ betas=(0.9, beta2),
131
+ )
132
+
133
+ return self._optimizer
134
+
135
+ @property
136
+ def scaler(self):
137
+ amp = self.config.run_cfg.get("amp", False)
138
+
139
+ if amp:
140
+ if self._scaler is None:
141
+ self._scaler = torch.cuda.amp.GradScaler()
142
+
143
+ return self._scaler
144
+
145
+ @property
146
+ def lr_scheduler(self):
147
+ """
148
+ A property to get and create learning rate scheduler by split just in need.
149
+ """
150
+ if self._lr_sched is None:
151
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
152
+
153
+ # max_epoch = self.config.run_cfg.max_epoch
154
+ max_epoch = self.max_epoch
155
+ # min_lr = self.config.run_cfg.min_lr
156
+ min_lr = self.min_lr
157
+ # init_lr = self.config.run_cfg.init_lr
158
+ init_lr = self.init_lr
159
+
160
+ # optional parameters
161
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
162
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
163
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
164
+ iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
165
+
166
+ if iters_per_epoch is None:
167
+ try:
168
+ iters_per_epoch = len(self.dataloaders['train'])
169
+ except (AttributeError, TypeError):
170
+ iters_per_epoch = 10000
171
+
172
+ self._lr_sched = lr_sched_cls(
173
+ optimizer=self.optimizer,
174
+ max_epoch=max_epoch,
175
+ iters_per_epoch=iters_per_epoch,
176
+ min_lr=min_lr,
177
+ init_lr=init_lr,
178
+ decay_rate=decay_rate,
179
+ warmup_start_lr=warmup_start_lr,
180
+ warmup_steps=warmup_steps,
181
+ )
182
+
183
+ return self._lr_sched
184
+
185
+ @property
186
+ def dataloaders(self) -> dict:
187
+ """
188
+ A property to get and create dataloaders by split just in need.
189
+
190
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
191
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
192
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
193
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
194
+
195
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
196
+ each dataset by ratios during training.
197
+
198
+ Currently do not support multiple datasets for validation and test.
199
+
200
+ Returns:
201
+ dict: {split_name: (tuples of) dataloader}
202
+ """
203
+ if self._dataloaders is None:
204
+
205
+ # concatenate map-style datasets and chain wds.DataPipe datasets separately
206
+ # training set becomes a tuple (ConcatDataset, ChainDataset), both are
207
+ # optional but at least one of them is required. The resultant ConcatDataset
208
+ # and ChainDataset will be sampled evenly.
209
+ logging.info(
210
+ "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
211
+ )
212
+
213
+ batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
214
+ for dataset_name in self.datasets.keys()}
215
+ datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
216
+ self.datasets = datasets
217
+ # self.datasets = concat_datasets(datasets)
218
+
219
+ # print dataset statistics after concatenation/chaining
220
+ for split_name in self.datasets:
221
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
222
+ self.datasets[split_name], list
223
+ ):
224
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
225
+ num_records = sum(
226
+ [
227
+ len(d)
228
+ if not type(d) in [wds.DataPipeline, ChainDataset]
229
+ else 0
230
+ for d in self.datasets[split_name]
231
+ ]
232
+ )
233
+
234
+ else:
235
+ if hasattr(self.datasets[split_name], "__len__"):
236
+ # a single map-style dataset
237
+ num_records = len(self.datasets[split_name])
238
+ else:
239
+ # a single wds.DataPipeline
240
+ num_records = -1
241
+ logging.info(
242
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
243
+ )
244
+
245
+ if num_records >= 0:
246
+ logging.info(
247
+ "Loaded {} records for {} split from the dataset.".format(
248
+ num_records, split_name
249
+ )
250
+ )
251
+
252
+ # create dataloaders
253
+ split_names = sorted(self.datasets.keys())
254
+
255
+ datasets = [self.datasets[split] for split in split_names]
256
+ batch_sizes = [batch_sizes[split] for split in split_names]
257
+ is_trains = [split in self.train_splits for split in split_names]
258
+
259
+ print("batch sizes", batch_sizes)
260
+
261
+ collate_fns = []
262
+ for dataset in datasets:
263
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
264
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
265
+ else:
266
+ collate_fns.append(getattr(dataset, "collater", None))
267
+
268
+ dataloaders = self.create_loaders(
269
+ datasets=datasets,
270
+ num_workers=self.config.run_cfg.num_workers,
271
+ batch_sizes=batch_sizes,
272
+ is_trains=is_trains,
273
+ collate_fns=collate_fns,
274
+ )
275
+
276
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
277
+
278
+ return self._dataloaders
279
+
280
+ @property
281
+ def cuda_enabled(self):
282
+ return self.device.type == "cuda"
283
+
284
+ @property
285
+ def max_epoch(self):
286
+ return int(self.config.run_cfg.max_epoch)
287
+
288
+ @property
289
+ def log_freq(self):
290
+ log_freq = self.config.run_cfg.get("log_freq", 50)
291
+ return int(log_freq)
292
+
293
+ @property
294
+ def init_lr(self):
295
+ return float(self.config.run_cfg.init_lr)
296
+
297
+ @property
298
+ def min_lr(self):
299
+ return float(self.config.run_cfg.min_lr)
300
+
301
+ @property
302
+ def accum_grad_iters(self):
303
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
304
+
305
+ @property
306
+ def valid_splits(self):
307
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
308
+
309
+ if len(valid_splits) == 0:
310
+ logging.info("No validation splits found.")
311
+
312
+ return valid_splits
313
+
314
+ @property
315
+ def test_splits(self):
316
+ test_splits = self.config.run_cfg.get("test_splits", [])
317
+
318
+ return test_splits
319
+
320
+ @property
321
+ def train_splits(self):
322
+ train_splits = self.config.run_cfg.get("train_splits", [])
323
+
324
+ if len(train_splits) == 0:
325
+ logging.info("Empty train splits.")
326
+
327
+ return train_splits
328
+
329
+ @property
330
+ def evaluate_only(self):
331
+ """
332
+ Set to True to skip training.
333
+ """
334
+ return self.config.run_cfg.evaluate
335
+
336
+ @property
337
+ def use_dist_eval_sampler(self):
338
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
339
+
340
+ @property
341
+ def resume_ckpt_path(self):
342
+ return self.config.run_cfg.get("resume_ckpt_path", None)
343
+
344
+ @property
345
+ def train_loader(self):
346
+ train_dataloader = self.dataloaders["train"]
347
+
348
+ return train_dataloader
349
+
350
+ def setup_output_dir(self):
351
+ lib_root = Path(registry.get_path("library_root"))
352
+
353
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
354
+ # output_dir = lib_root / self.config.run_cfg.output_dir
355
+ result_dir = output_dir / "result"
356
+
357
+ output_dir.mkdir(parents=True, exist_ok=True)
358
+ result_dir.mkdir(parents=True, exist_ok=True)
359
+
360
+ registry.register_path("result_dir", str(result_dir))
361
+ registry.register_path("output_dir", str(output_dir))
362
+
363
+ self.result_dir = result_dir
364
+ self.output_dir = output_dir
365
+
366
+ def train(self):
367
+ start_time = time.time()
368
+ best_agg_metric = 0
369
+ best_epoch = 0
370
+
371
+ self.log_config()
372
+
373
+ # resume from checkpoint if specified
374
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
375
+ self._load_checkpoint(self.resume_ckpt_path)
376
+
377
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
378
+ # training phase
379
+ if not self.evaluate_only:
380
+ logging.info("Start training")
381
+ train_stats = self.train_epoch(cur_epoch)
382
+ self.log_stats(split_name="train", stats=train_stats)
383
+
384
+ # evaluation phase
385
+ if len(self.valid_splits) > 0:
386
+ for split_name in self.valid_splits:
387
+ logging.info("Evaluating on {}.".format(split_name))
388
+
389
+ val_log = self.eval_epoch(
390
+ split_name=split_name, cur_epoch=cur_epoch
391
+ )
392
+ if val_log is not None:
393
+ if is_main_process():
394
+ assert (
395
+ "agg_metrics" in val_log
396
+ ), "No agg_metrics found in validation log."
397
+
398
+ agg_metrics = val_log["agg_metrics"]
399
+ if agg_metrics > best_agg_metric and split_name == "val":
400
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
401
+
402
+ self._save_checkpoint(cur_epoch, is_best=True)
403
+
404
+ val_log.update({"best_epoch": best_epoch})
405
+ self.log_stats(val_log, split_name)
406
+
407
+ else:
408
+ # if no validation split is provided, we just save the checkpoint at the end of each epoch.
409
+ if not self.evaluate_only:
410
+ self._save_checkpoint(cur_epoch, is_best=False)
411
+
412
+ if self.evaluate_only:
413
+ break
414
+
415
+ if self.config.run_cfg.distributed:
416
+ dist.barrier()
417
+
418
+ # testing phase
419
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
420
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
421
+
422
+ total_time = time.time() - start_time
423
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
424
+ logging.info("Training time {}".format(total_time_str))
425
+
426
+ def evaluate(self, cur_epoch="best", skip_reload=False):
427
+ test_logs = dict()
428
+
429
+ if len(self.test_splits) > 0:
430
+ for split_name in self.test_splits:
431
+ test_logs[split_name] = self.eval_epoch(
432
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
433
+ )
434
+
435
+ return test_logs
436
+
437
+ def train_epoch(self, epoch):
438
+ # train
439
+ self.model.train()
440
+
441
+ return self.task.train_epoch(
442
+ epoch=epoch,
443
+ model=self.model,
444
+ data_loader=self.train_loader,
445
+ optimizer=self.optimizer,
446
+ scaler=self.scaler,
447
+ lr_scheduler=self.lr_scheduler,
448
+ cuda_enabled=self.cuda_enabled,
449
+ log_freq=self.log_freq,
450
+ accum_grad_iters=self.accum_grad_iters,
451
+ )
452
+
453
+ @torch.no_grad()
454
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
455
+ """
456
+ Evaluate the model on a given split.
457
+
458
+ Args:
459
+ split_name (str): name of the split to evaluate on.
460
+ cur_epoch (int): current epoch.
461
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
462
+ During training, we will reload the best checkpoint for validation.
463
+ During testing, we will use provided weights and skip reloading the best checkpoint .
464
+ """
465
+ data_loader = self.dataloaders.get(split_name, None)
466
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
467
+
468
+ # TODO In validation, you need to compute loss as well as metrics
469
+ # TODO consider moving to model.before_evaluation()
470
+ model = self.unwrap_dist_model(self.model)
471
+ if not skip_reload and cur_epoch == "best":
472
+ model = self._reload_best_model(model)
473
+ model.eval()
474
+
475
+ self.task.before_evaluation(
476
+ model=model,
477
+ dataset=self.datasets[split_name],
478
+ )
479
+ results = self.task.evaluation(model, data_loader)
480
+
481
+ if results is not None:
482
+ return self.task.after_evaluation(
483
+ val_result=results,
484
+ split_name=split_name,
485
+ epoch=cur_epoch,
486
+ )
487
+
488
+ def unwrap_dist_model(self, model):
489
+ if self.use_distributed:
490
+ return model.module
491
+ else:
492
+ return model
493
+
494
+ def create_loaders(
495
+ self,
496
+ datasets,
497
+ num_workers,
498
+ batch_sizes,
499
+ is_trains,
500
+ collate_fns,
501
+ dataset_ratios=None,
502
+ ):
503
+ """
504
+ Create dataloaders for training and validation.
505
+ """
506
+
507
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
508
+ # create a single dataloader for each split
509
+ if isinstance(dataset, ChainDataset) or isinstance(
510
+ dataset, wds.DataPipeline
511
+ ):
512
+ # wds.WebdDataset instance are chained together
513
+ # webdataset.DataPipeline has its own sampler and collate_fn
514
+ loader = iter(
515
+ DataLoader(
516
+ dataset,
517
+ batch_size=bsz,
518
+ num_workers=num_workers,
519
+ pin_memory=True,
520
+ )
521
+ )
522
+ else:
523
+ # map-style dataset are concatenated together
524
+ # setup distributed sampler
525
+
526
+ if self.use_distributed:
527
+ sampler = DistributedSampler(
528
+ dataset,
529
+ shuffle=is_train,
530
+ num_replicas=get_world_size(),
531
+ rank=get_rank(),
532
+ )
533
+ if not self.use_dist_eval_sampler:
534
+ # e.g. retrieval evaluation
535
+ sampler = sampler if is_train else None
536
+ else:
537
+ sampler = None
538
+
539
+ loader = DataLoader(
540
+ dataset,
541
+ batch_size=bsz,
542
+ num_workers=num_workers,
543
+ pin_memory=True,
544
+ sampler=sampler,
545
+ shuffle=sampler is None and is_train,
546
+ collate_fn=collate_fn,
547
+ drop_last=True if is_train else False,
548
+ )
549
+ loader = PrefetchLoader(loader)
550
+
551
+ if is_train:
552
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
553
+
554
+ return loader
555
+
556
+ loaders = []
557
+
558
+ for dataset, bsz, is_train, collate_fn in zip(
559
+ datasets, batch_sizes, is_trains, collate_fns
560
+ ):
561
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
562
+ if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
563
+ dataset_ratios = [d.sample_ratio for d in dataset]
564
+ loader = MultiIterLoader(
565
+ loaders=[
566
+ _create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
567
+ for i, d in enumerate(dataset)
568
+ ],
569
+ ratios=dataset_ratios,
570
+ )
571
+ else:
572
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
573
+
574
+ loaders.append(loader)
575
+
576
+ return loaders
577
+
578
+ @main_process
579
+ def _save_checkpoint(self, cur_epoch, is_best=False):
580
+ """
581
+ Save the checkpoint at the current epoch.
582
+ """
583
+
584
+ model_no_ddp = self.unwrap_dist_model(self.model)
585
+ param_grad_dic = {
586
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
587
+ }
588
+ state_dict = model_no_ddp.state_dict()
589
+ for k in list(state_dict.keys()):
590
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
591
+ # delete parameters that do not require gradient
592
+ del state_dict[k]
593
+
594
+ save_obj = {
595
+ "model": state_dict,
596
+ "optimizer": self.optimizer.state_dict(),
597
+ "config": self.config.to_dict(),
598
+ "scaler": self.scaler.state_dict() if self.scaler else None,
599
+ "epoch": cur_epoch,
600
+ }
601
+ save_to = os.path.join(
602
+ self.output_dir,
603
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
604
+ )
605
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
606
+ torch.save(save_obj, save_to)
607
+
608
+ def _reload_best_model(self, model):
609
+ """
610
+ Load the best checkpoint for evaluation.
611
+ """
612
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
613
+
614
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
615
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
616
+ try:
617
+ model.load_state_dict(checkpoint["model"])
618
+ except RuntimeError as e:
619
+ logging.warning(
620
+ """
621
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
622
+ Trying to load the model with strict=False.
623
+ """
624
+ )
625
+ model.load_state_dict(checkpoint["model"], strict=False)
626
+ return model
627
+
628
+ def _load_checkpoint(self, url_or_filename):
629
+ """
630
+ Resume from a checkpoint.
631
+ """
632
+ if is_url(url_or_filename):
633
+ cached_file = download_cached_file(
634
+ url_or_filename, check_hash=False, progress=True
635
+ )
636
+ checkpoint = torch.load(cached_file, map_location=self.device)
637
+ elif os.path.isfile(url_or_filename):
638
+ checkpoint = torch.load(url_or_filename, map_location=self.device)
639
+ else:
640
+ raise RuntimeError("checkpoint url or path is invalid")
641
+
642
+ state_dict = checkpoint["model"]
643
+ message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
644
+
645
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
646
+ if self.scaler and "scaler" in checkpoint:
647
+ self.scaler.load_state_dict(checkpoint["scaler"])
648
+
649
+ self.start_epoch = checkpoint["epoch"] + 1
650
+ print("resume the checkpoint")
651
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
652
+
653
+ @main_process
654
+ def log_stats(self, stats, split_name):
655
+ if isinstance(stats, dict):
656
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
657
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
658
+ f.write(json.dumps(log_stats) + "\n")
659
+ elif isinstance(stats, list):
660
+ pass
661
+
662
+ @main_process
663
+ def log_config(self):
664
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
665
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
minigpt4/tasks/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.common.registry import registry
9
+ from minigpt4.tasks.base_task import BaseTask
10
+ from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
11
+
12
+
13
+ def setup_task(cfg):
14
+ assert "task" in cfg.run_cfg, "Task name must be provided."
15
+
16
+ task_name = cfg.run_cfg.task
17
+ task = registry.get_task_class(task_name).setup_task(cfg=cfg)
18
+ assert task is not None, "Task {} not properly registered.".format(task_name)
19
+
20
+ return task
21
+
22
+
23
+ __all__ = [
24
+ "BaseTask",
25
+ "ImageTextPretrainTask",
26
+ ]
minigpt4/tasks/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (931 Bytes). View file
 
minigpt4/tasks/__pycache__/base_task.cpython-39.pyc ADDED
Binary file (7.54 kB). View file
 
minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc ADDED
Binary file (1.12 kB). View file
 
minigpt4/tasks/base_task.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
14
+ from minigpt4.common.logger import MetricLogger, SmoothedValue
15
+ from minigpt4.common.registry import registry
16
+ from minigpt4.datasets.data_utils import prepare_sample
17
+ import wandb
18
+
19
+ class BaseTask:
20
+ def __init__(self, **kwargs):
21
+ super().__init__()
22
+
23
+ self.inst_id_key = "instance_id"
24
+ self.cfg = ""
25
+
26
+ @classmethod
27
+ def setup_task(cls, **kwargs):
28
+ return cls()
29
+
30
+ def build_model(self, cfg):
31
+ self.cfg = cfg
32
+ model_config = cfg.model_cfg
33
+
34
+ model_cls = registry.get_model_class(model_config.arch)
35
+ return model_cls.from_config(model_config)
36
+
37
+ def build_datasets(self, cfg):
38
+ """
39
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
40
+ Download dataset and annotations automatically if not exist.
41
+
42
+ Args:
43
+ cfg (common.config.Config): _description_
44
+
45
+ Returns:
46
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
47
+ """
48
+
49
+ datasets = dict()
50
+
51
+ datasets_config = cfg.datasets_cfg
52
+
53
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
54
+
55
+ for name in datasets_config:
56
+ dataset_config = datasets_config[name]
57
+
58
+ builder = registry.get_builder_class(name)(dataset_config)
59
+ dataset = builder.build_datasets()
60
+
61
+ dataset['train'].name = name
62
+ if 'sample_ratio' in dataset_config:
63
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
64
+
65
+ datasets[name] = dataset
66
+
67
+ return datasets
68
+
69
+ def train_step(self, model, samples):
70
+ outputs = model(samples)
71
+ # loss = outputs["loss"] + outputs["emos_loss"]
72
+ loss = outputs["emos_loss"]
73
+ # print(outputs["loss"], outputs["emos_loss"], torch.argmax(outputs['emos_pred'], dim=1), outputs["emotion"])
74
+
75
+ return loss
76
+
77
+ def valid_step(self, model, samples):
78
+ raise NotImplementedError
79
+
80
+ def before_evaluation(self, model, dataset, **kwargs):
81
+ model.before_evaluation(dataset=dataset, task_type=type(self))
82
+
83
+ def after_evaluation(self, **kwargs):
84
+ pass
85
+
86
+ def inference_step(self):
87
+ raise NotImplementedError
88
+
89
+ def evaluation(self, model, data_loader, cuda_enabled=True):
90
+ metric_logger = MetricLogger(delimiter=" ")
91
+ header = "Evaluation"
92
+ # TODO make it configurable
93
+ print_freq = 10
94
+
95
+ results = []
96
+
97
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
98
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
99
+
100
+ eval_output = self.valid_step(model=model, samples=samples)
101
+ results.extend(eval_output)
102
+
103
+ if is_dist_avail_and_initialized():
104
+ dist.barrier()
105
+
106
+ return results
107
+
108
+ def train_epoch(
109
+ self,
110
+ epoch,
111
+ model,
112
+ data_loader,
113
+ optimizer,
114
+ lr_scheduler,
115
+ scaler=None,
116
+ cuda_enabled=False,
117
+ log_freq=50,
118
+ accum_grad_iters=1,
119
+ ):
120
+ return self._train_inner_loop(
121
+ epoch=epoch,
122
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
123
+ model=model,
124
+ data_loader=data_loader,
125
+ optimizer=optimizer,
126
+ scaler=scaler,
127
+ lr_scheduler=lr_scheduler,
128
+ log_freq=log_freq,
129
+ cuda_enabled=cuda_enabled,
130
+ accum_grad_iters=accum_grad_iters,
131
+ )
132
+
133
+ def train_iters(
134
+ self,
135
+ epoch,
136
+ start_iters,
137
+ iters_per_inner_epoch,
138
+ model,
139
+ data_loader,
140
+ optimizer,
141
+ lr_scheduler,
142
+ scaler=None,
143
+ cuda_enabled=False,
144
+ log_freq=50,
145
+ accum_grad_iters=1,
146
+ ):
147
+ return self._train_inner_loop(
148
+ epoch=epoch,
149
+ start_iters=start_iters,
150
+ iters_per_epoch=iters_per_inner_epoch,
151
+ model=model,
152
+ data_loader=data_loader,
153
+ optimizer=optimizer,
154
+ scaler=scaler,
155
+ lr_scheduler=lr_scheduler,
156
+ log_freq=log_freq,
157
+ cuda_enabled=cuda_enabled,
158
+ accum_grad_iters=accum_grad_iters,
159
+ )
160
+
161
+ def _train_inner_loop(
162
+ self,
163
+ epoch,
164
+ iters_per_epoch,
165
+ model,
166
+ data_loader,
167
+ optimizer,
168
+ lr_scheduler,
169
+ scaler=None,
170
+ start_iters=None,
171
+ log_freq=50,
172
+ cuda_enabled=False,
173
+ accum_grad_iters=1,
174
+ ):
175
+ """
176
+ An inner training loop compatible with both epoch-based and iter-based training.
177
+
178
+ When using epoch-based, training stops after one epoch; when using iter-based,
179
+ training stops after #iters_per_epoch iterations.
180
+ """
181
+ use_amp = scaler is not None
182
+
183
+ if not hasattr(data_loader, "__next__"):
184
+ # convert to iterator if not already
185
+ data_loader = iter(data_loader)
186
+
187
+ metric_logger = MetricLogger(delimiter=" ")
188
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
189
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
190
+
191
+ # if iter-based runner, schedule lr based on inner epoch.
192
+ logging.info(
193
+ "Start training epoch {}, {} iters per inner epoch.".format(
194
+ epoch, iters_per_epoch
195
+ )
196
+ )
197
+ header = "Train: data epoch: [{}]".format(epoch)
198
+ if start_iters is None:
199
+ # epoch-based runner
200
+ inner_epoch = epoch
201
+ else:
202
+ # In iter-based runner, we schedule the learning rate based on iterations.
203
+ inner_epoch = start_iters // iters_per_epoch
204
+ header = header + "; inner epoch [{}]".format(inner_epoch)
205
+
206
+ image_list = []
207
+ caption_list = []
208
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
209
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
210
+ if i >= iters_per_epoch:
211
+ break
212
+
213
+ samples = next(data_loader)
214
+ image_list.append(samples['image_id'])
215
+ caption_list.append(samples['answer'])
216
+
217
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
218
+ samples.update(
219
+ {
220
+ "epoch": inner_epoch,
221
+ "num_iters_per_epoch": iters_per_epoch,
222
+ "iters": i,
223
+ }
224
+ )
225
+
226
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
227
+
228
+ with torch.cuda.amp.autocast(enabled=use_amp):
229
+ loss = self.train_step(model=model, samples=samples)
230
+
231
+ # after_train_step()
232
+ if use_amp:
233
+ scaler.scale(loss).backward()
234
+ else:
235
+ loss.backward()
236
+
237
+ # update gradients every accum_grad_iters iterations
238
+ if (i + 1) % accum_grad_iters == 0:
239
+ if use_amp:
240
+ scaler.step(optimizer)
241
+ scaler.update()
242
+ else:
243
+ optimizer.step()
244
+ optimizer.zero_grad()
245
+ # if self.cfg.wandb_log:
246
+ if self.cfg.run_cfg.wandb_log:
247
+ wandb.log({"epoch": inner_epoch, "loss": loss})
248
+ metric_logger.update(loss=loss.item())
249
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
250
+
251
+ # Print the learning rate for attention parameters
252
+ for param_group in optimizer.param_groups:
253
+ if "attention" in param_group.get("params", []):
254
+ print("Attention LR:", param_group["lr"])
255
+
256
+ # save random samples' name
257
+ save_dir = "/home/user/project/Emotion-LLaMA/checkpoints/run_samples"
258
+ save_to = os.path.join(
259
+ save_dir,
260
+ "epoch_{}.txt".format(epoch),
261
+ )
262
+ with open(save_to, 'w') as file:
263
+ for i in range(len(image_list)):
264
+ name = image_list[i]
265
+ caption = caption_list[i]
266
+ file.write(name[0] + " " + caption[0] + '\n')
267
+
268
+ # after train_epoch()
269
+ # gather the stats from all processes
270
+ metric_logger.synchronize_between_processes()
271
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
272
+ return {
273
+ k: "{:.6f}".format(meter.global_avg)
274
+ for k, meter in metric_logger.meters.items()
275
+ }
276
+
277
+ @staticmethod
278
+ def save_result(result, result_dir, filename, remove_duplicate=""):
279
+ import json
280
+
281
+ result_file = os.path.join(
282
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
283
+ )
284
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
285
+
286
+ json.dump(result, open(result_file, "w"))
287
+
288
+ if is_dist_avail_and_initialized():
289
+ dist.barrier()
290
+
291
+ if is_main_process():
292
+ logging.warning("rank %d starts merging results." % get_rank())
293
+ # combine results from all processes
294
+ result = []
295
+
296
+ for rank in range(get_world_size()):
297
+ result_file = os.path.join(
298
+ result_dir, "%s_rank%d.json" % (filename, rank)
299
+ )
300
+ res = json.load(open(result_file, "r"))
301
+ result += res
302
+
303
+ if remove_duplicate:
304
+ result_new = []
305
+ id_list = []
306
+ for res in result:
307
+ if res[remove_duplicate] not in id_list:
308
+ id_list.append(res[remove_duplicate])
309
+ result_new.append(res)
310
+ result = result_new
311
+
312
+ json.dump(result, open(final_result_file, "w"))
313
+ print("result file saved to %s" % final_result_file)
314
+
315
+ return final_result_file
minigpt4/tasks/image_text_pretrain.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from minigpt4.common.registry import registry
9
+ from minigpt4.tasks.base_task import BaseTask
10
+
11
+
12
+ @registry.register_task("image_text_pretrain")
13
+ class ImageTextPretrainTask(BaseTask):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def evaluation(self, model, data_loader, cuda_enabled=True):
18
+ print("-----evaluation----")
19
+ # pass