Browse Source

modify mindir file authority

tags/v1.2.0-rc1
changzherui 4 years ago
parent
commit
04947baf39
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      mindspore/train/serialization.py

+ 7
- 4
mindspore/train/serialization.py View File

@@ -481,8 +481,8 @@ def _save_graph(network, file_name):
graph_pb = network.get_func_graph_proto()
if graph_pb:
with open(file_name, "wb") as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(graph_pb)
os.chmod(file_name, stat.S_IRUSR)


def _get_merged_param_data(net, param_name, param_data, integrated_save):
@@ -637,7 +637,7 @@ def _export(net, file_name, file_format, *inputs):
if not file_name.endswith('.onnx'):
file_name += ".onnx"
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR':
_save_mindir(net, file_name, *inputs)
@@ -687,7 +687,7 @@ def _save_mindir(net, file_name, *inputs):
dirname = os.path.dirname(current_path)
os.makedirs(dirname, exist_ok=True)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(model.SerializeToString())
else:
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
@@ -701,6 +701,7 @@ def _save_mindir(net, file_name, *inputs):
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
index = 0
graphproto = graph_proto()
data_size = 0
@@ -720,6 +721,7 @@ def _save_mindir(net, file_name, *inputs):
if data_size > TOTAL_SAVE:
data_file_name = data_path + "/" + "data_" + str(index)
with open(data_file_name, "ab") as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(graphproto.SerializeToString())
index += 1
data_size = 0
@@ -728,13 +730,14 @@ def _save_mindir(net, file_name, *inputs):
if graphproto.parameter:
data_file_name = data_path + "/" + "data_" + str(index)
with open(data_file_name, "ab") as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(graphproto.SerializeToString())

# save graph
del model.graph.parameter[:]
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
with open(graph_file_name, 'wb') as f:
os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR)
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(model.SerializeToString())




Loading…
Cancel
Save