Browse Source

fix bug in quant deploy export

tags/v0.6.0-beta
Wei Luning 5 years ago
parent
commit
dcd5773f64
3 changed files with 5 additions and 3 deletions
  1. +3
    -1
      mindspore/train/quant/quant.py
  2. +1
    -1
      mindspore/train/quant/quant_utils.py
  3. +1
    -1
      tests/ut/python/train/quant/test_quant.py

+ 3
- 1
mindspore/train/quant/quant.py View File

@@ -154,7 +154,9 @@ class ConvertToQuantNetwork:
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
narrow_range=self.act_range) narrow_range=self.act_range)
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
prefix = self._convert_op_name(prim_op.name)
if network.param_prefix:
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
add_quant.update_parameters_name(prefix + '.') add_quant.update_parameters_name(prefix + '.')
del network.__dict__[name] del network.__dict__[name]
network.insert_child_to_cell(name, add_quant) network.insert_child_to_cell(name, add_quant)


+ 1
- 1
mindspore/train/quant/quant_utils.py View File

@@ -125,7 +125,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
""" """
minq = cell.minq.data.asnumpy() minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy() maxq = cell.maxq.data.asnumpy()
op = cell.fake_quant
op = cell.fake_quant_infer


scale, zp = cal_quantization_params( scale, zp = cal_quantization_params(
minq, maxq, data_type, minq, maxq, data_type,


+ 1
- 1
tests/ut/python/train/quant/test_quant.py View File

@@ -67,7 +67,7 @@ def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5() net = LeNet5()
net = qat.convert_quant_network( net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8)
net, freeze_bn=10000, num_bits=8)
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in net.get_parameters(): for param in net.get_parameters():
param.init_data() param.init_data()


Loading…
Cancel
Save