|
|
|
@@ -29,8 +29,7 @@ from mindspore.common.api import _executor |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore._checkparam import check_input_data |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] |
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"] |
|
|
|
|
|
|
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, |
|
|
|
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, |
|
|
|
@@ -513,6 +512,13 @@ def parse_print(print_file_name): |
|
|
|
tensor_list.append(Tensor(param_value, ms_type)) |
|
|
|
# Scale type |
|
|
|
else: |
|
|
|
data_type_ = data_type.lower() |
|
|
|
if 'float' in data_type_: |
|
|
|
param_data = float(param_data[0]) |
|
|
|
elif 'int' in data_type_: |
|
|
|
param_data = int(param_data[0]) |
|
|
|
elif 'bool' in data_type_: |
|
|
|
param_data = bool(param_data[0]) |
|
|
|
tensor_list.append(Tensor(param_data, ms_type)) |
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
|