|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
|
|
|
|
"""Operators for quantization.""" |
|
|
|
|
|
|
|
import mindspore.context as context |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
|
@@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): |
|
|
|
narrow_range=False, |
|
|
|
training=True): |
|
|
|
"""init FakeQuantPerLayer OP""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
@@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): |
|
|
|
quant_delay=0, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
@@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): |
|
|
|
training=True, |
|
|
|
channel_axis=1): |
|
|
|
"""init FakeQuantPerChannel OP""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' Attr \'num_bits\' is not support.") |
|
|
|
@@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): |
|
|
|
narrow_range=False, |
|
|
|
channel_axis=1): |
|
|
|
"""init FakeQuantPerChannelGrad Fill""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
@@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, channel_axis=0): |
|
|
|
"""init correction mul layer""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import correction_mul |
|
|
|
self.channel_axis = channel_axis |
|
|
|
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], |
|
|
|
outputs=['out']) |
|
|
|
@@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, channel_axis=0): |
|
|
|
"""init correction mul layer""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import correction_mul_grad |
|
|
|
self.channel_axis = channel_axis |
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], |
|
|
|
outputs=['dx', 'd_gamma']) |
|
|
|
@@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): |
|
|
|
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, |
|
|
|
training=True): |
|
|
|
"""init FakeQuantMinMaxPerLayerUpdate OP""" |
|
|
|
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
@@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): |
|
|
|
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, |
|
|
|
training=True, channel_axis=1): |
|
|
|
"""init FakeQuantPerChannelUpdate OP for Ascend""" |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update |
|
|
|
if num_bits not in self.support_quant_bit: |
|
|
|
raise ValueError( |
|
|
|
f"For '{self.name}' attr \'num_bits\' is not support.") |
|
|
|
|