Browse Source

!14263 tensorprint adapt to print scalar

From: @yepei6
Reviewed-by: @kingxian,@kisnwang
Signed-off-by: @kingxian
pull/14263/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
476066dc82
1 changed files with 3 additions and 6 deletions
  1. +3
    -6
      mindspore/train/serialization.py

+ 3
- 6
mindspore/train/serialization.py View File

@@ -830,13 +830,10 @@ def parse_print(print_file_name):
np_type = tensor_to_np_type[data_type]
param_data = np.fromstring(data, np_type)
ms_type = tensor_to_ms_type[data_type]
param_dim = []
for dim in dims:
param_dim.append(dim)
if param_dim:
param_value = param_data.reshape(param_dim)
if dims and dims != [0]:
param_value = param_data.reshape(dims)
tensor_list.append(Tensor(param_value, ms_type))
# Scale type
# Scalar type
else:
data_type_ = data_type.lower()
if 'float' in data_type_:


Loading…
Cancel
Save