Browse Source

fix the problem that failed to obtain quant op info

pull/15976/head
Erpim 4 years ago
parent
commit
19c18eafba
2 changed files with 16 additions and 0 deletions
  1. +13
    -0
      mindspore/ops/_op_impl/_custom_op/__init__.py
  2. +3
    -0
      mindspore/ops/operations/_quant_ops.py

+ 13
- 0
mindspore/ops/_op_impl/_custom_op/__init__.py View File

@@ -14,9 +14,22 @@
# ============================================================================

"""custom ops"""
from .batchnorm_fold import _batchnorm_fold_tbe
from .batchnorm_fold2 import _batchnorm_fold2_tbe
from .batchnorm_fold2_grad import _batchnorm_fold2_grad_tbe
from .batchnorm_fold2_grad_reduce import _batchnorm_fold2_grad_reduce_tbe
from .batchnorm_fold_grad import _batchnorm_fold_grad_tbe
from .correction_mul import _correction_mul_tbe
from .correction_mul_grad import _correction_mul_grad_tbe
from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe
from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe
from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe
from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe
from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe
from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe
from .fake_quant_perchannel import _fake_quant_perchannel_tbe
from .fake_quant_perchannel_grad import _fake_quant_perchannel_grad_tbe
from .fake_quant_perlayer import _fake_quant_per_layer_tbe
from .fake_quant_perlayer_grad import _fake_quant_per_layer_grad_tbe
from .minmax_update_perchannel import _minmax_update_perchannel_tbe
from .minmax_update_perlayer import _minmax_update_perlayer_tbe

+ 3
- 0
mindspore/ops/operations/_quant_ops.py View File

@@ -22,6 +22,9 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype

if context.get_context('device_target') == "Ascend":
import mindspore.ops._op_impl._custom_op

__all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeLearnedScaleQuantPerLayer",


Loading…
Cancel
Save