Browse Source

!1900 bug fix in fake quant ops

Merge pull request !1900 from chenzhongming/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
31ecc13b59
1 changed files with 17 additions and 3 deletions
  1. +17
    -3
      mindspore/ops/operations/_quant_ops.py

+ 17
- 3
mindspore/ops/operations/_quant_ops.py View File

@@ -15,6 +15,7 @@

"""Operators for quantization."""

import mindspore.context as context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
@@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
narrow_range=False,
training=True):
"""init FakeQuantPerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
quant_delay=0,
symmetric=False,
narrow_range=False):
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
@@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
outputs=['out'])
@@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul_grad
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma'])
@@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")


Loading…
Cancel
Save