From e9aca0162019ffbf4a85c76f7360d94e83c83e25 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Tue, 8 Dec 2020 21:30:57 +0800 Subject: [PATCH] add support to reduceAny and reduceAll on gpu --- .../gpu/arrays/array_reduce_gpu_kernel.cc | 4 + .../gpu/arrays/array_reduce_gpu_kernel.h | 19 ++-- .../kernel_compiler/gpu/kernel_constants.h | 8 +- mindspore/ccsrc/utils/utils.h | 3 + tests/st/ops/gpu/test_reduce_all_op.py | 94 +++++++++++++++++++ tests/st/ops/gpu/test_reduce_any_op.py | 94 +++++++++++++++++++ 6 files changed, 211 insertions(+), 11 deletions(-) create mode 100644 tests/st/ops/gpu/test_reduce_all_op.py create mode 100644 tests/st/ops/gpu/test_reduce_any_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc index 3e7cb788ea..9af961647b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc @@ -34,5 +34,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).A ArrayReduceGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ArrayReduceGpuKernel, bool) +MS_REG_GPU_KERNEL_ONE(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ArrayReduceGpuKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index 48cd0cd36c..49d2ef9b8c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -27,10 +27,9 @@ namespace mindspore { namespace kernel { const std::map kReduceTypeMap = { - {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, - {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, - {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, - {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, + {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, + {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, + {"ReduceAny", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceAll", CUDNN_REDUCE_TENSOR_MUL}, }; template class ArrayReduceGpuKernel : public GpuKernel { @@ -72,7 +71,14 @@ class ArrayReduceGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { kernel_node_ = kernel_node; InitResource(); - data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto type_id = TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if ((node_name == kReduceAnyOpName || node_name == kReduceAllOpName) && + std::strncmp(type_id, "kNumberTypeBool", std::strlen(type_id)) != 0) { + MS_LOG(ERROR) << "Input data type of ReduceAny or ReduceAll should be bool, but got " << type_id; + return false; + } + data_type_ = GetCudnnDataType(type_id); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; @@ -185,9 +191,8 @@ class ArrayReduceGpuKernel : public GpuKernel { auto iter = kReduceTypeMap.find(kernel_name); if (iter == kReduceTypeMap.end()) { MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; - } else { - reduce_tensor_op_ = iter->second; } + reduce_tensor_op_ = iter->second; CHECK_CUDNN_RET_WITH_EXCEPT( kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h index 9dce244774..8648cfac6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h @@ -43,10 +43,10 @@ static constexpr char kAvgPoolingModeLowerCase[] = "avg"; static constexpr float kSignedMinFloat = -3.402823466e+38F; // Used by mixprecision, cudnn dtype select -static std::map kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, - {"kNumberTypeFloat16", CUDNN_DATA_HALF}, - {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, - {"kNumberTypeInt32", CUDNN_DATA_INT32}}; +static std::map kCudnnDtypeMap = { + {"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, + {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}, + {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}}; // Used by mixprecision, cuda dtype select static std::map kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, {"kNumberTypeFloat16", CUDA_R_16F}}; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index dd50c95c53..0478f4145b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -225,6 +225,9 @@ constexpr auto kSelectOpName = "Select"; constexpr auto kReduceSumOpName = "ReduceSum"; constexpr auto kReduceMinOpName = "ReduceMin"; constexpr auto kReduceMaxOpName = "ReduceMax"; +constexpr auto kReduceMeanOpName = "ReduceMean"; +constexpr auto kReduceAnyOpName = "ReduceAny"; +constexpr auto kReduceAllOpName = "ReduceAll"; constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; diff --git a/tests/st/ops/gpu/test_reduce_all_op.py b/tests/st/ops/gpu/test_reduce_all_op.py new file mode 100644 index 0000000000..2b87063bb9 --- /dev/null +++ b/tests/st/ops/gpu/test_reduce_all_op.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.array([[True, True], [True, False], [False, False]]) +axis0 = 0 +keep_dims0 = True + +x1 = np.array([[True, True], [True, False], [False, False]]) +axis1 = 0 +keep_dims1 = False + +x2 = np.array([[True, True], [True, False], [False, False]]) +axis2 = 1 +keep_dims2 = True + +x3 = np.array([[True, True], [True, False], [False, False]]) +axis3 = 1 +keep_dims3 = False + +context.set_context(device_target='GPU') + + +class ReduceAll(nn.Cell): + def __init__(self): + super(ReduceAll, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + self.keep_dims0 = keep_dims0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + self.keep_dims1 = keep_dims1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + self.keep_dims2 = keep_dims2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + self.keep_dims3 = keep_dims3 + + + @ms_function + def construct(self): + return (P.ReduceAll(self.keep_dims0)(self.x0, self.axis0), + P.ReduceAll(self.keep_dims1)(self.x1, self.axis1), + P.ReduceAll(self.keep_dims2)(self.x2, self.axis2), + P.ReduceAll(self.keep_dims3)(self.x3, self.axis3)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ReduceAll(): + reduce_all = ReduceAll() + output = reduce_all() + + expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0) + np.allclose(output[0].asnumpy(), expect0) + assert output[0].shape == expect0.shape + + expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1) + np.allclose(output[1].asnumpy(), expect1) + assert output[1].shape == expect1.shape + + expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2) + np.allclose(output[2].asnumpy(), expect2) + assert output[2].shape == expect2.shape + + expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3) + np.allclose(output[3].asnumpy(), expect3) + assert output[3].shape == expect3.shape diff --git a/tests/st/ops/gpu/test_reduce_any_op.py b/tests/st/ops/gpu/test_reduce_any_op.py new file mode 100644 index 0000000000..c2e579bc40 --- /dev/null +++ b/tests/st/ops/gpu/test_reduce_any_op.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.array([[True, True], [True, False], [False, False]]) +axis0 = 0 +keep_dims0 = True + +x1 = np.array([[True, True], [True, False], [False, False]]) +axis1 = 0 +keep_dims1 = False + +x2 = np.array([[True, True], [True, False], [False, False]]) +axis2 = 1 +keep_dims2 = True + +x3 = np.array([[True, True], [True, False], [False, False]]) +axis3 = 1 +keep_dims3 = False + +context.set_context(device_target='GPU') + + +class ReduceAny(nn.Cell): + def __init__(self): + super(ReduceAny, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + self.keep_dims0 = keep_dims0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + self.keep_dims1 = keep_dims1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + self.keep_dims2 = keep_dims2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + self.keep_dims3 = keep_dims3 + + + @ms_function + def construct(self): + return (P.ReduceAny(self.keep_dims0)(self.x0, self.axis0), + P.ReduceAny(self.keep_dims1)(self.x1, self.axis1), + P.ReduceAny(self.keep_dims2)(self.x2, self.axis2), + P.ReduceAny(self.keep_dims3)(self.x3, self.axis3)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ReduceAny(): + reduce_any = ReduceAny() + output = reduce_any() + + expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0) + np.allclose(output[0].asnumpy(), expect0) + assert output[0].shape == expect0.shape + + expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1) + np.allclose(output[1].asnumpy(), expect1) + assert output[1].shape == expect1.shape + + expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2) + np.allclose(output[2].asnumpy(), expect2) + assert output[2].shape == expect2.shape + + expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3) + np.allclose(output[3].asnumpy(), expect3) + assert output[3].shape == expect3.shape