From: @peilin-wang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -37,5 +37,7 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).A | |||
| MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ZerosLikeGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| ZerosLikeGpuKernel, double) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -196,14 +196,6 @@ __global__ void AtanKernel(const T *input, T *output, const size_t count) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ZeroslikeKernel(T *output, const size_t count) { | |||
| T zero = 0.0; | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = zero; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void AbsKernel(const T *input, T *output, const size_t count) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||
| output[i] = abs(input[i]); | |||
| @@ -328,11 +320,6 @@ void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stre | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { | |||
| AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | |||
| return; | |||
| @@ -362,7 +349,6 @@ template void Atan<double>(const double *input, double *output, const size_t cou | |||
| template void Asinh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<double>(double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -386,7 +372,6 @@ template void Atan<float>(const float *input, float *output, const size_t count, | |||
| template void Asinh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -409,6 +394,5 @@ template void Atan<half>(const half *input, half *output, const size_t count, cu | |||
| template void Asinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -55,8 +55,6 @@ void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stre | |||
| template <typename T> | |||
| void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -52,10 +52,6 @@ MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32). | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -36,7 +36,6 @@ enum UnaryOptype { | |||
| UNARY_OP_ERFC, | |||
| UNARY_OP_NEG, | |||
| UNARY_OP_RECIPROCAL, | |||
| UNARY_OP_ZEROSLIKE, | |||
| UNARY_OP_SQUARE, | |||
| UNARY_OP_SQRT, | |||
| UNARY_OP_RSQRT, | |||
| @@ -51,27 +50,19 @@ enum UnaryOptype { | |||
| UNARY_OP_FLOOR, | |||
| UNARY_OP_INVALID_TYPE = 255 | |||
| }; | |||
| static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, | |||
| {"Expm1", UNARY_OP_EXPM1}, | |||
| {"Log", UNARY_OP_LOG}, | |||
| {"Log1p", UNARY_OP_LOG1P}, | |||
| {"Erf", UNARY_OP_ERF}, | |||
| {"Erfc", UNARY_OP_ERFC}, | |||
| {"Neg", UNARY_OP_NEG}, | |||
| {"Reciprocal", UNARY_OP_RECIPROCAL}, | |||
| {"ZerosLike", UNARY_OP_ZEROSLIKE}, | |||
| {"Square", UNARY_OP_SQUARE}, | |||
| {"Sqrt", UNARY_OP_SQRT}, | |||
| {"Rsqrt", UNARY_OP_RSQRT}, | |||
| {"Sin", UNARY_OP_SIN}, | |||
| {"Cos", UNARY_OP_COS}, | |||
| {"Asin", UNARY_OP_ASIN}, | |||
| {"ACos", UNARY_OP_ACOS}, | |||
| {"Atan", UNARY_OP_ATAN}, | |||
| {"Asinh", UNARY_OP_ASINH}, | |||
| {"Acosh", UNARY_OP_ACOSH}, | |||
| {"Abs", UNARY_OP_ABS}, | |||
| {"Floor", UNARY_OP_FLOOR}}; | |||
| static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = { | |||
| {"Exp", UNARY_OP_EXP}, {"Expm1", UNARY_OP_EXPM1}, | |||
| {"Log", UNARY_OP_LOG}, {"Log1p", UNARY_OP_LOG1P}, | |||
| {"Erf", UNARY_OP_ERF}, {"Erfc", UNARY_OP_ERFC}, | |||
| {"Neg", UNARY_OP_NEG}, {"Reciprocal", UNARY_OP_RECIPROCAL}, | |||
| {"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT}, | |||
| {"Rsqrt", UNARY_OP_RSQRT}, {"Sin", UNARY_OP_SIN}, | |||
| {"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN}, | |||
| {"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN}, | |||
| {"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH}, | |||
| {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; | |||
| template <typename T> | |||
| class UnaryOpGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -160,10 +151,6 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case UNARY_OP_ZEROSLIKE: { | |||
| Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| case UNARY_OP_ABS: { | |||
| Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -22,9 +22,6 @@ from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| class NetZerosLike(nn.Cell): | |||
| def __init__(self): | |||
| super(NetZerosLike, self).__init__() | |||
| @@ -109,7 +106,6 @@ def test_zeros_like_dynamic_int8(): | |||
| x = Tensor(np.arange(24).reshape(1, 4, 1, 6).astype(np.int8)) | |||
| output = zeros_like_dynamic(x) | |||
| expected = np.zeros([1, 4, 1, 6]) | |||
| print(output) | |||
| np.testing.assert_array_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @@ -148,6 +144,15 @@ def test_zeros_like_dynamic_float32(): | |||
| expected = np.zeros([3, 7, 3]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_zeros_like_dynamic_float64(): | |||
| x = Tensor(np.arange(2).reshape(2, 1, 1).astype(np.float64)) | |||
| output = zeros_like_dynamic(x) | |||
| expected = np.zeros([2, 1, 1]) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||