Browse Source

!2839 support vm for PopulationCount

Merge pull request !2839 from jiangjinsheng/vm_population_count
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3d377c51b9
5 changed files with 77 additions and 3 deletions
  1. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  2. +38
    -0
      mindspore/ops/_op_impl/tbe/population_count.py
  3. +3
    -2
      mindspore/ops/operations/__init__.py
  4. +31
    -0
      mindspore/ops/operations/other_ops.py
  5. +4
    -1
      tests/ut/python/ops/test_ops.py

+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -284,3 +284,4 @@ from .scatter_div import _scatter_div_tbe
from .mod import _mod_tbe from .mod import _mod_tbe
from .max_pool_grad_grad import _max_pool_grad_grad_tbe from .max_pool_grad_grad import _max_pool_grad_grad_tbe
from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe
from .population_count import _population_count_tbe

+ 38
- 0
mindspore/ops/_op_impl/tbe/population_count.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""PopulationCount op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

population_count_op_info = TBERegOp("PopulationCount") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("population_count.so") \
.compute_cost(10) \
.kernel_name("population_count") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I16_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
.get_op_info()


@op_info_register(population_count_op_info)
def _population_count_tbe():
"""PopulationCount TBE register"""
return

+ 3
- 2
mindspore/ops/operations/__init__.py View File

@@ -76,7 +76,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop) CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
from .thor_ops import * from .thor_ops import *


@@ -328,7 +328,8 @@ __all__ = [
"InplaceUpdate", "InplaceUpdate",
"InTopK", "InTopK",
"LRN", "LRN",
"Mod"
"Mod",
"PopulationCount"
] ]


__all__.sort() __all__.sort()

+ 31
- 0
mindspore/ops/operations/other_ops.py View File

@@ -51,6 +51,7 @@ class Assign(PrimitiveWithInfer):
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
) )

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
@@ -324,6 +325,7 @@ class Partial(Primitive):
partial_func = functools.partial(func, *args[1:]) partial_func = functools.partial(func, *args[1:])
return partial_func return partial_func



class Depend(Primitive): class Depend(Primitive):
""" """
Depend is used for process side-effect operations. Depend is used for process side-effect operations.
@@ -457,3 +459,32 @@ class ConfusionMatrix(PrimitiveWithInfer):
args = {"labels": labels, "predictions": predictions} args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name) validator.check_tensor_type_same(args, (mstype.number_type), self.name)
return labels return labels


class PopulationCount(PrimitiveWithInfer):
r"""
Calculate population count.

Inputs:
- **input** (Tensor) - The data type should be int16 or uint16.

Outputs:
Tensor, with shape same as the input.

Examples:
>>> population_count = P.PopulationCount()
>>> x_input = Tensor([0, 1, 3], mindspore.int16)
>>> population_count(x_input)
"""

@prim_attr_register
def __init__(self):
pass

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
args = {"x": x_dtype}
validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8)

+ 4
- 1
tests/ut/python/ops/test_ops.py View File

@@ -2143,7 +2143,10 @@ test_case_other_ops = [
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))], Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),

('PopulationCount', {
'block': P.PopulationCount(),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))],
'skip': ['backward']}),
] ]


test_case_quant_ops = [ test_case_quant_ops = [


Loading…
Cancel
Save