import os import argparse import sys import torch import warnings warnings.filterwarnings("ignore") torch.multiprocessing.set_sharing_strategy('file_system') # Set environment variables # os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6' os.environ['OMP_NUM_THREADS'] = '1' os.environ['DETECTRON2_DATASETS'] = '/ccn2/u/honglinc/datasets' # Add necessary path MASK2FORMER_PATH = '/ccn2/u/honglinc/Mask2Former' BBNET_PATH = '/home/honglinc/BBNet' sys.path.append(os.path.join(BBNET_PATH, 'bbnet/models/VideoMAE-main/')) sys.path.append(BBNET_PATH) sys.path.append(MASK2FORMER_PATH) # BBNet import import modeling_pretrain as vmae_tranformers from evaluate_segmentation_readout_helper_v2 import CWMSegmentPredictorV2 import detectron2.utils.comm as comm from detectron2.evaluation import verify_results from train_net import setup, Trainer, DetectionCheckpointer from detectron2.engine import default_argument_parser, launch def main(args): cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) return trainer.train() if __name__ == "__main__": args = default_argument_parser().parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )