Merge pull request !1784 from SanjayChan/fakequant_bug_fixtags/v0.5.0-beta
| @@ -30,7 +30,9 @@ FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() | |||||
| quant_max_(0), | quant_max_(0), | ||||
| quant_size_(0), | quant_size_(0), | ||||
| quant_delay_(0), | quant_delay_(0), | ||||
| global_step_(0) {} | |||||
| global_step_(0), | |||||
| narrow_range_(false), | |||||
| symmetric_(false) {} | |||||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | ||||
| @@ -59,8 +61,19 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; | MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; | ||||
| } | } | ||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| if (symmetric_) { | |||||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||||
| } else { | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| } | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| if (narrow_range_) { | |||||
| quant_min_++; | |||||
| } | |||||
| if (quant_size_ == 0) { | if (quant_size_ == 0) { | ||||
| quant_size_ = 1; | quant_size_ = 1; | ||||
| @@ -54,6 +54,8 @@ class FakeQuantGradGpuKernel : public GpuKernel { | |||||
| int quant_size_; | int quant_size_; | ||||
| int quant_delay_; | int quant_delay_; | ||||
| int global_step_; | int global_step_; | ||||
| bool narrow_range_; | |||||
| bool symmetric_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,8 @@ fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("num_bits", "optional", "int", "all") \ | .attr("num_bits", "optional", "int", "all") \ | ||||
| .attr("quant_delay", "optional", "int", "all") \ | .attr("quant_delay", "optional", "int", "all") \ | ||||
| .attr("symmetric", "optional", "bool", "all") \ | |||||
| .attr("narrow_range", "optional", "bool", "all") \ | |||||
| .input(0, "dout", None, "required", None) \ | .input(0, "dout", None, "required", None) \ | ||||
| .input(1, "x", None, "required", None) \ | .input(1, "x", None, "required", None) \ | ||||
| .input(2, "min", None, "required", None) \ | .input(2, "min", None, "required", None) \ | ||||
| @@ -104,8 +106,9 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q | |||||
| return res | return res | ||||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, str) | |||||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay, | |||||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) | |||||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, | |||||
| num_bits, quant_delay, symmetric, narrow_range, | |||||
| kernel_name="fake_quant_with_min_max_grad"): | kernel_name="fake_quant_with_min_max_grad"): | ||||
| """FakeQuantWithMinMaxGrad""" | """FakeQuantWithMinMaxGrad""" | ||||
| input_shape = x.get("shape") | input_shape = x.get("shape") | ||||
| @@ -136,8 +139,15 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_ | |||||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | ||||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | ||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | |||||
| quant_min = quant_min + 1 | |||||
| dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype) | dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype) | ||||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | ||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| @@ -23,10 +23,10 @@ from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_with_min_max_update5d.so") \ | |||||
| .binfile_name("fake_quant_with_min_max_update.so") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("fake_quant_with_min_max_update") \ | .kernel_name("fake_quant_with_min_max_update") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| @@ -47,9 +47,9 @@ fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(fake_quant_update5d_op_info) | |||||
| def _fake_quant_update5d_tbe(): | |||||
| """_FakeQuantWithMinMaxUpdate5D TBE register""" | |||||
| @op_info_register(fake_quant_update_op_info) | |||||
| def _fake_quant_update_tbe(): | |||||
| """_FakeQuantWithMinMaxUpdate TBE register""" | |||||
| return | return | ||||
| @@ -116,15 +116,17 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||||
| >>> _max = Tensor(np.array([2]), mindspore.float32) | >>> _max = Tensor(np.array([2]), mindspore.float32) | ||||
| >>> result = fake_min_max_grad(dout, input_x, _min, _max) | >>> result = fake_min_max_grad(dout, input_x, _min, _max) | ||||
| """ | """ | ||||
| support_quant_bit = [4, 8] | |||||
| support_quant_bit = [4, 7, 8] | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, quant_delay=0): | |||||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | ||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | ||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | ||||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | ||||
| @@ -172,7 +174,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||||
| >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | ||||
| >>> result = fake_quant(input_x, _min, _max) | >>> result = fake_quant(input_x, _min, _max) | ||||
| """ | """ | ||||
| support_quant_bit = [4, 8] | |||||
| support_quant_bit = [4, 7, 8] | |||||
| channel_axis = 0 | channel_axis = 0 | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -219,16 +221,18 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||||
| >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) | >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) | ||||
| >>> result = fqmmpc_grad(dout, input_x, _min, _max) | >>> result = fqmmpc_grad(dout, input_x, _min, _max) | ||||
| """ | """ | ||||
| support_quant_bit = [4, 8] | |||||
| support_quant_bit = [4, 7, 8] | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, quant_delay=0): | |||||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||||
| """init FakeQuantWithMinMaxPerChannel Fill""" | """init FakeQuantWithMinMaxPerChannel Fill""" | ||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | ||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | ||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | ||||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | ||||