|
|
|
@@ -50,7 +50,7 @@ def _fake_quant_perchannel_tbe(): |
|
|
|
|
|
|
|
|
|
|
|
@fusion_manager.register("fake_quant_perchannel") |
|
|
|
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, |
|
|
|
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric, |
|
|
|
kernel_name="fake_quant_perchannel"): |
|
|
|
"""FakeQuantPerChannel""" |
|
|
|
x_shape = te.lang.cce.util.shape_to_list(x.shape) |
|
|
|
@@ -59,6 +59,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, |
|
|
|
quant_max = tvm.const(quant_max, x.dtype) |
|
|
|
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) |
|
|
|
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) |
|
|
|
if symmetric: |
|
|
|
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) |
|
|
|
min_val = te.lang.cce.vmuls(max_val, -1.) |
|
|
|
|
|
|
|
scale = te.lang.cce.vdiv(te.lang.cce.vsub( |
|
|
|
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) |
|
|
|
@@ -119,12 +122,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, |
|
|
|
util.check_dtype_rule(min_dtype, check_list) |
|
|
|
util.check_dtype_rule(max_dtype, check_list) |
|
|
|
|
|
|
|
if symmetric: |
|
|
|
quant_min = 0 - 2 ** (num_bits - 1) |
|
|
|
quant_max = 2 ** (num_bits - 1) - 1 |
|
|
|
else: |
|
|
|
quant_min = 0 |
|
|
|
quant_max = 2 ** num_bits - 1 |
|
|
|
quant_min = 0 |
|
|
|
quant_max = 2 ** num_bits - 1 |
|
|
|
if narrow_range: |
|
|
|
quant_min = quant_min + 1 |
|
|
|
|
|
|
|
@@ -136,7 +135,7 @@ def fake_quant_perchannel(x, min_val, max_val, y, |
|
|
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) |
|
|
|
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) |
|
|
|
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y, |
|
|
|
quant_min, quant_max, kernel_name) |
|
|
|
quant_min, quant_max, symmetric, kernel_name) |
|
|
|
|
|
|
|
with tvm.target.cce(): |
|
|
|
sch = generic.auto_schedule(res) |
|
|
|
|