| @@ -477,9 +477,9 @@ def export(net, *inputs, file_name, file_format='AIR'): | |||||
| supported_formats = ['AIR', 'ONNX', 'MINDIR'] | supported_formats = ['AIR', 'ONNX', 'MINDIR'] | ||||
| if file_format not in supported_formats: | if file_format not in supported_formats: | ||||
| raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') | raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') | ||||
| # switch network mode to infer when it is training | |||||
| is_training = net.training | |||||
| if is_training: | |||||
| # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) | |||||
| is_dump_onnx_in_training = net.training and file_format == 'ONNX' | |||||
| if is_dump_onnx_in_training: | |||||
| net.set_train(mode=False) | net.set_train(mode=False) | ||||
| # export model | # export model | ||||
| net.init_parameters_data() | net.init_parameters_data() | ||||
| @@ -502,7 +502,7 @@ def export(net, *inputs, file_name, file_format='AIR'): | |||||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | ||||
| f.write(onnx_stream) | f.write(onnx_stream) | ||||
| # restore network training mode | # restore network training mode | ||||
| if is_training: | |||||
| if is_dump_onnx_in_training: | |||||
| net.set_train(mode=True) | net.set_train(mode=True) | ||||