|
|
|
@@ -477,9 +477,9 @@ def export(net, *inputs, file_name, file_format='AIR'): |
|
|
|
supported_formats = ['AIR', 'ONNX', 'MINDIR'] |
|
|
|
if file_format not in supported_formats: |
|
|
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') |
|
|
|
# switch network mode to infer when it is training |
|
|
|
is_training = net.training |
|
|
|
if is_training: |
|
|
|
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) |
|
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX' |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=False) |
|
|
|
# export model |
|
|
|
net.init_parameters_data() |
|
|
|
@@ -502,7 +502,7 @@ def export(net, *inputs, file_name, file_format='AIR'): |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
# restore network training mode |
|
|
|
if is_training: |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=True) |
|
|
|
|
|
|
|
|
|
|
|
|