| @@ -34,5 +34,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).A | |||||
| ArrayReduceGpuKernel, float) | ArrayReduceGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ArrayReduceGpuKernel, half) | ArrayReduceGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| ArrayReduceGpuKernel, bool) | |||||
| MS_REG_GPU_KERNEL_ONE(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| ArrayReduceGpuKernel, bool) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = { | const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = { | ||||
| {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, | |||||
| {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, | |||||
| {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, | |||||
| {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, | |||||
| {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, | |||||
| {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, | |||||
| {"ReduceAny", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceAll", CUDNN_REDUCE_TENSOR_MUL}, | |||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| class ArrayReduceGpuKernel : public GpuKernel { | class ArrayReduceGpuKernel : public GpuKernel { | ||||
| @@ -72,7 +71,14 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| kernel_node_ = kernel_node; | kernel_node_ = kernel_node; | ||||
| InitResource(); | InitResource(); | ||||
| data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto type_id = TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); | |||||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| if ((node_name == kReduceAnyOpName || node_name == kReduceAllOpName) && | |||||
| std::strncmp(type_id, "kNumberTypeBool", std::strlen(type_id)) != 0) { | |||||
| MS_LOG(ERROR) << "Input data type of ReduceAny or ReduceAll should be bool, but got " << type_id; | |||||
| return false; | |||||
| } | |||||
| data_type_ = GetCudnnDataType(type_id); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; | ||||
| @@ -185,9 +191,8 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||||
| auto iter = kReduceTypeMap.find(kernel_name); | auto iter = kReduceTypeMap.find(kernel_name); | ||||
| if (iter == kReduceTypeMap.end()) { | if (iter == kReduceTypeMap.end()) { | ||||
| MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; | MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; | ||||
| } else { | |||||
| reduce_tensor_op_ = iter->second; | |||||
| } | } | ||||
| reduce_tensor_op_ = iter->second; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| kernel_node_, | kernel_node_, | ||||
| @@ -43,10 +43,10 @@ static constexpr char kAvgPoolingModeLowerCase[] = "avg"; | |||||
| static constexpr float kSignedMinFloat = -3.402823466e+38F; | static constexpr float kSignedMinFloat = -3.402823466e+38F; | ||||
| // Used by mixprecision, cudnn dtype select | // Used by mixprecision, cudnn dtype select | ||||
| static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, | |||||
| {"kNumberTypeFloat16", CUDNN_DATA_HALF}, | |||||
| {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, | |||||
| {"kNumberTypeInt32", CUDNN_DATA_INT32}}; | |||||
| static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = { | |||||
| {"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, | |||||
| {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}, | |||||
| {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}}; | |||||
| // Used by mixprecision, cuda dtype select | // Used by mixprecision, cuda dtype select | ||||
| static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, | static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, | ||||
| {"kNumberTypeFloat16", CUDA_R_16F}}; | {"kNumberTypeFloat16", CUDA_R_16F}}; | ||||
| @@ -225,6 +225,9 @@ constexpr auto kSelectOpName = "Select"; | |||||
| constexpr auto kReduceSumOpName = "ReduceSum"; | constexpr auto kReduceSumOpName = "ReduceSum"; | ||||
| constexpr auto kReduceMinOpName = "ReduceMin"; | constexpr auto kReduceMinOpName = "ReduceMin"; | ||||
| constexpr auto kReduceMaxOpName = "ReduceMax"; | constexpr auto kReduceMaxOpName = "ReduceMax"; | ||||
| constexpr auto kReduceMeanOpName = "ReduceMean"; | |||||
| constexpr auto kReduceAnyOpName = "ReduceAny"; | |||||
| constexpr auto kReduceAllOpName = "ReduceAll"; | |||||
| constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | ||||
| constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | ||||
| constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | ||||
| @@ -0,0 +1,94 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| x0 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis0 = 0 | |||||
| keep_dims0 = True | |||||
| x1 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis1 = 0 | |||||
| keep_dims1 = False | |||||
| x2 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis2 = 1 | |||||
| keep_dims2 = True | |||||
| x3 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis3 = 1 | |||||
| keep_dims3 = False | |||||
| context.set_context(device_target='GPU') | |||||
| class ReduceAll(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ReduceAll, self).__init__() | |||||
| self.x0 = Tensor(x0) | |||||
| self.axis0 = axis0 | |||||
| self.keep_dims0 = keep_dims0 | |||||
| self.x1 = Tensor(x1) | |||||
| self.axis1 = axis1 | |||||
| self.keep_dims1 = keep_dims1 | |||||
| self.x2 = Tensor(x2) | |||||
| self.axis2 = axis2 | |||||
| self.keep_dims2 = keep_dims2 | |||||
| self.x3 = Tensor(x3) | |||||
| self.axis3 = axis3 | |||||
| self.keep_dims3 = keep_dims3 | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return (P.ReduceAll(self.keep_dims0)(self.x0, self.axis0), | |||||
| P.ReduceAll(self.keep_dims1)(self.x1, self.axis1), | |||||
| P.ReduceAll(self.keep_dims2)(self.x2, self.axis2), | |||||
| P.ReduceAll(self.keep_dims3)(self.x3, self.axis3)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ReduceAll(): | |||||
| reduce_all = ReduceAll() | |||||
| output = reduce_all() | |||||
| expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0) | |||||
| np.allclose(output[0].asnumpy(), expect0) | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1) | |||||
| np.allclose(output[1].asnumpy(), expect1) | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2) | |||||
| np.allclose(output[2].asnumpy(), expect2) | |||||
| assert output[2].shape == expect2.shape | |||||
| expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3) | |||||
| np.allclose(output[3].asnumpy(), expect3) | |||||
| assert output[3].shape == expect3.shape | |||||
| @@ -0,0 +1,94 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| x0 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis0 = 0 | |||||
| keep_dims0 = True | |||||
| x1 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis1 = 0 | |||||
| keep_dims1 = False | |||||
| x2 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis2 = 1 | |||||
| keep_dims2 = True | |||||
| x3 = np.array([[True, True], [True, False], [False, False]]) | |||||
| axis3 = 1 | |||||
| keep_dims3 = False | |||||
| context.set_context(device_target='GPU') | |||||
| class ReduceAny(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ReduceAny, self).__init__() | |||||
| self.x0 = Tensor(x0) | |||||
| self.axis0 = axis0 | |||||
| self.keep_dims0 = keep_dims0 | |||||
| self.x1 = Tensor(x1) | |||||
| self.axis1 = axis1 | |||||
| self.keep_dims1 = keep_dims1 | |||||
| self.x2 = Tensor(x2) | |||||
| self.axis2 = axis2 | |||||
| self.keep_dims2 = keep_dims2 | |||||
| self.x3 = Tensor(x3) | |||||
| self.axis3 = axis3 | |||||
| self.keep_dims3 = keep_dims3 | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return (P.ReduceAny(self.keep_dims0)(self.x0, self.axis0), | |||||
| P.ReduceAny(self.keep_dims1)(self.x1, self.axis1), | |||||
| P.ReduceAny(self.keep_dims2)(self.x2, self.axis2), | |||||
| P.ReduceAny(self.keep_dims3)(self.x3, self.axis3)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ReduceAny(): | |||||
| reduce_any = ReduceAny() | |||||
| output = reduce_any() | |||||
| expect0 = np.all(x0, axis=axis0, keepdims=keep_dims0) | |||||
| np.allclose(output[0].asnumpy(), expect0) | |||||
| assert output[0].shape == expect0.shape | |||||
| expect1 = np.all(x1, axis=axis1, keepdims=keep_dims1) | |||||
| np.allclose(output[1].asnumpy(), expect1) | |||||
| assert output[1].shape == expect1.shape | |||||
| expect2 = np.all(x2, axis=axis2, keepdims=keep_dims2) | |||||
| np.allclose(output[2].asnumpy(), expect2) | |||||
| assert output[2].shape == expect2.shape | |||||
| expect3 = np.all(x3, axis=axis3, keepdims=keep_dims3) | |||||
| np.allclose(output[3].asnumpy(), expect3) | |||||
| assert output[3].shape == expect3.shape | |||||