Browse Source

!1965 format init

Merge pull request !1965 from chenzhongming/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
6cfdfb0e90
3 changed files with 4 additions and 12 deletions
  1. +3
    -5
      mindspore/nn/layer/quant.py
  2. +0
    -6
      mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
  3. +1
    -1
      mindspore/train/quant/quant.py

+ 3
- 5
mindspore/nn/layer/quant.py View File

@@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell):
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
@@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
self.is_ascend = context.get_context('device_target') == "Ascend"

# init tensor min and max for fake quant op
@@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell):
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.ema:
if self.training:
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
@@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell):
return s

def construct(self, x):
if self.ema and self.is_ascend:
if self.is_ascend and self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)


+ 0
- 6
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py View File

@@ -38,12 +38,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.input(3, "running_std", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.output(1, "d_batch_std", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


+ 1
- 1
mindspore/train/quant/quant.py View File

@@ -247,7 +247,7 @@ def convert_quant_network(network,
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0.
freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0.
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.


Loading…
Cancel
Save