Spaces:
Runtime error
Runtime error
File size: 4,307 Bytes
8fc2b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""Main training script."""
import os
from pathlib import Path
import torch
from cliport import agents
from cliport.dataset import RavensDataset, RavensMultiTaskDataset, RavenMultiTaskDatasetBalance
import hydra
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import IPython
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
import datetime
import time
@hydra.main(config_path="./cfg", config_name='train', version_base="1.2")
def main(cfg):
# Logger
wandb_logger = None
if cfg['train']['log']:
try:
wandb_logger = WandbLogger(name=cfg['tag'])
except:
pass
# Checkpoint saver
hydra_dir = Path(os.getcwd())
checkpoint_path = os.path.join(cfg['train']['train_dir'], 'checkpoints')
last_checkpoint_path = os.path.join(checkpoint_path, 'last.ckpt')
last_checkpoint = last_checkpoint_path if os.path.exists(last_checkpoint_path) and cfg['train']['load_from_last_ckpt'] else None
checkpoint_callback = [ModelCheckpoint(
# monitor=cfg['wandb']['saver']['monitor'],
dirpath=os.path.join(checkpoint_path, 'best'),
save_top_k=1,
every_n_epochs=3,
save_last=True,
# every_n_train_steps=100
)]
# Trainer
max_epochs = cfg['train']['n_steps'] * cfg['train']['batch_size'] // cfg['train']['n_demos']
if cfg['train']['training_step_scale'] > 0:
# scale training time depending on the tasks to ensure coverage.
max_epochs = cfg['train']['training_step_scale'] # // cfg['train']['batch_size']
trainer = Trainer(
accelerator='gpu',
devices=cfg['train']['gpu'],
fast_dev_run=cfg['debug'],
logger=wandb_logger,
callbacks=checkpoint_callback,
max_epochs=max_epochs,
# check_val_every_n_epoch=max_epochs // 50,
# resume_from_checkpoint=last_checkpoint,
sync_batchnorm=True,
log_every_n_steps=30,
)
print(f"max epochs: {max_epochs}!")
# Resume epoch and global_steps
if last_checkpoint:
print(f"Resuming: {last_checkpoint}")
# Config
data_dir = cfg['train']['data_dir']
task = cfg['train']['task']
agent_type = cfg['train']['agent']
n_demos = cfg['train']['n_demos']
# n_demos = cfg['train']['n_demos']
# n_demos = cfg['train']['n_demos']
n_val = cfg['train']['n_val']
name = '{}-{}-{}'.format(task, agent_type, n_demos)
# Datasets
dataset_type = cfg['dataset']['type']
if 'multi' in dataset_type:
train_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='train',
n_demos=n_demos, augment=True)
val_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False)
elif 'weighted' in dataset_type:
train_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='train', n_demos=n_demos, augment=True)
val_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False)
else:
train_ds = RavensDataset(os.path.join(data_dir, '{}-train'.format(task)), cfg, n_demos=n_demos, augment=True)
val_ds = RavensDataset(os.path.join(data_dir, '{}-val'.format(task)), cfg, n_demos=n_val, augment=False)
# Initialize agent
train_loader = DataLoader(train_ds, shuffle=True,
pin_memory=True,
batch_size=cfg['train']['batch_size'],
num_workers=1 )
test_loader = DataLoader(val_ds, shuffle=False,
num_workers=1,
batch_size=cfg['train']['batch_size'],
pin_memory=True)
agent = agents.names[agent_type](name, cfg, train_loader, test_loader)
dt_string = datetime.datetime.now().strftime("%d_%m_%Y_%H:%M:%S")
print("current time:", dt_string)
start_time = time.time()
# Main training loop
trainer.fit(agent, ckpt_path=last_checkpoint)
print("current time:", time.time() - start_time)
if __name__ == '__main__':
main()
|