diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 3812698419..bc74986321 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -424,6 +424,7 @@ def export(net, *inputs, file_name, file_format='GEIR'): if is_training: net.set_train(mode=False) # export model + net.init_parameters_data() if file_format == 'GEIR': _executor.compile(net, *inputs, phase='export') _executor.export(net, file_name, file_format)