Browse Source

add support to reduceAny and reduceAll on gpu

tags/v1.1.0
zhouyuanshen 5 years ago
parent
commit
e9aca01620
6 changed files with 211 additions and 11 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc
  2. +12
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h
  3. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h
  4. +3
    -0
      mindspore/ccsrc/utils/utils.h
  5. +94
    -0
      tests/st/ops/gpu/test_reduce_all_op.py
  6. +94
    -0
      tests/st/ops/gpu/test_reduce_any_op.py

+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc View File

@@ -34,5 +34,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).A
ArrayReduceGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArrayReduceGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArrayReduceGpuKernel, bool)
MS_REG_GPU_KERNEL_ONE(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArrayReduceGpuKernel, bool)
} // namespace kernel
} // namespace mindspore

+ 12
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h View File

@@ -27,10 +27,9 @@
namespace mindspore {
namespace kernel {
const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
{"ReduceMax", CUDNN_REDUCE_TENSOR_MAX},
{"ReduceMean", CUDNN_REDUCE_TENSOR_AVG},
{"ReduceSum", CUDNN_REDUCE_TENSOR_ADD},
{"ReduceMin", CUDNN_REDUCE_TENSOR_MIN},
{"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG},
{"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN},
{"ReduceAny", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceAll", CUDNN_REDUCE_TENSOR_MUL},
};
template <typename T>
class ArrayReduceGpuKernel : public GpuKernel {
@@ -72,7 +71,14 @@ class ArrayReduceGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto type_id = TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if ((node_name == kReduceAnyOpName || node_name == kReduceAllOpName) &&
std::strncmp(type_id, "kNumberTypeBool", std::strlen(type_id)) != 0) {
MS_LOG(ERROR) << "Input data type of ReduceAny or ReduceAll should be bool, but got " << type_id;
return false;
}
data_type_ = GetCudnnDataType(type_id);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs.";
@@ -185,9 +191,8 @@ class ArrayReduceGpuKernel : public GpuKernel {
auto iter = kReduceTypeMap.find(kernel_name);
if (iter == kReduceTypeMap.end()) {
MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported.";
} else {
reduce_tensor_op_ = iter->second;
}
reduce_tensor_op_ = iter->second;

CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h View File

@@ -43,10 +43,10 @@ static constexpr char kAvgPoolingModeLowerCase[] = "avg";
static constexpr float kSignedMinFloat = -3.402823466e+38F;

// Used by mixprecision, cudnn dtype select
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT},
{"kNumberTypeFloat16", CUDNN_DATA_HALF},
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE},
{"kNumberTypeInt32", CUDNN_DATA_INT32}};
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {
{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF},
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32},
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}};
// Used by mixprecision, cuda dtype select
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},
{"kNumberTypeFloat16", CUDA_R_16F}};


+ 3
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -225,6 +225,9 @@ constexpr auto kSelectOpName = "Select";
constexpr auto kReduceSumOpName = "ReduceSum";
constexpr auto kReduceMinOpName = "ReduceMin";
constexpr auto kReduceMaxOpName = "ReduceMax";
constexpr auto kReduceMeanOpName = "ReduceMean";
constexpr auto kReduceAnyOpName = "ReduceAny";
constexpr auto kReduceAllOpName = "ReduceAll";
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";


+ 94
- 0
tests/st/ops/gpu/test_reduce_all_op.py View File

@@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P

x0 = np.array([[True, True], [True, False], [False, False]])
axis0 = 0
keep_dims0 = True

x1 = np.array([[True, True], [True, False], [False, False]])
axis1 = 0
keep_dims1 = False

x2 = np.array([[True, True], [True, False], [False, False]])
axis2 = 1
keep_dims2 = True

x3 = np.array([[True, True], [True, False], [False, False]])
axis3 = 1
keep_dims3 = False

context.set_context(device_target='GPU')


class ReduceAll(nn.Cell):
def __init__(self):
super(ReduceAll, self).__init__()

self.x0 = Tensor(x0)
self.axis0 = axis0
self.keep_dims0 = keep_dims0

self.x1 = Tensor(x1)
self.axis1 = axis1
self.keep_dims1 = keep_dims1

self.x2 = Tensor(x2)
self.axis2 = axis2
self.keep_dims2 = keep_dims2

self.x3 = Tensor(x3)
self.axis3 = axis3
self.keep_dims3 = keep_dims3


@ms_function
def construct(self):
return (P.ReduceAll(self.keep_dims0)(self.x0, self.axis0),
P.ReduceAll(self.keep_dims1)(self.x1, self.axis1),
P.ReduceAll(self.keep_dims2)(self.x2, self.axis2),
P.ReduceAll(self.keep_dims3)(self.x3, self.axis3))


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceAll():
reduce_all = ReduceAll()
output = reduce_all()

expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0)
np.allclose(output[0].asnumpy(), expect0)
assert output[0].shape == expect0.shape

expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1)
np.allclose(output[1].asnumpy(), expect1)
assert output[1].shape == expect1.shape

expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2)
np.allclose(output[2].asnumpy(), expect2)
assert output[2].shape == expect2.shape

expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3)
np.allclose(output[3].asnumpy(), expect3)
assert output[3].shape == expect3.shape

+ 94
- 0
tests/st/ops/gpu/test_reduce_any_op.py View File

@@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P

x0 = np.array([[True, True], [True, False], [False, False]])
axis0 = 0
keep_dims0 = True

x1 = np.array([[True, True], [True, False], [False, False]])
axis1 = 0
keep_dims1 = False

x2 = np.array([[True, True], [True, False], [False, False]])
axis2 = 1
keep_dims2 = True

x3 = np.array([[True, True], [True, False], [False, False]])
axis3 = 1
keep_dims3 = False

context.set_context(device_target='GPU')


class ReduceAny(nn.Cell):
def __init__(self):
super(ReduceAny, self).__init__()

self.x0 = Tensor(x0)
self.axis0 = axis0
self.keep_dims0 = keep_dims0

self.x1 = Tensor(x1)
self.axis1 = axis1
self.keep_dims1 = keep_dims1

self.x2 = Tensor(x2)
self.axis2 = axis2
self.keep_dims2 = keep_dims2

self.x3 = Tensor(x3)
self.axis3 = axis3
self.keep_dims3 = keep_dims3


@ms_function
def construct(self):
return (P.ReduceAny(self.keep_dims0)(self.x0, self.axis0),
P.ReduceAny(self.keep_dims1)(self.x1, self.axis1),
P.ReduceAny(self.keep_dims2)(self.x2, self.axis2),
P.ReduceAny(self.keep_dims3)(self.x3, self.axis3))


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ReduceAny():
reduce_any = ReduceAny()
output = reduce_any()

expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0)
np.allclose(output[0].asnumpy(), expect0)
assert output[0].shape == expect0.shape

expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1)
np.allclose(output[1].asnumpy(), expect1)
assert output[1].shape == expect1.shape

expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2)
np.allclose(output[2].asnumpy(), expect2)
assert output[2].shape == expect2.shape

expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3)
np.allclose(output[3].asnumpy(), expect3)
assert output[3].shape == expect3.shape

Loading…
Cancel
Save