|
|
@@ -530,6 +530,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): |
|
|
if not isinstance(file_name, str): |
|
|
if not isinstance(file_name, str): |
|
|
raise ValueError("Args file_name {} must be string, please check it".format(file_name)) |
|
|
raise ValueError("Args file_name {} must be string, please check it".format(file_name)) |
|
|
|
|
|
|
|
|
|
|
|
Validator.check_file_name_by_regular(file_name) |
|
|
net = _quant_export(net, *inputs, file_format=file_format, **kwargs) |
|
|
net = _quant_export(net, *inputs, file_format=file_format, **kwargs) |
|
|
_export(net, file_name, file_format, *inputs) |
|
|
_export(net, file_name, file_format, *inputs) |
|
|
|
|
|
|
|
|
@@ -552,14 +553,14 @@ def _export(net, file_name, file_format, *inputs): |
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX' |
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX' |
|
|
if is_dump_onnx_in_training: |
|
|
if is_dump_onnx_in_training: |
|
|
net.set_train(mode=False) |
|
|
net.set_train(mode=False) |
|
|
# export model |
|
|
|
|
|
|
|
|
|
|
|
net.init_parameters_data() |
|
|
net.init_parameters_data() |
|
|
if file_format == 'AIR': |
|
|
if file_format == 'AIR': |
|
|
phase_name = 'export.air' |
|
|
phase_name = 'export.air' |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) |
|
|
file_name += ".air" |
|
|
file_name += ".air" |
|
|
_executor.export(file_name, graph_id) |
|
|
_executor.export(file_name, graph_id) |
|
|
elif file_format == 'ONNX': # file_format is 'ONNX' |
|
|
|
|
|
|
|
|
elif file_format == 'ONNX': |
|
|
phase_name = 'export.onnx' |
|
|
phase_name = 'export.onnx' |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id) |
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id) |
|
|
@@ -567,7 +568,7 @@ def _export(net, file_name, file_format, *inputs): |
|
|
with open(file_name, 'wb') as f: |
|
|
with open(file_name, 'wb') as f: |
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
f.write(onnx_stream) |
|
|
f.write(onnx_stream) |
|
|
elif file_format == 'MINDIR': # file_format is 'MINDIR' |
|
|
|
|
|
|
|
|
elif file_format == 'MINDIR': |
|
|
phase_name = 'export.mindir' |
|
|
phase_name = 'export.mindir' |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') |
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') |
|
|
@@ -575,7 +576,7 @@ def _export(net, file_name, file_format, *inputs): |
|
|
with open(file_name, 'wb') as f: |
|
|
with open(file_name, 'wb') as f: |
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
f.write(onnx_stream) |
|
|
f.write(onnx_stream) |
|
|
# restore network training mode |
|
|
|
|
|
|
|
|
|
|
|
if is_dump_onnx_in_training: |
|
|
if is_dump_onnx_in_training: |
|
|
net.set_train(mode=True) |
|
|
net.set_train(mode=True) |
|
|
|
|
|
|
|
|
|