| @@ -15,6 +15,7 @@ | |||
| """Operators for quantization.""" | |||
| import mindspore.context as context | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| @@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||
| narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantPerLayer OP""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_perlayer | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| @@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| @@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||
| training=True, | |||
| channel_axis=1): | |||
| """init FakeQuantPerChannel OP""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_perchannel | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' Attr \'num_bits\' is not support.") | |||
| @@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): | |||
| narrow_range=False, | |||
| channel_axis=1): | |||
| """init FakeQuantPerChannelGrad Fill""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| @@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, channel_axis=0): | |||
| """init correction mul layer""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import correction_mul | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], | |||
| outputs=['out']) | |||
| @@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, channel_axis=0): | |||
| """init correction mul layer""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import correction_mul_grad | |||
| self.channel_axis = channel_axis | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], | |||
| outputs=['dx', 'd_gamma']) | |||
| @@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantMinMaxPerLayerUpdate OP""" | |||
| from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| @@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||
| training=True, channel_axis=1): | |||
| """init FakeQuantPerChannelUpdate OP for Ascend""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||