|
|
|
@@ -14,16 +14,21 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""Model and parameters serialization.""" |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import stat |
|
|
|
import math |
|
|
|
import shutil |
|
|
|
from threading import Thread, Lock |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.train.checkpoint_pb2 import Checkpoint |
|
|
|
from mindspore.train.print_pb2 import Print |
|
|
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts |
|
|
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model |
|
|
|
from mindspore.train.mind_ir_pb2 import GraphProto as graph_proto |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
@@ -46,6 +51,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin |
|
|
|
|
|
|
|
_ckpt_mutex = Lock() |
|
|
|
SLICE_SIZE = 512 * 1024 * 1024 |
|
|
|
TOTAL_SAVE = 1024 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par): |
|
|
|
@@ -423,10 +429,10 @@ def _save_graph(network, file_name): |
|
|
|
""" |
|
|
|
logger.info("Execute the process of saving graph.") |
|
|
|
|
|
|
|
graph_proto = network.get_func_graph_proto() |
|
|
|
if graph_proto: |
|
|
|
graph_pb = network.get_func_graph_proto() |
|
|
|
if graph_pb: |
|
|
|
with open(file_name, "wb") as f: |
|
|
|
f.write(graph_proto) |
|
|
|
f.write(graph_pb) |
|
|
|
os.chmod(file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
|
|
|
|
@@ -569,7 +575,6 @@ def _export(net, file_name, file_format, *inputs): |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=False) |
|
|
|
|
|
|
|
net.init_parameters_data() |
|
|
|
if file_format == 'AIR': |
|
|
|
phase_name = 'export.air' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) |
|
|
|
@@ -586,17 +591,94 @@ def _export(net, file_name, file_format, *inputs): |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
elif file_format == 'MINDIR': |
|
|
|
_save_mindir(net, file_name, *inputs) |
|
|
|
|
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=True) |
|
|
|
|
|
|
|
|
|
|
|
def _save_mindir(net, file_name, *inputs): |
|
|
|
"""Save MindIR format file.""" |
|
|
|
model = mindir_model() |
|
|
|
if net._auto_parallel_mode: |
|
|
|
phase_name = "predict" |
|
|
|
else: |
|
|
|
phase_name = 'export.mindir' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, |
|
|
|
do_convert=False, auto_parallel_mode=net._auto_parallel_mode) |
|
|
|
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') |
|
|
|
|
|
|
|
net_dict = net.parameters_dict() |
|
|
|
data_total = 0 |
|
|
|
save_together = True |
|
|
|
for i in net_dict.values(): |
|
|
|
data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 |
|
|
|
if data_total > TOTAL_SAVE: |
|
|
|
save_together = False |
|
|
|
break |
|
|
|
|
|
|
|
model.ParseFromString(mindir_stream) |
|
|
|
if save_together: |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
param_name = param_proto.name[param_proto.name.find(":")+1:] |
|
|
|
if param_name in net_dict.keys(): |
|
|
|
param_data = net_dict[param_name].data.asnumpy().tobytes() |
|
|
|
param_proto.raw_data = param_data |
|
|
|
else: |
|
|
|
logger.error("The parameter %s in the graph are not in the network.", param_name) |
|
|
|
raise ValueError("The parameter in the graph must in the network.") |
|
|
|
if not file_name.endswith('.mindir'): |
|
|
|
file_name += ".mindir" |
|
|
|
current_path = os.path.abspath(file_name) |
|
|
|
dirname = os.path.dirname(current_path) |
|
|
|
os.makedirs(dirname, exist_ok=True) |
|
|
|
with open(file_name, 'wb') as f: |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
f.write(model.SerializeToString()) |
|
|
|
else: |
|
|
|
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") |
|
|
|
# save parameter |
|
|
|
current_path = os.path.abspath(file_name) |
|
|
|
dirname = os.path.dirname(current_path) |
|
|
|
data_path = dirname + "/variables" |
|
|
|
if os.path.exists(data_path): |
|
|
|
shutil.rmtree(data_path) |
|
|
|
os.makedirs(data_path, exist_ok=True) |
|
|
|
index = 0 |
|
|
|
graphproto = graph_proto() |
|
|
|
data_size = 0 |
|
|
|
for name, param in net_dict.items(): |
|
|
|
byte_data = param.data.asnumpy().tobytes() |
|
|
|
data_size += sys.getsizeof(byte_data)/1024 |
|
|
|
parameter = graphproto.parameter.add() |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
if name == param_proto.name[param_proto.name.find(":") + 1:]: |
|
|
|
parameter.name = param_proto.name |
|
|
|
parameter.data_type = param_proto.data_type |
|
|
|
for dim in param_proto.dims: |
|
|
|
parameter.dims.append(dim) |
|
|
|
break |
|
|
|
|
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=True) |
|
|
|
parameter.raw_data = byte_data |
|
|
|
if data_size > TOTAL_SAVE: |
|
|
|
data_file_name = data_path + "/" + "data_" + str(index) |
|
|
|
with open(data_file_name, "ab") as f: |
|
|
|
f.write(graphproto.SerializeToString()) |
|
|
|
index += 1 |
|
|
|
data_size = 0 |
|
|
|
del graphproto.parameter[:] |
|
|
|
|
|
|
|
if graphproto.parameter: |
|
|
|
data_file_name = data_path + "/" + "data_" + str(index) |
|
|
|
with open(data_file_name, "ab") as f: |
|
|
|
f.write(graphproto.SerializeToString()) |
|
|
|
|
|
|
|
# save graph |
|
|
|
del model.graph.parameter[:] |
|
|
|
graph_file_name = file_name + "_graph.mindir" |
|
|
|
with open(graph_file_name, 'wb') as f: |
|
|
|
os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(model.SerializeToString()) |
|
|
|
|
|
|
|
|
|
|
|
def _quant_export(network, *inputs, file_format, **kwargs): |
|
|
|
|