import argparse | |
import torch | |
from common import flops_calculation_function | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model-path", | |
type=str, | |
help="Path to models checkpoint (.pth file).", | |
) | |
args = parser.parse_args() | |
checkpoint = torch.load(args.model_path, map_location="cpu") | |
model = checkpoint["model"] | |
flops = flops_calculation_function(model, torch.ones(1, 3, 480, 480)) | |
print(f"MMACs = {flops}") |