diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index dd1076a094..e4eec594aa 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -830,13 +830,10 @@ def parse_print(print_file_name): np_type = tensor_to_np_type[data_type] param_data = np.fromstring(data, np_type) ms_type = tensor_to_ms_type[data_type] - param_dim = [] - for dim in dims: - param_dim.append(dim) - if param_dim: - param_value = param_data.reshape(param_dim) + if dims and dims != [0]: + param_value = param_data.reshape(dims) tensor_list.append(Tensor(param_value, ms_type)) - # Scale type + # Scalar type else: data_type_ = data_type.lower() if 'float' in data_type_: