From a90713b19485851c47d53571a6b2402ae34d1c2e Mon Sep 17 00:00:00 2001 From: xcnick Date: Sun, 25 Apr 2021 17:28:25 +0800 Subject: [PATCH] Add rint op kernel for cpu and gpu --- .../cpu/arithmetic_self_cpu_kernel.cc | 14 ++++- .../cpu/arithmetic_self_cpu_kernel.h | 2 + .../backend/kernel_compiler/cpu/cpu_kernel.h | 1 + .../gpu/cuda_impl/unary_op_impl.cu | 22 +++++++ .../gpu/cuda_impl/unary_op_impl.cuh | 2 + .../gpu/math/unary_op_gpu_kernel.cc | 6 ++ .../gpu/math/unary_op_gpu_kernel.h | 8 ++- mindspore/core/base/core_ops.h | 1 + mindspore/ops/operations/array_ops.py | 2 +- tests/st/ops/cpu/test_arithmetic_self_op.py | 30 ++++++++-- tests/st/ops/gpu/test_rint_op.py | 60 +++++++++++++++++++ 11 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 tests/st/ops/gpu/test_rint_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index a654effa4c..8af583bc80 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -99,6 +99,16 @@ void Floor(const T *in, T *out, size_t size) { CPUKernelUtils::ParallelFor(task, size); } +template +void Rint(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(rint(in[i])); + } + }; + CPUKernelUtils::ParallelFor(task, size); +} + template void Reciprocal(const T *in, T *out, size_t size) { auto task = [&](size_t start, size_t end) { @@ -240,6 +250,7 @@ static const std::map kArithmeticOpTypeMap = {{prim::k {prim::kPrimLogicalNot->name(), LOGICALNOT}, {prim::kPrimSign->name(), SIGN}, {prim::kPrimFloor->name(), FLOOR}, + {prim::kPrimRint->name(), RINT}, {prim::kPrimReciprocal->name(), RECIPROCAL}, {prim::kPrimGeLU->name(), GELU}, {prim::kPrimAsin->name(), ASIN}, @@ -305,7 +316,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs {ASIN, Asin}, {ACOS, ACos}, {ATAN, Atan}, {SINH, Sinh}, {COSH, Cosh}, {ASINH, Asinh}, - {ACOSH, Acosh}, {ATANH, Atanh}}; + {ACOSH, Acosh}, {ATANH, Atanh}, + {RINT, Rint}}; if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); } else { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 3820610934..2e47dda977 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -65,6 +65,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 9ddf86a4ef..da8b528d1e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -113,6 +113,7 @@ enum OperateType { ASINHGRAD, ACOSHGRAD, ATAN2, + RINT, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 06759a050e..d3647c6f24 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -225,6 +225,20 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count) return; } template +__global__ void RintKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = rint(input[i]); + } + return; +} +template <> +__global__ void RintKernel(const half *input, half *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hrint(input[i]); + } + return; +} +template void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -329,6 +343,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre FloorKernel<<>>(input, output, count); return; } +template +void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + RintKernel<<>>(input, output, count); + return; +} // double template void Exponential(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); @@ -351,6 +370,7 @@ template void Acosh(const double *input, double *output, const size_t co template void Rsqrt(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Rint(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); // float @@ -374,6 +394,7 @@ template void Acosh(const float *input, float *output, const size_t count template void Rsqrt(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Rint(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); // half template void Exponential(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); @@ -396,3 +417,4 @@ template void Acosh(const half *input, half *output, const size_t count, c template void Rsqrt(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Rint(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index fffe32ca52..4f0d92d81d 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -58,5 +58,7 @@ template void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template +void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 3079fd1ca6..73a4ea9372 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -108,5 +108,11 @@ MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 8c859def14..3d80f116d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -48,6 +48,7 @@ enum UnaryOptype { UNARY_OP_ACOSH, UNARY_OP_ABS, UNARY_OP_FLOOR, + UNARY_OP_RINT, UNARY_OP_INVALID_TYPE = 255 }; @@ -61,7 +62,8 @@ static const std::map kUnaryOpTypeMap = { {"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN}, {"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN}, {"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH}, - {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; + {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}, + {"Rint", UNARY_OP_RINT}}; template class UnaryOpGpuKernel : public GpuKernel { @@ -159,6 +161,10 @@ class UnaryOpGpuKernel : public GpuKernel { Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_RINT: { + Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e88d2c7721..ff91059f64 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -394,6 +394,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared("SqrtGrad" inline const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); inline const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); inline const PrimitivePtr kPrimAbs = std::make_shared("Abs"); +inline const PrimitivePtr kPrimRint = std::make_shared("Rint"); inline const PrimitivePtr kPrimRound = std::make_shared("Round"); inline const PrimitivePtr kPrimExp = std::make_shared("Exp"); inline const PrimitivePtr kPrimLog = std::make_shared("Log"); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index ae251d332d..cdd7699aff 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2705,7 +2705,7 @@ class Rint(PrimitiveWithInfer): TypeError: If dtype of `input_x` is neither float16 nor float32. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index 02e520839e..a7a00f5e9a 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -50,6 +50,15 @@ class ReciprocalNet(nn.Cell): return self.reciprocal(x) +class RintNet(nn.Cell): + def __init__(self): + super(RintNet, self).__init__() + self.rint = P.Rint() + + def construct(self, x): + return self.rint(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -118,6 +127,23 @@ def test_floor(): assert np.all(output.asnumpy() == expect_output) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_rint(): + net = RintNet() + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop + output = net(Tensor(x)) + expect_output = np.rint(x).astype(np.float16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop + output = net(Tensor(x)) + expect_output = np.rint(x).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -137,7 +163,3 @@ def test_reciprocal(): diff = output.asnumpy() - expect_output error = np.ones(shape=expect_output.shape) * 1.0e-5 assert np.all(np.abs(diff) < error) - -test_square() -test_floor() -test_reciprocal() diff --git a/tests/st/ops/gpu/test_rint_op.py b/tests/st/ops/gpu/test_rint_op.py new file mode 100644 index 0000000000..e1e05c6e06 --- /dev/null +++ b/tests/st/ops/gpu/test_rint_op.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.rint = ops.Rint() + + def construct(self, x): + return self.rint(x) + + +def generate_testcases(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.rint(x).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.rint(x).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sign_float32(): + generate_testcases(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sign_float16(): + generate_testcases(np.float16)