""" Minimal command: python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub Maximal command: python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub """ import json import torch from pytorch_lightning import Trainer, callbacks, seed_everything from pytorch_lightning.loggers import WandbLogger from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor from dataloader import SidewalkSegmentationDataLoader from model import SidewalkSegmentationModel def main( hub_dir: str, batch_size: int = 32, learning_rate: float = 6e-5, model_flavor: int = 0, seed: int = 42, split: str = "train", push_to_hub: bool = False, ): seed_everything(seed) logger = WandbLogger(project="sidewalk-segmentation") gpu_value = 1 if torch.cuda.is_available() else 0 id2label_file = json.load(open("id2label.json", "r")) id2label = {int(k): v for k, v in id2label_file.items()} num_labels = len(id2label) model = SidewalkSegmentationModel( num_labels=num_labels, id2label=id2label, model_flavor=model_flavor, learning_rate=learning_rate, ) data_module = SidewalkSegmentationDataLoader( hub_dir=hub_dir, batch_size=batch_size, split=split, ) data_module.setup() checkpoint_callback = callbacks.ModelCheckpoint( dirpath="checkpoints", save_top_k=1, verbose=True, monitor="val_mean_iou", mode="max", ) early_stopping_callback = callbacks.EarlyStopping( monitor="val_mean_iou", patience=5, verbose=True, mode="max", ) trainer = Trainer( max_epochs=200, progress_bar_refresh_rate=10, gpus=gpu_value, logger=logger, callbacks=[checkpoint_callback, early_stopping_callback], deterministic=False, ) trainer.fit(model, data_module) if push_to_hub: config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}") config.num_labels = num_labels config.id2label = id2label config.label2id = {v: k for k, v in id2label_file.items()} config.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co./ChainYo/segformer-{model_flavor}-sidewalk") checkpoint_path = checkpoint_callback.best_model_filepath model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,) model.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co./ChainYo/segformer-{model_flavor}-sidewalk") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--hub_dir", type=str, required=True) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--learning_rate", type=float, default=6e-5) parser.add_argument("--model_flavor", type=int, default=0) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--split", type=str, default="train") parser.add_argument("--push_to_hub", action="store_true") args = parser.parse_args() main( hub_dir=args.hub_dir, batch_size=args.batch_size, learning_rate=args.learning_rate, model_flavor=args.model_flavor, seed=args.seed, split=args.split, push_to_hub=args.push_to_hub, )