Browse Source

modify read proto

tags/v1.2.0-rc1
changzherui 4 years ago
parent
commit
0d6db9a0a4
1 changed files with 16 additions and 3 deletions
  1. +16
    -3
      mindspore/train/_utils.py

+ 16
- 3
mindspore/train/_utils.py View File

@@ -25,6 +25,7 @@ from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
from mindspore.train.checkpoint_pb2 import Checkpoint

from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo

@@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types):
f'but got {type(arg_value).__name__}.')


def read_proto(file_name, proto_format="MINDIR"):
def read_proto(file_name, proto_format="MINDIR", display_data=False):
"""
Read protobuf file.

Args:
file_name (str): File name.
proto_format (str): Proto format.
proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR.
display_data (bool): Whether display data. Default: False.

Returns:
Object, proto object.
@@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"):

if proto_format == "MINDIR":
model = mindir_model()
elif model_format == "ANF":
elif proto_format == "ANF":
model = anf_model()
elif proto_format == "CKPT":
model = Checkpoint()
else:
raise ValueError("Unsupported proto format.")

@@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"):
except BaseException as e:
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
raise ValueError(e.__str__())

if proto_format == "MINDIR" and not display_data:
for param_proto in model.graph.parameter:
param_proto.raw_data = b'\0'

if proto_format == "CKPT" and not display_data:
for element in model.value:
element.tensor.tensor_content = b'\0'

return model

Loading…
Cancel
Save