| @@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell): | |||||
| out_channels=1, | out_channels=1, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False, | |||||
| training=True): | |||||
| narrow_range=False): | |||||
| """init FakeQuantWithMinMax layer""" | """init FakeQuantWithMinMax layer""" | ||||
| super(FakeQuantWithMinMax, self).__init__() | super(FakeQuantWithMinMax, self).__init__() | ||||
| self.min_init = min_init | self.min_init = min_init | ||||
| @@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell): | |||||
| self.quant_delay = quant_delay | self.quant_delay = quant_delay | ||||
| self.symmetric = symmetric | self.symmetric = symmetric | ||||
| self.narrow_range = narrow_range | self.narrow_range = narrow_range | ||||
| self.training = training | |||||
| self.is_ascend = context.get_context('device_target') == "Ascend" | self.is_ascend = context.get_context('device_target') == "Ascend" | ||||
| # init tensor min and max for fake quant op | # init tensor min and max for fake quant op | ||||
| @@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell): | |||||
| symmetric=self.symmetric, | symmetric=self.symmetric, | ||||
| narrow_range=self.narrow_range, | narrow_range=self.narrow_range, | ||||
| training=self.training) | training=self.training) | ||||
| if self.ema: | |||||
| if self.training: | |||||
| self.ema_update = ema_fun(num_bits=self.num_bits, | self.ema_update = ema_fun(num_bits=self.num_bits, | ||||
| ema=self.ema, | ema=self.ema, | ||||
| ema_decay=self.ema_decay, | ema_decay=self.ema_decay, | ||||
| @@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell): | |||||
| return s | return s | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.ema and self.is_ascend: | |||||
| if self.is_ascend and self.training: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | min_up, max_up = self.ema_update(x, self.minq, self.maxq) | ||||
| out = self.fake_quant(x, min_up, max_up) | out = self.fake_quant(x, min_up, max_up) | ||||
| P.Assign()(self.minq, min_up) | P.Assign()(self.minq, min_up) | ||||
| @@ -38,12 +38,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ | |||||
| .input(3, "running_std", None, "required", None) \ | .input(3, "running_std", None, "required", None) \ | ||||
| .output(0, "dx", True, "required", "all") \ | .output(0, "dx", True, "required", "all") \ | ||||
| .output(1, "d_batch_std", True, "required", "all") \ | .output(1, "d_batch_std", True, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD) \ | DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -247,7 +247,7 @@ def convert_quant_network(network, | |||||
| network (Cell): Obtain a pipeline through network for saving graph summary. | network (Cell): Obtain a pipeline through network for saving graph summary. | ||||
| quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. | quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. | ||||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | ||||
| freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. | |||||
| freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0. | |||||
| weight_bits (int): Number of bits to use for quantizing weights. Default: 8. | weight_bits (int): Number of bits to use for quantizing weights. Default: 8. | ||||
| act_bits (int): Number of bits to use for quantizing activations. Default: 8. | act_bits (int): Number of bits to use for quantizing activations. Default: 8. | ||||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | ||||