From: @xcnick Reviewed-by: @liangchenghui,@tom__chen Signed-off-by: @liangchenghuipull/15647/MERGE
| @@ -99,6 +99,16 @@ void Floor(const T *in, T *out, size_t size) { | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Rint(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(rint(in[i])); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Reciprocal(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| @@ -240,6 +250,7 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k | |||
| {prim::kPrimLogicalNot->name(), LOGICALNOT}, | |||
| {prim::kPrimSign->name(), SIGN}, | |||
| {prim::kPrimFloor->name(), FLOOR}, | |||
| {prim::kPrimRint->name(), RINT}, | |||
| {prim::kPrimReciprocal->name(), RECIPROCAL}, | |||
| {prim::kPrimGeLU->name(), GELU}, | |||
| {prim::kPrimAsin->name(), ASIN}, | |||
| @@ -305,7 +316,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||
| {ATAN, Atan<T>}, {SINH, Sinh<T>}, | |||
| {COSH, Cosh<T>}, {ASINH, Asinh<T>}, | |||
| {ACOSH, Acosh<T>}, {ATANH, Atanh<T>}}; | |||
| {ACOSH, Acosh<T>}, {ATANH, Atanh<T>}, | |||
| {RINT, Rint<T>}}; | |||
| if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { | |||
| kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); | |||
| } else { | |||
| @@ -65,6 +65,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -113,6 +113,7 @@ enum OperateType { | |||
| ASINHGRAD, | |||
| ACOSHGRAD, | |||
| ATAN2, | |||
| RINT, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -225,6 +225,20 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count) | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void RintKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = rint(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void RintKernel(const half *input, half *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = hrint(input[i]); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| @@ -329,6 +343,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre | |||
| FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| } | |||
| // double | |||
| template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -351,6 +370,7 @@ template void Acosh<double>(const double *input, double *output, const size_t co | |||
| template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| // float | |||
| @@ -374,6 +394,7 @@ template void Acosh<float>(const float *input, float *output, const size_t count | |||
| template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| // half | |||
| template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -396,3 +417,4 @@ template void Acosh<half>(const half *input, half *output, const size_t count, c | |||
| template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -58,5 +58,7 @@ template <typename T> | |||
| void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ | |||
| @@ -108,5 +108,11 @@ MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -48,6 +48,7 @@ enum UnaryOptype { | |||
| UNARY_OP_ACOSH, | |||
| UNARY_OP_ABS, | |||
| UNARY_OP_FLOOR, | |||
| UNARY_OP_RINT, | |||
| UNARY_OP_INVALID_TYPE = 255 | |||
| }; | |||
| @@ -61,7 +62,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = { | |||
| {"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN}, | |||
| {"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN}, | |||
| {"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH}, | |||
| {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; | |||
| {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}, | |||
| {"Rint", UNARY_OP_RINT}}; | |||
| template <typename T> | |||
| class UnaryOpGpuKernel : public GpuKernel { | |||
| @@ -159,6 +161,10 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_RINT: { | |||
| Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; | |||
| } | |||
| @@ -395,6 +395,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad" | |||
| inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||
| inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||
| inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs"); | |||
| inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint"); | |||
| inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round"); | |||
| inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | |||
| inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | |||
| @@ -2705,7 +2705,7 @@ class Rint(PrimitiveWithInfer): | |||
| TypeError: If dtype of `input_x` is neither float16 nor float32. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32) | |||
| @@ -50,6 +50,15 @@ class ReciprocalNet(nn.Cell): | |||
| return self.reciprocal(x) | |||
| class RintNet(nn.Cell): | |||
| def __init__(self): | |||
| super(RintNet, self).__init__() | |||
| self.rint = P.Rint() | |||
| def construct(self, x): | |||
| return self.rint(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @@ -118,6 +127,23 @@ def test_floor(): | |||
| assert np.all(output.asnumpy() == expect_output) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_rint(): | |||
| net = RintNet() | |||
| prop = 100 if np.random.random() > 0.5 else -100 | |||
| x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop | |||
| output = net(Tensor(x)) | |||
| expect_output = np.rint(x).astype(np.float16) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect_output) | |||
| x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop | |||
| output = net(Tensor(x)) | |||
| expect_output = np.rint(x).astype(np.float32) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect_output) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @@ -137,7 +163,3 @@ def test_reciprocal(): | |||
| diff = output.asnumpy() - expect_output | |||
| error = np.ones(shape=expect_output.shape) * 1.0e-5 | |||
| assert np.all(np.abs(diff) < error) | |||
| test_square() | |||
| test_floor() | |||
| test_reciprocal() | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, ops | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.rint = ops.Rint() | |||
| def construct(self, x): | |||
| return self.rint(x) | |||
| def generate_testcases(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype) | |||
| net = Net() | |||
| output = net(Tensor(x)) | |||
| expect = np.rint(x).astype(nptype) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype) | |||
| net = Net() | |||
| output = net(Tensor(x)) | |||
| expect = np.rint(x).astype(nptype) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sign_float32(): | |||
| generate_testcases(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sign_float16(): | |||
| generate_testcases(np.float16) | |||