Merge pull request !1784 from SanjayChan/fakequant_bug_fixtags/v0.5.0-beta
| @@ -30,7 +30,9 @@ FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() | |||
| quant_max_(0), | |||
| quant_size_(0), | |||
| quant_delay_(0), | |||
| global_step_(0) {} | |||
| global_step_(0), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| @@ -59,8 +61,19 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; | |||
| } | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| if (symmetric_) { | |||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||
| } else { | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| } | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| if (quant_size_ == 0) { | |||
| quant_size_ = 1; | |||
| @@ -54,6 +54,8 @@ class FakeQuantGradGpuKernel : public GpuKernel { | |||
| int quant_size_; | |||
| int quant_delay_; | |||
| int global_step_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,8 @@ fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ | |||
| .partial_flag(True) \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("quant_delay", "optional", "int", "all") \ | |||
| .attr("symmetric", "optional", "bool", "all") \ | |||
| .attr("narrow_range", "optional", "bool", "all") \ | |||
| .input(0, "dout", None, "required", None) \ | |||
| .input(1, "x", None, "required", None) \ | |||
| .input(2, "min", None, "required", None) \ | |||
| @@ -104,8 +106,9 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q | |||
| return res | |||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, str) | |||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay, | |||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) | |||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, | |||
| num_bits, quant_delay, symmetric, narrow_range, | |||
| kernel_name="fake_quant_with_min_max_grad"): | |||
| """FakeQuantWithMinMaxGrad""" | |||
| input_shape = x.get("shape") | |||
| @@ -136,8 +139,15 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_ | |||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | |||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if symmetric: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| else: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype) | |||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | |||
| @@ -23,10 +23,10 @@ from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||
| fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_with_min_max_update5d.so") \ | |||
| .binfile_name("fake_quant_with_min_max_update.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_with_min_max_update") \ | |||
| .partial_flag(True) \ | |||
| @@ -47,9 +47,9 @@ fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||
| .get_op_info() | |||
| @op_info_register(fake_quant_update5d_op_info) | |||
| def _fake_quant_update5d_tbe(): | |||
| """_FakeQuantWithMinMaxUpdate5D TBE register""" | |||
| @op_info_register(fake_quant_update_op_info) | |||
| def _fake_quant_update_tbe(): | |||
| """_FakeQuantWithMinMaxUpdate TBE register""" | |||
| return | |||
| @@ -116,15 +116,17 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||
| >>> _max = Tensor(np.array([2]), mindspore.float32) | |||
| >>> result = fake_min_max_grad(dout, input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, quant_delay=0): | |||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | |||
| @@ -172,7 +174,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | |||
| >>> result = fake_quant(input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| support_quant_bit = [4, 7, 8] | |||
| channel_axis = 0 | |||
| @prim_attr_register | |||
| @@ -219,16 +221,18 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||
| >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) | |||
| >>> result = fqmmpc_grad(dout, input_x, _min, _max) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, quant_delay=0): | |||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||
| """init FakeQuantWithMinMaxPerChannel Fill""" | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | |||