You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
767 B
30 lines
767 B
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
|
|
import torch.nn as nn
|
|
|
|
from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx
|
|
|
|
|
|
class DummyModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv3d(1, 2, 1)
|
|
self.bn = nn.SyncBatchNorm(2)
|
|
|
|
def forward(self, x):
|
|
return self.bn(self.conv(x))
|
|
|
|
def forward_dummy(self, x):
|
|
return (self.forward(x), )
|
|
|
|
|
|
def test_onnx_exporting():
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
out_file = osp.join(tmpdir, 'tmp.onnx')
|
|
model = DummyModel()
|
|
model = _convert_batchnorm(model)
|
|
# test exporting
|
|
pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file)
|
|
|