| @@ -154,7 +154,9 @@ class ConvertToQuantNetwork: | |||||
| per_channel=self.act_channel, | per_channel=self.act_channel, | ||||
| symmetric=self.act_symmetric, | symmetric=self.act_symmetric, | ||||
| narrow_range=self.act_range) | narrow_range=self.act_range) | ||||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | |||||
| prefix = self._convert_op_name(prim_op.name) | |||||
| if network.param_prefix: | |||||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | |||||
| add_quant.update_parameters_name(prefix + '.') | add_quant.update_parameters_name(prefix + '.') | ||||
| del network.__dict__[name] | del network.__dict__[name] | ||||
| network.insert_child_to_cell(name, add_quant) | network.insert_child_to_cell(name, add_quant) | ||||
| @@ -125,7 +125,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type): | |||||
| """ | """ | ||||
| minq = cell.minq.data.asnumpy() | minq = cell.minq.data.asnumpy() | ||||
| maxq = cell.maxq.data.asnumpy() | maxq = cell.maxq.data.asnumpy() | ||||
| op = cell.fake_quant | |||||
| op = cell.fake_quant_infer | |||||
| scale, zp = cal_quantization_params( | scale, zp = cal_quantization_params( | ||||
| minq, maxq, data_type, | minq, maxq, data_type, | ||||
| @@ -67,7 +67,7 @@ def test_qat_lenet(): | |||||
| img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | ||||
| net = LeNet5() | net = LeNet5() | ||||
| net = qat.convert_quant_network( | net = qat.convert_quant_network( | ||||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8) | |||||
| net, freeze_bn=10000, num_bits=8) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| for param in net.get_parameters(): | for param in net.get_parameters(): | ||||
| param.init_data() | param.init_data() | ||||