diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 4d06d53a48..7cd8133df2 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -228,10 +228,12 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T } if (!tensor_proto.has_data_type()) { - MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!"; + MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; + MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type."; } if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { - MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!"; + MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; + MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; } tensor::TensorPtr tensor_info = diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 8d42ad0f02..f4d5e343b6 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -24,8 +24,8 @@ from mindspore.common import dtype as mstype from mindspore import log as logger from mindspore.common.api import _executor from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model -from mindspore.train.anf_ir_pb2 import ModelProto as anf_model from mindspore.train.checkpoint_pb2 import Checkpoint +from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo @@ -215,7 +215,7 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): Args: file_name (str): File name. - proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. + proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR. display_data (bool): Whether display data. Default: False. Returns: @@ -224,10 +224,10 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): if proto_format == "MINDIR": model = mindir_model() - elif proto_format == "ANF": - model = anf_model() elif proto_format == "CKPT": model = Checkpoint() + elif proto_format == "CKPT_STRATEGY": + model = ckpt_strategy() else: raise ValueError("Unsupported proto format.") diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 283ad875b2..7d87478855 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -658,13 +658,18 @@ def _save_mindir(net, file_name, *inputs): net_dict = net.parameters_dict() data_total = 0 save_together = True - for i in net_dict.values(): - data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 + + model.ParseFromString(mindir_stream) + for param_proto in model.graph.parameter: + name = param_proto.name[param_proto.name.find(":") + 1:] + if name in net_dict.keys(): + data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 + else: + raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name)) if data_total > TOTAL_SAVE: save_together = False break - model.ParseFromString(mindir_stream) if save_together: for param_proto in model.graph.parameter: param_name = param_proto.name[param_proto.name.find(":")+1:] @@ -697,19 +702,19 @@ def _save_mindir(net, file_name, *inputs): index = 0 graphproto = graph_proto() data_size = 0 + for name, param in net_dict.items(): - byte_data = param.data.asnumpy().tobytes() - data_size += sys.getsizeof(byte_data)/1024 - parameter = graphproto.parameter.add() for param_proto in model.graph.parameter: if name == param_proto.name[param_proto.name.find(":") + 1:]: + parameter = graphproto.parameter.add() parameter.name = param_proto.name parameter.data_type = param_proto.data_type for dim in param_proto.dims: parameter.dims.append(dim) + byte_data = param.data.asnumpy().tobytes() + parameter.raw_data = byte_data + data_size += sys.getsizeof(byte_data) / 1024 break - - parameter.raw_data = byte_data if data_size > TOTAL_SAVE: data_file_name = data_path + "/" + "data_" + str(index) with open(data_file_name, "ab") as f: