From e9a67efc6b277cc835c73b2d9871c09e5eb08c2b Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 2 Jun 2020 09:24:56 +0800 Subject: [PATCH] fix bug in fake quant grad --- .../gpu/quant/fake_quant_grad_gpu_kernel.cc | 19 ++++++++++++++++--- .../gpu/quant/fake_quant_grad_gpu_kernel.h | 2 ++ .../fake_quant_with_min_max_grad.py | 18 ++++++++++++++---- .../fake_quant_with_min_max_update.py | 10 +++++----- mindspore/ops/operations/_quant_ops.py | 18 +++++++++++------- 5 files changed, 48 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc index 7b7e3f1737..d92696d1bd 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc @@ -30,7 +30,9 @@ FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() quant_max_(0), quant_size_(0), quant_delay_(0), - global_step_(0) {} + global_step_(0), + narrow_range_(false), + symmetric_(false) {} const std::vector &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } @@ -59,8 +61,19 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; } - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + if (symmetric_) { + quant_min_ = 0 - (1 << (num_bits_ - 1)); + quant_max_ = (1 << (num_bits_ - 1)) - 1; + } else { + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + } + + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + if (narrow_range_) { + quant_min_++; + } if (quant_size_ == 0) { quant_size_ = 1; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h index 04c505d2bd..cfde98355c 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h @@ -54,6 +54,8 @@ class FakeQuantGradGpuKernel : public GpuKernel { int quant_size_; int quant_delay_; int global_step_; + bool narrow_range_; + bool symmetric_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py index be5dcb6591..5137f7c42b 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py @@ -35,6 +35,8 @@ fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ .partial_flag(True) \ .attr("num_bits", "optional", "int", "all") \ .attr("quant_delay", "optional", "int", "all") \ + .attr("symmetric", "optional", "bool", "all") \ + .attr("narrow_range", "optional", "bool", "all") \ .input(0, "dout", None, "required", None) \ .input(1, "x", None, "required", None) \ .input(2, "min", None, "required", None) \ @@ -104,8 +106,9 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q return res -@util.check_input_type(dict, dict, dict, dict, dict, int, int, str) -def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay, +@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) +def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, + num_bits, quant_delay, symmetric, narrow_range, kernel_name="fake_quant_with_min_max_grad"): """FakeQuantWithMinMaxGrad""" input_shape = x.get("shape") @@ -136,8 +139,15 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_ input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - quant_min = 0 - quant_max = 2 ** num_bits - 1 + if symmetric: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + if narrow_range: + quant_min = quant_min + 1 + dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py index e5c932aa0f..58eeeda9fb 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py @@ -23,10 +23,10 @@ from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ +fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_with_min_max_update5d.so") \ + .binfile_name("fake_quant_with_min_max_update.so") \ .compute_cost(10) \ .kernel_name("fake_quant_with_min_max_update") \ .partial_flag(True) \ @@ -47,9 +47,9 @@ fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ .get_op_info() -@op_info_register(fake_quant_update5d_op_info) -def _fake_quant_update5d_tbe(): - """_FakeQuantWithMinMaxUpdate5D TBE register""" +@op_info_register(fake_quant_update_op_info) +def _fake_quant_update_tbe(): + """_FakeQuantWithMinMaxUpdate TBE register""" return diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 705968be65..af7b979392 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -113,15 +113,17 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): >>> _max = Tensor(np.array([2]), mindspore.float32) >>> result = fake_min_max_grad(dout, input_x, _min, _max) """ - support_quant_bit = [4, 8] + support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, quant_delay=0): + def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): if num_bits not in self.support_quant_bit: raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): @@ -169,7 +171,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) >>> result = fake_quant(input_x, _min, _max) """ - support_quant_bit = [4, 8] + support_quant_bit = [4, 7, 8] channel_axis = 0 @prim_attr_register @@ -216,16 +218,18 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) >>> result = fqmmpc_grad(dout, input_x, _min, _max) """ - support_quant_bit = [4, 8] + support_quant_bit = [4, 7, 8] @prim_attr_register - def __init__(self, num_bits=8, quant_delay=0): + def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): """init FakeQuantWithMinMaxPerChannel Fill""" if num_bits not in self.support_quant_bit: raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):