Browse Source

!15647 Add rint op for cpu and gpu

From: @xcnick
Reviewed-by: @liangchenghui,@tom__chen
Signed-off-by: @liangchenghui
pull/15647/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
94ed3b89a3
11 changed files with 141 additions and 7 deletions
  1. +13
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h
  3. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  4. +22
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu
  5. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh
  6. +6
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc
  7. +7
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h
  8. +1
    -0
      mindspore/core/base/core_ops.h
  9. +1
    -1
      mindspore/ops/operations/array_ops.py
  10. +26
    -4
      tests/st/ops/cpu/test_arithmetic_self_op.py
  11. +60
    -0
      tests/st/ops/gpu/test_rint_op.py

+ 13
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc View File

@@ -99,6 +99,16 @@ void Floor(const T *in, T *out, size_t size) {
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Rint(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(rint(in[i]));
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Reciprocal(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
@@ -240,6 +250,7 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k
{prim::kPrimLogicalNot->name(), LOGICALNOT},
{prim::kPrimSign->name(), SIGN},
{prim::kPrimFloor->name(), FLOOR},
{prim::kPrimRint->name(), RINT},
{prim::kPrimReciprocal->name(), RECIPROCAL},
{prim::kPrimGeLU->name(), GELU},
{prim::kPrimAsin->name(), ASIN},
@@ -305,7 +316,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}, {SINH, Sinh<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>},
{RINT, Rint<T>}};
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
} else {


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h View File

@@ -65,6 +65,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -113,6 +113,7 @@ enum OperateType {
ASINHGRAD,
ACOSHGRAD,
ATAN2,
RINT,
};
class CPUKernel : public kernel::KernelMod {


+ 22
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu View File

@@ -225,6 +225,20 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
return;
}
template <typename T>
__global__ void RintKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = rint(input[i]);
}
return;
}
template <>
__global__ void RintKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hrint(input[i]);
}
return;
}
template <typename T>
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@@ -329,6 +343,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}

// double
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
@@ -351,6 +370,7 @@ template void Acosh<double>(const double *input, double *output, const size_t co
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);


// float
@@ -374,6 +394,7 @@ template void Acosh<float>(const float *input, float *output, const size_t count
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);

// half
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
@@ -396,3 +417,4 @@ template void Acosh<half>(const half *input, half *output, const size_t count, c
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh View File

@@ -58,5 +58,7 @@ template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_

+ 6
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc View File

@@ -108,5 +108,11 @@ MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 7
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h View File

@@ -48,6 +48,7 @@ enum UnaryOptype {
UNARY_OP_ACOSH,
UNARY_OP_ABS,
UNARY_OP_FLOOR,
UNARY_OP_RINT,
UNARY_OP_INVALID_TYPE = 255
};

@@ -61,7 +62,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}};
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
{"Rint", UNARY_OP_RINT}};

template <typename T>
class UnaryOpGpuKernel : public GpuKernel {
@@ -159,6 +161,10 @@ class UnaryOpGpuKernel : public GpuKernel {
Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_RINT: {
Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
}


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -395,6 +395,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint");
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");


+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -2705,7 +2705,7 @@ class Rint(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32.

Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)


+ 26
- 4
tests/st/ops/cpu/test_arithmetic_self_op.py View File

@@ -50,6 +50,15 @@ class ReciprocalNet(nn.Cell):
return self.reciprocal(x)


class RintNet(nn.Cell):
def __init__(self):
super(RintNet, self).__init__()
self.rint = P.Rint()

def construct(self, x):
return self.rint(x)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -118,6 +127,23 @@ def test_floor():
assert np.all(output.asnumpy() == expect_output)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_rint():
net = RintNet()
prop = 100 if np.random.random() > 0.5 else -100
x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop
output = net(Tensor(x))
expect_output = np.rint(x).astype(np.float16)
np.testing.assert_almost_equal(output.asnumpy(), expect_output)

x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop
output = net(Tensor(x))
expect_output = np.rint(x).astype(np.float32)
np.testing.assert_almost_equal(output.asnumpy(), expect_output)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -137,7 +163,3 @@ def test_reciprocal():
diff = output.asnumpy() - expect_output
error = np.ones(shape=expect_output.shape) * 1.0e-5
assert np.all(np.abs(diff) < error)

test_square()
test_floor()
test_reciprocal()

+ 60
- 0
tests/st/ops/gpu/test_rint_op.py View File

@@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.rint = ops.Rint()

def construct(self, x):
return self.rint(x)


def generate_testcases(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
net = Net()
output = net(Tensor(x))
expect = np.rint(x).astype(nptype)
np.testing.assert_almost_equal(output.asnumpy(), expect)

context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
net = Net()
output = net(Tensor(x))
expect = np.rint(x).astype(nptype)
np.testing.assert_almost_equal(output.asnumpy(), expect)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sign_float32():
generate_testcases(np.float32)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sign_float16():
generate_testcases(np.float16)

Loading…
Cancel
Save