| @@ -228,10 +228,12 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T | |||
| } | |||
| if (!tensor_proto.has_data_type()) { | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!"; | |||
| MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type."; | |||
| } | |||
| if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!"; | |||
| MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; | |||
| } | |||
| tensor::TensorPtr tensor_info = | |||
| @@ -24,8 +24,8 @@ from mindspore.common import dtype as mstype | |||
| from mindspore import log as logger | |||
| from mindspore.common.api import _executor | |||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | |||
| from mindspore.train.anf_ir_pb2 import ModelProto as anf_model | |||
| from mindspore.train.checkpoint_pb2 import Checkpoint | |||
| from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy | |||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | |||
| @@ -215,7 +215,7 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): | |||
| Args: | |||
| file_name (str): File name. | |||
| proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. | |||
| proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR. | |||
| display_data (bool): Whether display data. Default: False. | |||
| Returns: | |||
| @@ -224,10 +224,10 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): | |||
| if proto_format == "MINDIR": | |||
| model = mindir_model() | |||
| elif proto_format == "ANF": | |||
| model = anf_model() | |||
| elif proto_format == "CKPT": | |||
| model = Checkpoint() | |||
| elif proto_format == "CKPT_STRATEGY": | |||
| model = ckpt_strategy() | |||
| else: | |||
| raise ValueError("Unsupported proto format.") | |||
| @@ -658,13 +658,18 @@ def _save_mindir(net, file_name, *inputs): | |||
| net_dict = net.parameters_dict() | |||
| data_total = 0 | |||
| save_together = True | |||
| for i in net_dict.values(): | |||
| data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 | |||
| model.ParseFromString(mindir_stream) | |||
| for param_proto in model.graph.parameter: | |||
| name = param_proto.name[param_proto.name.find(":") + 1:] | |||
| if name in net_dict.keys(): | |||
| data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 | |||
| else: | |||
| raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name)) | |||
| if data_total > TOTAL_SAVE: | |||
| save_together = False | |||
| break | |||
| model.ParseFromString(mindir_stream) | |||
| if save_together: | |||
| for param_proto in model.graph.parameter: | |||
| param_name = param_proto.name[param_proto.name.find(":")+1:] | |||
| @@ -697,19 +702,19 @@ def _save_mindir(net, file_name, *inputs): | |||
| index = 0 | |||
| graphproto = graph_proto() | |||
| data_size = 0 | |||
| for name, param in net_dict.items(): | |||
| byte_data = param.data.asnumpy().tobytes() | |||
| data_size += sys.getsizeof(byte_data)/1024 | |||
| parameter = graphproto.parameter.add() | |||
| for param_proto in model.graph.parameter: | |||
| if name == param_proto.name[param_proto.name.find(":") + 1:]: | |||
| parameter = graphproto.parameter.add() | |||
| parameter.name = param_proto.name | |||
| parameter.data_type = param_proto.data_type | |||
| for dim in param_proto.dims: | |||
| parameter.dims.append(dim) | |||
| byte_data = param.data.asnumpy().tobytes() | |||
| parameter.raw_data = byte_data | |||
| data_size += sys.getsizeof(byte_data) / 1024 | |||
| break | |||
| parameter.raw_data = byte_data | |||
| if data_size > TOTAL_SAVE: | |||
| data_file_name = data_path + "/" + "data_" + str(index) | |||
| with open(data_file_name, "ab") as f: | |||