You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
5.8 KiB
165 lines
5.8 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpose.apis import init_pose_model
|
|
|
|
try:
|
|
import onnx
|
|
import onnxruntime as rt
|
|
except ImportError as e:
|
|
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
|
|
|
|
try:
|
|
from mmcv.onnx.symbolic import register_extra_symbolics
|
|
except ModuleNotFoundError:
|
|
raise NotImplementedError('please update mmcv to version>=1.0.4')
|
|
|
|
|
|
def _convert_batchnorm(module):
|
|
"""Convert the syncBNs into normal BN3ds."""
|
|
module_output = module
|
|
if isinstance(module, torch.nn.SyncBatchNorm):
|
|
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
|
|
module.momentum, module.affine,
|
|
module.track_running_stats)
|
|
if module.affine:
|
|
module_output.weight.data = module.weight.data.clone().detach()
|
|
module_output.bias.data = module.bias.data.clone().detach()
|
|
# keep requires_grad unchanged
|
|
module_output.weight.requires_grad = module.weight.requires_grad
|
|
module_output.bias.requires_grad = module.bias.requires_grad
|
|
module_output.running_mean = module.running_mean
|
|
module_output.running_var = module.running_var
|
|
module_output.num_batches_tracked = module.num_batches_tracked
|
|
for name, child in module.named_children():
|
|
module_output.add_module(name, _convert_batchnorm(child))
|
|
del module
|
|
return module_output
|
|
|
|
|
|
def pytorch2onnx(model,
|
|
input_shape,
|
|
opset_version=11,
|
|
show=False,
|
|
output_file='tmp.onnx',
|
|
verify=False):
|
|
"""Convert pytorch model to onnx model.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`): The pytorch model to be exported.
|
|
input_shape (tuple[int]): The input tensor shape of the model.
|
|
opset_version (int): Opset version of onnx used. Default: 11.
|
|
show (bool): Determines whether to print the onnx model architecture.
|
|
Default: False.
|
|
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
|
|
verify (bool): Determines whether to verify the onnx model.
|
|
Default: False.
|
|
"""
|
|
model.cpu().eval()
|
|
|
|
one_img = torch.randn(input_shape)
|
|
|
|
register_extra_symbolics(opset_version)
|
|
torch.onnx.export(
|
|
model,
|
|
one_img,
|
|
output_file,
|
|
export_params=True,
|
|
keep_initializers_as_inputs=True,
|
|
verbose=show,
|
|
opset_version=opset_version)
|
|
|
|
print(f'Successfully exported ONNX model: {output_file}')
|
|
if verify:
|
|
# check by onnx
|
|
onnx_model = onnx.load(output_file)
|
|
onnx.checker.check_model(onnx_model)
|
|
|
|
# check the numerical value
|
|
# get pytorch output
|
|
pytorch_results = model(one_img)
|
|
if not isinstance(pytorch_results, (list, tuple)):
|
|
assert isinstance(pytorch_results, torch.Tensor)
|
|
pytorch_results = [pytorch_results]
|
|
|
|
# get onnx output
|
|
input_all = [node.name for node in onnx_model.graph.input]
|
|
input_initializer = [
|
|
node.name for node in onnx_model.graph.initializer
|
|
]
|
|
net_feed_input = list(set(input_all) - set(input_initializer))
|
|
assert len(net_feed_input) == 1
|
|
sess = rt.InferenceSession(output_file)
|
|
onnx_results = sess.run(None,
|
|
{net_feed_input[0]: one_img.detach().numpy()})
|
|
|
|
# compare results
|
|
assert len(pytorch_results) == len(onnx_results)
|
|
for pt_result, onnx_result in zip(pytorch_results, onnx_results):
|
|
assert np.allclose(
|
|
pt_result.detach().cpu(), onnx_result, atol=1.e-5
|
|
), 'The outputs are different between Pytorch and ONNX'
|
|
print('The numerical values are same between Pytorch and ONNX')
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Convert MMPose models to ONNX')
|
|
parser.add_argument('config', help='test config file path')
|
|
parser.add_argument('checkpoint', help='checkpoint file')
|
|
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
|
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
|
parser.add_argument('--opset-version', type=int, default=11)
|
|
parser.add_argument(
|
|
'--verify',
|
|
action='store_true',
|
|
help='verify the onnx model output against pytorch output')
|
|
parser.add_argument(
|
|
'--shape',
|
|
type=int,
|
|
nargs='+',
|
|
default=[1, 3, 256, 192],
|
|
help='input size')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
assert args.opset_version == 11, 'MMPose only supports opset 11 now'
|
|
|
|
# Following strings of text style are from colorama package
|
|
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
|
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
|
white_background = '\x1b[107m'
|
|
|
|
msg = white_background + bright_style + red_text
|
|
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
|
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
|
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
|
msg += reset_style
|
|
warnings.warn(msg)
|
|
|
|
model = init_pose_model(args.config, args.checkpoint, device='cpu')
|
|
model = _convert_batchnorm(model)
|
|
|
|
# onnx.export does not support kwargs
|
|
if hasattr(model, 'forward_dummy'):
|
|
model.forward = model.forward_dummy
|
|
else:
|
|
raise NotImplementedError(
|
|
'Please implement the forward method for exporting.')
|
|
|
|
# convert model to onnx file
|
|
pytorch2onnx(
|
|
model,
|
|
args.shape,
|
|
opset_version=args.opset_version,
|
|
show=args.show,
|
|
output_file=args.output_file,
|
|
verify=args.verify)
|
|
|