Browse Source

!12400 Add float64 support to ZerosLike, remove duplicate ZerosLike gpu kernel

From: @peilin-wang
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
50bf5033f2
6 changed files with 28 additions and 56 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc
  2. +0
    -16
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu
  3. +1
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh
  4. +0
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc
  5. +14
    -27
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h
  6. +10
    -5
      tests/st/ops/gpu/test_zeroslike_op.py

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -37,5 +37,7 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).A
MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ZerosLikeGpuKernel, float)

MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ZerosLikeGpuKernel, double)
} // namespace kernel
} // namespace mindspore

+ 0
- 16
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu View File

@@ -196,14 +196,6 @@ __global__ void AtanKernel(const T *input, T *output, const size_t count) {
return;
}
template <typename T>
__global__ void ZeroslikeKernel(T *output, const size_t count) {
T zero = 0.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = zero;
}
return;
}
template <typename T>
__global__ void AbsKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = abs(input[i]);
@@ -328,11 +320,6 @@ void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
return;
}
template <typename T>
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream) {
ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count);
return;
}
template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@@ -362,7 +349,6 @@ template void Atan<double>(const double *input, double *output, const size_t cou
template void Asinh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<double>(double *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);

@@ -386,7 +372,6 @@ template void Atan<float>(const float *input, float *output, const size_t count,
template void Asinh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);

@@ -409,6 +394,5 @@ template void Atan<half>(const half *input, half *output, const size_t count, cu
template void Asinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

+ 1
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -55,8 +55,6 @@ void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
template <typename T>
void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);


+ 0
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc View File

@@ -52,10 +52,6 @@ MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),


+ 14
- 27
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,7 +36,6 @@ enum UnaryOptype {
UNARY_OP_ERFC,
UNARY_OP_NEG,
UNARY_OP_RECIPROCAL,
UNARY_OP_ZEROSLIKE,
UNARY_OP_SQUARE,
UNARY_OP_SQRT,
UNARY_OP_RSQRT,
@@ -51,27 +50,19 @@ enum UnaryOptype {
UNARY_OP_FLOOR,
UNARY_OP_INVALID_TYPE = 255
};
static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP},
{"Expm1", UNARY_OP_EXPM1},
{"Log", UNARY_OP_LOG},
{"Log1p", UNARY_OP_LOG1P},
{"Erf", UNARY_OP_ERF},
{"Erfc", UNARY_OP_ERFC},
{"Neg", UNARY_OP_NEG},
{"Reciprocal", UNARY_OP_RECIPROCAL},
{"ZerosLike", UNARY_OP_ZEROSLIKE},
{"Square", UNARY_OP_SQUARE},
{"Sqrt", UNARY_OP_SQRT},
{"Rsqrt", UNARY_OP_RSQRT},
{"Sin", UNARY_OP_SIN},
{"Cos", UNARY_OP_COS},
{"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS},
{"Atan", UNARY_OP_ATAN},
{"Asinh", UNARY_OP_ASINH},
{"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS},
{"Floor", UNARY_OP_FLOOR}};

static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Exp", UNARY_OP_EXP}, {"Expm1", UNARY_OP_EXPM1},
{"Log", UNARY_OP_LOG}, {"Log1p", UNARY_OP_LOG1P},
{"Erf", UNARY_OP_ERF}, {"Erfc", UNARY_OP_ERFC},
{"Neg", UNARY_OP_NEG}, {"Reciprocal", UNARY_OP_RECIPROCAL},
{"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT},
{"Rsqrt", UNARY_OP_RSQRT}, {"Sin", UNARY_OP_SIN},
{"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}};

template <typename T>
class UnaryOpGpuKernel : public GpuKernel {
public:
@@ -160,10 +151,6 @@ class UnaryOpGpuKernel : public GpuKernel {
Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ZEROSLIKE: {
Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
case UNARY_OP_ABS: {
Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;


+ 10
- 5
tests/st/ops/gpu/test_zeroslike_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,6 @@ from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner

context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")


class NetZerosLike(nn.Cell):
def __init__(self):
super(NetZerosLike, self).__init__()
@@ -109,7 +106,6 @@ def test_zeros_like_dynamic_int8():
x = Tensor(np.arange(24).reshape(1, 4, 1, 6).astype(np.int8))
output = zeros_like_dynamic(x)
expected = np.zeros([1, 4, 1, 6])
print(output)
np.testing.assert_array_equal(output.asnumpy(), expected)

@pytest.mark.level0
@@ -148,6 +144,15 @@ def test_zeros_like_dynamic_float32():
expected = np.zeros([3, 7, 3])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_zeros_like_dynamic_float64():
x = Tensor(np.arange(2).reshape(2, 1, 1).astype(np.float64))
output = zeros_like_dynamic(x)
expected = np.zeros([2, 1, 1])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard


Loading…
Cancel
Save