| import argparse | |
| import torch | |
| from common import flops_calculation_function | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| help="Path to models checkpoint (.pth file).", | |
| ) | |
| args = parser.parse_args() | |
| checkpoint = torch.load(args.model_path, map_location="cpu") | |
| model = checkpoint["model"] | |
| flops = flops_calculation_function(model, torch.ones(1, 3, 480, 480)) | |
| print(f"MMACs = {flops}") |