| @@ -25,6 +25,7 @@ from mindspore import log as logger | |||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | ||||
| from mindspore.train.anf_ir_pb2 import ModelProto as anf_model | from mindspore.train.anf_ir_pb2 import ModelProto as anf_model | ||||
| from mindspore.train.checkpoint_pb2 import Checkpoint | |||||
| from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo | ||||
| @@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types): | |||||
| f'but got {type(arg_value).__name__}.') | f'but got {type(arg_value).__name__}.') | ||||
| def read_proto(file_name, proto_format="MINDIR"): | |||||
| def read_proto(file_name, proto_format="MINDIR", display_data=False): | |||||
| """ | """ | ||||
| Read protobuf file. | Read protobuf file. | ||||
| Args: | Args: | ||||
| file_name (str): File name. | file_name (str): File name. | ||||
| proto_format (str): Proto format. | |||||
| proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. | |||||
| display_data (bool): Whether display data. Default: False. | |||||
| Returns: | Returns: | ||||
| Object, proto object. | Object, proto object. | ||||
| @@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"): | |||||
| if proto_format == "MINDIR": | if proto_format == "MINDIR": | ||||
| model = mindir_model() | model = mindir_model() | ||||
| elif model_format == "ANF": | |||||
| elif proto_format == "ANF": | |||||
| model = anf_model() | model = anf_model() | ||||
| elif proto_format == "CKPT": | |||||
| model = Checkpoint() | |||||
| else: | else: | ||||
| raise ValueError("Unsupported proto format.") | raise ValueError("Unsupported proto format.") | ||||
| @@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"): | |||||
| except BaseException as e: | except BaseException as e: | ||||
| logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) | logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) | ||||
| raise ValueError(e.__str__()) | raise ValueError(e.__str__()) | ||||
| if proto_format == "MINDIR" and not display_data: | |||||
| for param_proto in model.graph.parameter: | |||||
| param_proto.raw_data = b'\0' | |||||
| if proto_format == "CKPT" and not display_data: | |||||
| for element in model.value: | |||||
| element.tensor.tensor_content = b'\0' | |||||
| return model | return model | ||||