diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index e1aa5630ba..6aa30ab2f3 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -15,6 +15,7 @@ """Operators for quantization.""" +import mindspore.context as context from ..._checkparam import Validator as validator from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): narrow_range=False, training=True): """init FakeQuantPerLayer OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_perlayer if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' attr \'num_bits\' is not support.") @@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): quant_delay=0, symmetric=False, narrow_range=False): + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' attr \'num_bits\' is not support.") @@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): training=True, channel_axis=1): """init FakeQuantPerChannel OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_perchannel if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' Attr \'num_bits\' is not support.") @@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): narrow_range=False, channel_axis=1): """init FakeQuantPerChannelGrad Fill""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' attr \'num_bits\' is not support.") @@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer): @prim_attr_register def __init__(self, channel_axis=0): """init correction mul layer""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import correction_mul self.channel_axis = channel_axis self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], outputs=['out']) @@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, channel_axis=0): """init correction mul layer""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import correction_mul_grad self.channel_axis = channel_axis self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], outputs=['dx', 'd_gamma']) @@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, training=True): """init FakeQuantMinMaxPerLayerUpdate OP""" - from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad - from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad - from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' attr \'num_bits\' is not support.") @@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, training=True, channel_axis=1): """init FakeQuantPerChannelUpdate OP for Ascend""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update if num_bits not in self.support_quant_bit: raise ValueError( f"For '{self.name}' attr \'num_bits\' is not support.")