From ac725dcfd9a99a27b89929f1cdb2ac70a6b84858 Mon Sep 17 00:00:00 2001 From: fary86 Date: Sat, 22 Aug 2020 10:52:10 +0800 Subject: [PATCH] Fix bug of export interface, AIR and MINDIR file format support trainning mode --- mindspore/train/serialization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 91f976cb0b..b668f348b2 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -477,9 +477,9 @@ def export(net, *inputs, file_name, file_format='AIR'): supported_formats = ['AIR', 'ONNX', 'MINDIR'] if file_format not in supported_formats: raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') - # switch network mode to infer when it is training - is_training = net.training - if is_training: + # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) + is_dump_onnx_in_training = net.training and file_format == 'ONNX' + if is_dump_onnx_in_training: net.set_train(mode=False) # export model net.init_parameters_data() @@ -502,7 +502,7 @@ def export(net, *inputs, file_name, file_format='AIR'): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) # restore network training mode - if is_training: + if is_dump_onnx_in_training: net.set_train(mode=True)