Browse Source

fake quant debug

tags/v0.3.1-alpha
chenzomi 5 years ago
parent
commit
5a26546b56
2 changed files with 7 additions and 17 deletions
  1. +7
    -11
      mindspore/nn/layer/quant.py
  2. +0
    -6
      mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py

+ 7
- 11
mindspore/nn/layer/quant.py View File

@@ -179,23 +179,19 @@ class FakeQuantWithMinMax(Cell):
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
narrow_range=self.narrow_range)
else:
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.ema:
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
narrow_range=self.narrow_range)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range)

def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \


+ 0
- 6
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py View File

@@ -38,12 +38,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.input(3, "running_std", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.output(1, "d_batch_std", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


Loading…
Cancel
Save