Browse Source

modify export mindir bug

tags/v1.2.0-rc1
changzherui 4 years ago
parent
commit
8d1ab6e60b
3 changed files with 21 additions and 14 deletions
  1. +4
    -2
      mindspore/core/load_mindir/anf_model_parser.cc
  2. +4
    -4
      mindspore/train/_utils.py
  3. +13
    -8
      mindspore/train/serialization.py

+ 4
- 2
mindspore/core/load_mindir/anf_model_parser.cc View File

@@ -228,10 +228,12 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
}

if (!tensor_proto.has_data_type()) {
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!";
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type.";
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!";
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
}

tensor::TensorPtr tensor_info =


+ 4
- 4
mindspore/train/_utils.py View File

@@ -24,8 +24,8 @@ from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
from mindspore.train.checkpoint_pb2 import Checkpoint
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy

from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo

@@ -215,7 +215,7 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):

Args:
file_name (str): File name.
proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR.
proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR.
display_data (bool): Whether display data. Default: False.

Returns:
@@ -224,10 +224,10 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):

if proto_format == "MINDIR":
model = mindir_model()
elif proto_format == "ANF":
model = anf_model()
elif proto_format == "CKPT":
model = Checkpoint()
elif proto_format == "CKPT_STRATEGY":
model = ckpt_strategy()
else:
raise ValueError("Unsupported proto format.")



+ 13
- 8
mindspore/train/serialization.py View File

@@ -658,13 +658,18 @@ def _save_mindir(net, file_name, *inputs):
net_dict = net.parameters_dict()
data_total = 0
save_together = True
for i in net_dict.values():
data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024

model.ParseFromString(mindir_stream)
for param_proto in model.graph.parameter:
name = param_proto.name[param_proto.name.find(":") + 1:]
if name in net_dict.keys():
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
else:
raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name))
if data_total > TOTAL_SAVE:
save_together = False
break

model.ParseFromString(mindir_stream)
if save_together:
for param_proto in model.graph.parameter:
param_name = param_proto.name[param_proto.name.find(":")+1:]
@@ -697,19 +702,19 @@ def _save_mindir(net, file_name, *inputs):
index = 0
graphproto = graph_proto()
data_size = 0

for name, param in net_dict.items():
byte_data = param.data.asnumpy().tobytes()
data_size += sys.getsizeof(byte_data)/1024
parameter = graphproto.parameter.add()
for param_proto in model.graph.parameter:
if name == param_proto.name[param_proto.name.find(":") + 1:]:
parameter = graphproto.parameter.add()
parameter.name = param_proto.name
parameter.data_type = param_proto.data_type
for dim in param_proto.dims:
parameter.dims.append(dim)
byte_data = param.data.asnumpy().tobytes()
parameter.raw_data = byte_data
data_size += sys.getsizeof(byte_data) / 1024
break

parameter.raw_data = byte_data
if data_size > TOTAL_SAVE:
data_file_name = data_path + "/" + "data_" + str(index)
with open(data_file_name, "ab") as f:


Loading…
Cancel
Save