From: @jonwe Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @tom__chentags/v1.1.0
| @@ -27,7 +27,7 @@ namespace kernel { | |||||
| template <typename S, typename T> | template <typename S, typename T> | ||||
| class CastGpuKernel : public GpuKernel { | class CastGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| CastGpuKernel() : input_size_(1), output_size_(1) {} | |||||
| CastGpuKernel() { ResetResource(); } | |||||
| ~CastGpuKernel() = default; | ~CastGpuKernel() = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -42,6 +42,7 @@ class CastGpuKernel : public GpuKernel { | |||||
| Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| @@ -62,6 +63,14 @@ class CastGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ResetResource() noexcept override { | |||||
| input_size_ = 1; | |||||
| output_size_ = 1; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | protected: | ||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(input_size_ * sizeof(T)); | input_size_list_.push_back(input_size_ * sizeof(T)); | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| __global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : static_cast<T>(0); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { | |||||
| CalReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, output_addr); | |||||
| return; | |||||
| } | |||||
| template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); | |||||
| template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | |||||
| template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | |||||
| template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ | |||||
| @@ -45,6 +45,7 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F; | |||||
| // Used by mixprecision, cudnn dtype select | // Used by mixprecision, cudnn dtype select | ||||
| static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, | static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, | ||||
| {"kNumberTypeFloat16", CUDNN_DATA_HALF}, | {"kNumberTypeFloat16", CUDNN_DATA_HALF}, | ||||
| {"kNumberTypeInt64", CUDNN_DATA_DOUBLE}, | |||||
| {"kNumberTypeInt32", CUDNN_DATA_INT32}}; | {"kNumberTypeInt32", CUDNN_DATA_INT32}}; | ||||
| // Used by mixprecision, cuda dtype select | // Used by mixprecision, cuda dtype select | ||||
| static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, | static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, | ||||
| @@ -22,6 +22,10 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||||
| ActivationGpuFwdKernel, float) | ActivationGpuFwdKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ActivationGpuFwdKernel, half) | ActivationGpuFwdKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ActivationGpuFwdKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| ActivationGpuFwdKernel, int64_t) | |||||
| MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ActivationGpuFwdKernel, float) | ActivationGpuFwdKernel, float) | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | #include "backend/kernel_compiler/gpu/kernel_constants.h" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -36,18 +37,23 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||
| const std::vector<AddressPtr> &outputs, void *) override { | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| const float alpha = 1; | |||||
| const float beta = 0; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, | |||||
| &beta, data_descriptor_, output), | |||||
| "cudnnActivationForward failed"); | |||||
| if (mode_ == CUDNN_ACTIVATION_RELU) { | |||||
| const int size = input_size_ / sizeof(T); | |||||
| CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } else { | |||||
| const float alpha = 1; | |||||
| const float beta = 0; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, | |||||
| input, &beta, data_descriptor_, output), | |||||
| "cudnnActivationForward failed"); | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -291,7 +291,7 @@ class Softsign(PrimitiveWithInfer): | |||||
| return input_x | return input_x | ||||
| class ReLU(PrimitiveWithInfer): | |||||
| class ReLU(PrimitiveWithCheck): | |||||
| r""" | r""" | ||||
| Computes ReLU (Rectified Linear Unit) of input tensors element-wise. | Computes ReLU (Rectified Linear Unit) of input tensors element-wise. | ||||
| @@ -320,12 +320,11 @@ class ReLU(PrimitiveWithInfer): | |||||
| """Initialize ReLU""" | """Initialize ReLU""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| def infer_shape(self, input_x): | |||||
| return input_x | |||||
| def check_shape(self, input_x): | |||||
| pass | |||||
| def infer_dtype(self, input_x): | |||||
| def check_dtype(self, input_x): | |||||
| validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) | validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) | ||||
| return input_x | |||||
| class ReLU6(PrimitiveWithInfer): | class ReLU6(PrimitiveWithInfer): | ||||
| @@ -21,6 +21,7 @@ import mindspore.context as context | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| class Net(Cell): | class Net(Cell): | ||||
| @@ -36,6 +37,22 @@ class Net(Cell): | |||||
| return output | return output | ||||
| class NetDynamic(Cell): | |||||
| def __init__(self, type0, type1): | |||||
| super(NetDynamic, self).__init__() | |||||
| self.conv = inner.GpuConvertToDynamicShape() | |||||
| self.Cast = P.Cast() | |||||
| self.type0 = type0 | |||||
| self.type1 = type1 | |||||
| def construct(self, x0, x1): | |||||
| x0_conv = self.conv(x0) | |||||
| x1_conv = self.conv(x1) | |||||
| output = (self.Cast(x0_conv, self.type0), | |||||
| self.Cast(x1_conv, self.type1)) | |||||
| return output | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -563,3 +580,20 @@ def test_cast30(): | |||||
| assert type0 == 'uint16' | assert type0 == 'uint16' | ||||
| type1 = output[1].asnumpy().dtype | type1 = output[1].asnumpy().dtype | ||||
| assert type1 == 'uint32' | assert type1 == 'uint32' | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_cast31(): | |||||
| x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32)) | |||||
| t0 = mstype.uint16 | |||||
| x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32)) | |||||
| t1 = mstype.uint32 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| net = NetDynamic(t0, t1) | |||||
| output = net(x0, x1) | |||||
| type0 = output[0].asnumpy().dtype | |||||
| assert type0 == 'uint16' | |||||
| type1 = output[1].asnumpy().dtype | |||||
| assert type1 == 'uint32' | |||||
| @@ -20,6 +20,7 @@ import mindspore.context as context | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| class NetRelu(nn.Cell): | class NetRelu(nn.Cell): | ||||
| @@ -31,10 +32,21 @@ class NetRelu(nn.Cell): | |||||
| return self.relu(x) | return self.relu(x) | ||||
| class NetReluDynamic(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetReluDynamic, self).__init__() | |||||
| self.conv = inner.GpuConvertToDynamicShape() | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| x_conv = self.conv(x) | |||||
| return self.relu(x_conv) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_relu(): | |||||
| def test_relu_float32(): | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | x = Tensor(np.array([[[[-1, 1, 10], | ||||
| [1, -1, 1], | [1, -1, 1], | ||||
| [10, 1, -1]]]]).astype(np.float32)) | [10, 1, -1]]]]).astype(np.float32)) | ||||
| @@ -51,3 +63,65 @@ def test_relu(): | |||||
| relu = NetRelu() | relu = NetRelu() | ||||
| output = relu(x) | output = relu(x) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_int32(): | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.int32)) | |||||
| expect = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.int32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_int64(): | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.int64)) | |||||
| expect = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.int64) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| print(output.asnumpy(), expect) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_int64_dynamic_shape(): | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.int64)) | |||||
| expect = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.int64) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| relu_dynamic = NetReluDynamic() | |||||
| output = relu_dynamic(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||