Browse Source

!7561 mode_export_v3

Merge pull request !7561 from baiyangfan/mode_export_v3
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ce4f64021e
2 changed files with 9 additions and 7 deletions
  1. +7
    -5
      mindspore/train/quant/quant.py
  2. +2
    -2
      mindspore/train/serialization.py

+ 7
- 5
mindspore/train/quant/quant.py View File

@@ -391,14 +391,16 @@ class ExportToQuantInferNetwork:


scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
if fake_quant_a_out is not None:
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)


info = self.quant_info_table.get(w_minq_name, None) info = self.quant_info_table.get(w_minq_name, None)
if info: if info:
fake_quant_a_in_op, minq_name = info fake_quant_a_in_op, minq_name = info
if minq_name == 'input': if minq_name == 'input':
scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
self.input_scale, self.input_zero_point, 'None', 'None'
else: else:
maxq = self.all_parameters[minq_name[:-4] + "maxq"] maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_parameters[minq_name] minq = self.all_parameters[minq_name]
@@ -483,11 +485,11 @@ class ExportToQuantInferNetwork:
if isinstance(subcell, quant.Conv2dBnAct): if isinstance(subcell, quant.Conv2dBnAct):
cell_core = subcell.conv cell_core = subcell.conv
activation = subcell.activation activation = subcell.activation
fake_quant_act = activation.fake_quant_act
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
elif isinstance(subcell, quant.DenseBnAct): elif isinstance(subcell, quant.DenseBnAct):
cell_core = subcell.dense cell_core = subcell.dense
activation = subcell.activation activation = subcell.activation
fake_quant_act = activation.fake_quant_act
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
if cell_core is not None: if cell_core is not None:
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell: if new_subcell:


+ 2
- 2
mindspore/train/serialization.py View File

@@ -519,7 +519,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
logger.info("exporting model file:%s format:%s.", file_name, file_format) logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor) check_input_data(*inputs, data_class=Tensor)


net = _quant_export(net, *inputs, file_format='AIR', **kwargs)
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
_export(net, file_name, file_format, *inputs) _export(net, file_name, file_format, *inputs)




@@ -566,7 +566,7 @@ def _export(net, file_name, file_format, *inputs):
net.set_train(mode=True) net.set_train(mode=True)




def _quant_export(network, *inputs, file_format='AIR', **kwargs):
def _quant_export(network, *inputs, file_format, **kwargs):
""" """
Exports MindSpore quantization predict model to deploy with AIR and MINDIR. Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
""" """


Loading…
Cancel
Save