From f3f9fc958ab043e98797e9e6b1a4d304e0896862 Mon Sep 17 00:00:00 2001 From: CaoJian Date: Thu, 16 Jul 2020 22:54:53 +0800 Subject: [PATCH] add GPU operator: abs and floor --- .../gpu/cuda_impl/unary_op_impl.cu | 43 +++++++++++++++++++ .../gpu/cuda_impl/unary_op_impl.cuh | 4 ++ .../gpu/math/unary_op_gpu_kernel.cc | 8 ++++ .../gpu/math/unary_op_gpu_kernel.h | 14 +++++- 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 09b347e3d5..629c4c29dc 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -103,6 +103,35 @@ __global__ void ZeroslikeKernel(T *output, size_t count) { return; } template +__global__ void AbsKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = abs(input[i]); + } + return; +} +template <> +__global__ void AbsKernel(half *input, half *output, size_t count) { + half zero = 0.0; + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = input[i] < zero ? -input[i] : input[i]; + } + return; +} +template +__global__ void FloorKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = floor(input[i]); + } + return; +} +template <> +__global__ void FloorKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hfloor(input[i]); + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -147,6 +176,16 @@ void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { ZeroslikeKernel<<>>(output, count); return; } +template +void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + AbsKernel<<>>(input, output, count); + return; +} +template +void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + FloorKernel<<>>(input, output, count); + return; +} template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); @@ -156,6 +195,8 @@ template void Square(float *input, float *output, size_t count, cudaStrea template void Sqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); +template void Abs(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Floor(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); @@ -164,3 +205,5 @@ template void Square(half *input, half *output, size_t count, cudaStream_t template void Sqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); +template void Abs(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Floor(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index cf8b30866e..4020f93df2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -34,5 +34,9 @@ template void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); +template +void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index ae8e7bbd0b..d646ef417c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -46,5 +46,13 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 26993bc3bd..a02b94130c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -36,6 +36,8 @@ enum UnaryOptype { UNARY_OP_SQUARE, UNARY_OP_SQRT, UNARY_OP_RSQRT, + UNARY_OP_ABS, + UNARY_OP_FLOOR, UNARY_OP_INVALID_TYPE = 255 }; static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, @@ -45,7 +47,9 @@ static const std::map kUnaryOpTypeMap = {{"Exp", UNARY {"ZerosLike", UNARY_OP_ZEROSLIKE}, {"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT}, - {"Rsqrt", UNARY_OP_RSQRT}}; + {"Rsqrt", UNARY_OP_RSQRT}, + {"Abs", UNARY_OP_ABS}, + {"Floor", UNARY_OP_FLOOR}}; template class UnaryOpGpuKernel : public GpuKernel { public: @@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel { Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; } + case UNARY_OP_ABS: { + Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_FLOOR: { + Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; }