Browse Source

!4966 fix bug of exporting AIR/MINDIR

Merge pull request !4966 from fary86/fix_bug_of_export_interface
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e72cd8867e
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/train/serialization.py

+ 4
- 4
mindspore/train/serialization.py View File

@@ -478,9 +478,9 @@ def export(net, *inputs, file_name, file_format='AIR'):
supported_formats = ['AIR', 'ONNX', 'MINDIR']
if file_format not in supported_formats:
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
# switch network mode to infer when it is training
is_training = net.training
if is_training:
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training:
net.set_train(mode=False)
# export model
net.init_parameters_data()
@@ -503,7 +503,7 @@ def export(net, *inputs, file_name, file_format='AIR'):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
# restore network training mode
if is_training:
if is_dump_onnx_in_training:
net.set_train(mode=True)




Loading…
Cancel
Save