|
|
|
@@ -25,6 +25,7 @@ from mindspore import log as logger |
|
|
|
from mindspore.common.api import _executor |
|
|
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model |
|
|
|
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model |
|
|
|
from mindspore.train.checkpoint_pb2 import Checkpoint |
|
|
|
|
|
|
|
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo |
|
|
|
|
|
|
|
@@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types): |
|
|
|
f'but got {type(arg_value).__name__}.') |
|
|
|
|
|
|
|
|
|
|
|
def read_proto(file_name, proto_format="MINDIR"): |
|
|
|
def read_proto(file_name, proto_format="MINDIR", display_data=False): |
|
|
|
""" |
|
|
|
Read protobuf file. |
|
|
|
|
|
|
|
Args: |
|
|
|
file_name (str): File name. |
|
|
|
proto_format (str): Proto format. |
|
|
|
proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. |
|
|
|
display_data (bool): Whether display data. Default: False. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Object, proto object. |
|
|
|
@@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"): |
|
|
|
|
|
|
|
if proto_format == "MINDIR": |
|
|
|
model = mindir_model() |
|
|
|
elif model_format == "ANF": |
|
|
|
elif proto_format == "ANF": |
|
|
|
model = anf_model() |
|
|
|
elif proto_format == "CKPT": |
|
|
|
model = Checkpoint() |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported proto format.") |
|
|
|
|
|
|
|
@@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"): |
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) |
|
|
|
raise ValueError(e.__str__()) |
|
|
|
|
|
|
|
if proto_format == "MINDIR" and not display_data: |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
param_proto.raw_data = b'\0' |
|
|
|
|
|
|
|
if proto_format == "CKPT" and not display_data: |
|
|
|
for element in model.value: |
|
|
|
element.tensor.tensor_content = b'\0' |
|
|
|
|
|
|
|
return model |