|
|
|
@@ -30,6 +30,9 @@ from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.common.api import _executor |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore._checkparam import check_input_data |
|
|
|
from mindspore.train.quant import quant |
|
|
|
import mindspore.context as context |
|
|
|
from .._checkparam import Validator |
|
|
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", |
|
|
|
"build_searched_strategy", "merge_sliced_parameter"] |
|
|
|
@@ -460,7 +463,7 @@ def _fill_param_into_net(net, parameter_list): |
|
|
|
load_param_into_net(net, parameter_dict) |
|
|
|
|
|
|
|
|
|
|
|
def export(net, *inputs, file_name, file_format='AIR'): |
|
|
|
def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwargs): |
|
|
|
""" |
|
|
|
Export the MindSpore prediction model to a file in the specified format. |
|
|
|
|
|
|
|
@@ -469,7 +472,6 @@ def export(net, *inputs, file_name, file_format='AIR'): |
|
|
|
inputs (Tensor): Inputs of the `net`. |
|
|
|
file_name (str): File name of the model to be exported. |
|
|
|
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. |
|
|
|
|
|
|
|
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model. |
|
|
|
Recommended suffix for output file is '.air'. |
|
|
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. |
|
|
|
@@ -477,44 +479,103 @@ def export(net, *inputs, file_name, file_format='AIR'): |
|
|
|
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format |
|
|
|
for MindSpore models. |
|
|
|
Recommended suffix for output file is '.mindir'. |
|
|
|
quant_export (str): Quantitative export choise. Default: None. |
|
|
|
""" |
|
|
|
if quant_export == 'MANUAL': |
|
|
|
mean = kwargs.get('mean', None) |
|
|
|
std_dev = kwargs.get('std_dev', None) |
|
|
|
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR', quant_manual_export=True) |
|
|
|
elif quant_export == 'AUTO': |
|
|
|
mean = kwargs.get('mean', None) |
|
|
|
std_dev = kwargs.get('std_dev', None) |
|
|
|
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR') |
|
|
|
else: |
|
|
|
logger.info("exporting model file:%s format:%s.", file_name, file_format) |
|
|
|
check_input_data(*inputs, data_class=Tensor) |
|
|
|
|
|
|
|
if file_format == 'GEIR': |
|
|
|
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.") |
|
|
|
file_format = 'AIR' |
|
|
|
|
|
|
|
supported_formats = ['AIR', 'ONNX', 'MINDIR'] |
|
|
|
if file_format not in supported_formats: |
|
|
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') |
|
|
|
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) |
|
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX' |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=False) |
|
|
|
# export model |
|
|
|
net.init_parameters_data() |
|
|
|
if file_format == 'AIR': |
|
|
|
phase_name = 'export.air' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) |
|
|
|
_executor.export(file_name, graph_id) |
|
|
|
elif file_format == 'ONNX': # file_format is 'ONNX' |
|
|
|
phase_name = 'export.onnx' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(graph_id) |
|
|
|
with open(file_name, 'wb') as f: |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
elif file_format == 'MINDIR': # file_format is 'MINDIR' |
|
|
|
phase_name = 'export.mindir' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir') |
|
|
|
with open(file_name, 'wb') as f: |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
# restore network training mode |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=True) |
|
|
|
|
|
|
|
def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', quant_manual_export=False): |
|
|
|
""" |
|
|
|
Exports MindSpore quantization predict model to deploy with AIR and MINDIR. |
|
|
|
|
|
|
|
Args: |
|
|
|
network (Cell): MindSpore network produced by `convert_quant_network`. |
|
|
|
file_name (str): File name of model to export. |
|
|
|
mean (int, float): Input data mean. Default: 127.5. |
|
|
|
std_dev (int, float): Input data variance. Default: 127.5. |
|
|
|
inputs (Tensor): Inputs of the `quantization aware training network`. |
|
|
|
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported |
|
|
|
quantization aware model. Default: 'AIR'. |
|
|
|
|
|
|
|
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of |
|
|
|
Ascend model. |
|
|
|
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format |
|
|
|
for MindSpore models. |
|
|
|
Recommended suffix for output file is '.mindir'. |
|
|
|
quant_manual_export (bool): Is it manual quantitative export. Default: False. |
|
|
|
""" |
|
|
|
logger.info("exporting model file:%s format:%s.", file_name, file_format) |
|
|
|
check_input_data(*inputs, data_class=Tensor) |
|
|
|
supported_device = ["Ascend", "GPU"] |
|
|
|
supported_formats = ['AIR', 'MINDIR'] |
|
|
|
|
|
|
|
mean = mean if mean else 127.5 |
|
|
|
std_dev = std_dev if std_dev else 127.5 |
|
|
|
|
|
|
|
if file_format == 'GEIR': |
|
|
|
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.") |
|
|
|
file_format = 'AIR' |
|
|
|
mean = Validator.check_type("mean", mean, (int, float)) |
|
|
|
std_dev = Validator.check_type("std_dev", std_dev, (int, float)) |
|
|
|
|
|
|
|
if context.get_context('device_target') not in supported_device: |
|
|
|
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) |
|
|
|
|
|
|
|
supported_formats = ['AIR', 'ONNX', 'MINDIR'] |
|
|
|
if file_format not in supported_formats: |
|
|
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') |
|
|
|
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) |
|
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX' |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=False) |
|
|
|
# export model |
|
|
|
net.init_parameters_data() |
|
|
|
if file_format == 'AIR': |
|
|
|
phase_name = 'export.air' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) |
|
|
|
_executor.export(file_name, graph_id) |
|
|
|
elif file_format == 'ONNX': # file_format is 'ONNX' |
|
|
|
phase_name = 'export.onnx' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(graph_id) |
|
|
|
with open(file_name, 'wb') as f: |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
elif file_format == 'MINDIR': # file_format is 'MINDIR' |
|
|
|
phase_name = 'export.mindir' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir') |
|
|
|
with open(file_name, 'wb') as f: |
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(onnx_stream) |
|
|
|
# restore network training mode |
|
|
|
if is_dump_onnx_in_training: |
|
|
|
net.set_train(mode=True) |
|
|
|
raise ValueError('Illegal file format {}.'.format(file_format)) |
|
|
|
|
|
|
|
network.set_train(False) |
|
|
|
if file_format == "MINDIR": |
|
|
|
if quant_manual_export: |
|
|
|
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
else: |
|
|
|
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) |
|
|
|
else: |
|
|
|
if quant_manual_export: |
|
|
|
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs) |
|
|
|
else: |
|
|
|
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) |
|
|
|
deploy_net = exporter.run() |
|
|
|
export(deploy_net, *inputs, file_name=file_name, file_format=file_format) |
|
|
|
|
|
|
|
|
|
|
|
def parse_print(print_file_name): |
|
|
|
|