| @@ -228,10 +228,12 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T | |||||
| } | } | ||||
| if (!tensor_proto.has_data_type()) { | if (!tensor_proto.has_data_type()) { | ||||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!"; | |||||
| MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; | |||||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type."; | |||||
| } | } | ||||
| if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { | if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { | ||||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!"; | |||||
| MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; | |||||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; | |||||
| } | } | ||||
| tensor::TensorPtr tensor_info = | tensor::TensorPtr tensor_info = | ||||
| @@ -24,8 +24,8 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | ||||
| from mindspore.train.anf_ir_pb2 import ModelProto as anf_model | |||||
| from mindspore.train.checkpoint_pb2 import Checkpoint | from mindspore.train.checkpoint_pb2 import Checkpoint | ||||
| from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy | |||||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | ||||
| @@ -215,7 +215,7 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): | |||||
| Args: | Args: | ||||
| file_name (str): File name. | file_name (str): File name. | ||||
| proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. | |||||
| proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR. | |||||
| display_data (bool): Whether display data. Default: False. | display_data (bool): Whether display data. Default: False. | ||||
| Returns: | Returns: | ||||
| @@ -224,10 +224,10 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): | |||||
| if proto_format == "MINDIR": | if proto_format == "MINDIR": | ||||
| model = mindir_model() | model = mindir_model() | ||||
| elif proto_format == "ANF": | |||||
| model = anf_model() | |||||
| elif proto_format == "CKPT": | elif proto_format == "CKPT": | ||||
| model = Checkpoint() | model = Checkpoint() | ||||
| elif proto_format == "CKPT_STRATEGY": | |||||
| model = ckpt_strategy() | |||||
| else: | else: | ||||
| raise ValueError("Unsupported proto format.") | raise ValueError("Unsupported proto format.") | ||||
| @@ -658,13 +658,18 @@ def _save_mindir(net, file_name, *inputs): | |||||
| net_dict = net.parameters_dict() | net_dict = net.parameters_dict() | ||||
| data_total = 0 | data_total = 0 | ||||
| save_together = True | save_together = True | ||||
| for i in net_dict.values(): | |||||
| data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 | |||||
| model.ParseFromString(mindir_stream) | |||||
| for param_proto in model.graph.parameter: | |||||
| name = param_proto.name[param_proto.name.find(":") + 1:] | |||||
| if name in net_dict.keys(): | |||||
| data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 | |||||
| else: | |||||
| raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name)) | |||||
| if data_total > TOTAL_SAVE: | if data_total > TOTAL_SAVE: | ||||
| save_together = False | save_together = False | ||||
| break | break | ||||
| model.ParseFromString(mindir_stream) | |||||
| if save_together: | if save_together: | ||||
| for param_proto in model.graph.parameter: | for param_proto in model.graph.parameter: | ||||
| param_name = param_proto.name[param_proto.name.find(":")+1:] | param_name = param_proto.name[param_proto.name.find(":")+1:] | ||||
| @@ -697,19 +702,19 @@ def _save_mindir(net, file_name, *inputs): | |||||
| index = 0 | index = 0 | ||||
| graphproto = graph_proto() | graphproto = graph_proto() | ||||
| data_size = 0 | data_size = 0 | ||||
| for name, param in net_dict.items(): | for name, param in net_dict.items(): | ||||
| byte_data = param.data.asnumpy().tobytes() | |||||
| data_size += sys.getsizeof(byte_data)/1024 | |||||
| parameter = graphproto.parameter.add() | |||||
| for param_proto in model.graph.parameter: | for param_proto in model.graph.parameter: | ||||
| if name == param_proto.name[param_proto.name.find(":") + 1:]: | if name == param_proto.name[param_proto.name.find(":") + 1:]: | ||||
| parameter = graphproto.parameter.add() | |||||
| parameter.name = param_proto.name | parameter.name = param_proto.name | ||||
| parameter.data_type = param_proto.data_type | parameter.data_type = param_proto.data_type | ||||
| for dim in param_proto.dims: | for dim in param_proto.dims: | ||||
| parameter.dims.append(dim) | parameter.dims.append(dim) | ||||
| byte_data = param.data.asnumpy().tobytes() | |||||
| parameter.raw_data = byte_data | |||||
| data_size += sys.getsizeof(byte_data) / 1024 | |||||
| break | break | ||||
| parameter.raw_data = byte_data | |||||
| if data_size > TOTAL_SAVE: | if data_size > TOTAL_SAVE: | ||||
| data_file_name = data_path + "/" + "data_" + str(index) | data_file_name = data_path + "/" + "data_" + str(index) | ||||
| with open(data_file_name, "ab") as f: | with open(data_file_name, "ab") as f: | ||||