| import torch |
| import numpy as np |
|
|
| from mobile_sam import sam_model_registry |
| from .onnx_image_encoder import ImageEncoderOnnxModel |
|
|
| import os |
| import argparse |
| import warnings |
|
|
| try: |
| import onnxruntime |
|
|
| onnxruntime_exists = True |
| except ImportError: |
| onnxruntime_exists = False |
|
|
| parser = argparse.ArgumentParser( |
| description="Export the SAM image encoder to an ONNX model." |
| ) |
|
|
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| required=True, |
| help="The path to the SAM model checkpoint.", |
| ) |
|
|
| parser.add_argument( |
| "--output", type=str, required=True, help="The filename to save the ONNX model to." |
| ) |
|
|
| parser.add_argument( |
| "--model-type", |
| type=str, |
| required=True, |
| help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", |
| ) |
|
|
| parser.add_argument( |
| "--use-preprocess", |
| action="store_true", |
| help="Whether to preprocess the image by resizing, standardizing, etc.", |
| ) |
|
|
| parser.add_argument( |
| "--opset", |
| type=int, |
| default=17, |
| help="The ONNX opset version to use. Must be >=11", |
| ) |
|
|
| parser.add_argument( |
| "--quantize-out", |
| type=str, |
| default=None, |
| help=( |
| "If set, will quantize the model and save it with this name. " |
| "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." |
| ), |
| ) |
|
|
| parser.add_argument( |
| "--gelu-approximate", |
| action="store_true", |
| help=( |
| "Replace GELU operations with approximations using tanh. Useful " |
| "for some runtimes that have slow or unimplemented erf ops, used in GELU." |
| ), |
| ) |
|
|
|
|
| def run_export( |
| model_type: str, |
| checkpoint: str, |
| output: str, |
| use_preprocess: bool, |
| opset: int, |
| gelu_approximate: bool = False, |
| ): |
| print("Loading model...") |
| sam = sam_model_registry[model_type](checkpoint=checkpoint) |
|
|
| onnx_model = ImageEncoderOnnxModel( |
| model=sam, |
| use_preprocess=use_preprocess, |
| pixel_mean=[123.675, 116.28, 103.53], |
| pixel_std=[58.395, 57.12, 57.375], |
| ) |
|
|
| if gelu_approximate: |
| for n, m in onnx_model.named_modules(): |
| if isinstance(m, torch.nn.GELU): |
| m.approximate = "tanh" |
|
|
| image_size = sam.image_encoder.img_size |
| if use_preprocess: |
| dummy_input = { |
| "input_image": torch.randn((image_size, image_size, 3), dtype=torch.float) |
| } |
| dynamic_axes = { |
| "input_image": {0: "image_height", 1: "image_width"}, |
| } |
| else: |
| dummy_input = { |
| "input_image": torch.randn( |
| (1, 3, image_size, image_size), dtype=torch.float |
| ) |
| } |
| dynamic_axes = None |
|
|
| _ = onnx_model(**dummy_input) |
|
|
| output_names = ["image_embeddings"] |
|
|
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
| print(f"Exporting onnx model to {output}...") |
| if model_type == "vit_h": |
| output_dir, output_file = os.path.split(output) |
| os.makedirs(output_dir, mode=0o777, exist_ok=True) |
| torch.onnx.export( |
| onnx_model, |
| tuple(dummy_input.values()), |
| output, |
| export_params=True, |
| verbose=False, |
| opset_version=opset, |
| do_constant_folding=True, |
| input_names=list(dummy_input.keys()), |
| output_names=output_names, |
| dynamic_axes=dynamic_axes, |
| ) |
| else: |
| with open(output, "wb") as f: |
| torch.onnx.export( |
| onnx_model, |
| tuple(dummy_input.values()), |
| f, |
| export_params=True, |
| verbose=False, |
| opset_version=opset, |
| do_constant_folding=True, |
| input_names=list(dummy_input.keys()), |
| output_names=output_names, |
| dynamic_axes=dynamic_axes, |
| ) |
|
|
| if onnxruntime_exists: |
| ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()} |
| providers = ["CPUExecutionProvider"] |
|
|
| if model_type == "vit_h": |
| session_option = onnxruntime.SessionOptions() |
| ort_session = onnxruntime.InferenceSession(output, providers=providers) |
| param_file = os.listdir(output_dir) |
| param_file.remove(output_file) |
| for i, layer in enumerate(param_file): |
| with open(os.path.join(output_dir, layer), "rb") as fp: |
| weights = np.frombuffer(fp.read(), dtype=np.float32) |
| weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights) |
| session_option.add_initializer(layer, weights) |
| else: |
| ort_session = onnxruntime.InferenceSession(output, providers=providers) |
|
|
| _ = ort_session.run(None, ort_inputs) |
| print("Model has successfully been run with ONNXRuntime.") |
|
|
|
|
| def to_numpy(tensor): |
| return tensor.cpu().numpy() |
|
|
|
|
| if __name__ == "__main__": |
| args = parser.parse_args() |
| run_export( |
| model_type=args.model_type, |
| checkpoint=args.checkpoint, |
| output=args.output, |
| use_preprocess=args.use_preprocess, |
| opset=args.opset, |
| gelu_approximate=args.gelu_approximate, |
| ) |
|
|
| if args.quantize_out is not None: |
| assert onnxruntime_exists, "onnxruntime is required to quantize the model." |
| from onnxruntime.quantization import QuantType |
| from onnxruntime.quantization.quantize import quantize_dynamic |
|
|
| print(f"Quantizing model and writing to {args.quantize_out}...") |
| quantize_dynamic( |
| model_input=args.output, |
| model_output=args.quantize_out, |
| optimize_model=True, |
| per_channel=False, |
| reduce_range=False, |
| weight_type=QuantType.QUInt8, |
| ) |
| print("Done!") |
|
|