diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index e6ee62f6ec..fad50318a9 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -50,7 +50,7 @@ def _fake_quant_perchannel_tbe(): @fusion_manager.register("fake_quant_perchannel") -def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, +def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric, kernel_name="fake_quant_perchannel"): """FakeQuantPerChannel""" x_shape = te.lang.cce.util.shape_to_list(x.shape) @@ -59,6 +59,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, quant_max = tvm.const(quant_max, x.dtype) quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) + if symmetric: + max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) + min_val = te.lang.cce.vmuls(max_val, -1.) scale = te.lang.cce.vdiv(te.lang.cce.vsub( max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) @@ -119,12 +122,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 + quant_min = 0 + quant_max = 2 ** num_bits - 1 if narrow_range: quant_min = quant_min + 1 @@ -136,7 +135,7 @@ def fake_quant_perchannel(x, min_val, max_val, y, min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) res = fake_quant_perchannel_compute(input_data, min_data, max_data, y, - quant_min, quant_max, kernel_name) + quant_min, quant_max, symmetric, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res)