From de63ee46907491663118cbd1cca8df22f4fdbd9b Mon Sep 17 00:00:00 2001 From: zuochuanyong Date: Wed, 12 May 2021 19:48:30 +0800 Subject: [PATCH] add some gpu ops after removing akg --- .../gpu/cuda_impl/broadcast_impl.cu | 128 +++++++++++- .../gpu/cuda_impl/broadcast_impl.cuh | 5 + .../gpu/cuda_impl/logical_not_impl.cu | 40 ++++ .../gpu/cuda_impl/logical_not_impl.cuh | 28 +++ .../gpu/math/broadcast_gpu_kernel.cc | 100 ++++++++++ .../gpu/math/broadcast_gpu_kernel.h | 5 + .../gpu/math/logical_not_gpu_kernel.cc | 23 +++ .../gpu/math/logical_not_gpu_kernel.h | 78 ++++++++ tests/st/ops/gpu/test_equal_op.py | 184 ++++++++---------- tests/st/ops/gpu/test_lessequal_op.py | 49 +++-- 10 files changed, 520 insertions(+), 120 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 58c79b4a7c..61970a725d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -50,6 +50,75 @@ struct EqualFunc { } }; +template +struct GreaterEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs >= rhs ? true : false; } +}; + +template <> +struct GreaterEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) { + return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? + true : (__half2float(lhs) > __half2float(rhs) ? true : false); + } +}; + +template <> +struct GreaterEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) { + return std::abs(lhs - rhs) < 1e-9 ? true : (lhs > rhs ? true : false); + } +}; + +template +struct LessEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs <= rhs ? true : false; } +}; + +template <> +struct LessEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) { + return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? + true : (__half2float(lhs) < __half2float(rhs) ? true : false); + } +}; + +template <> +struct LessEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) { + return std::abs(lhs - rhs) < 1e-9 ? true : (lhs < rhs ? true : false); + } +}; + +template +struct NotEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs == rhs ? false : true; } +}; + +template <> +struct NotEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) { + return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? false : true; + } +}; + +template <> +struct NotEqualFunc { + __device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) { + return std::abs(lhs - rhs) < 1e-9 ? false : true; + } +}; + +template +struct LogicalAndFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs && rhs; } +}; + +template +struct LogicalOrFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs || rhs; } +}; + template struct MinimumFunc { __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } @@ -329,6 +398,16 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T * return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); case BROADCAST_TYPE_EQUAL: return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_GREATER_EQUAL: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_LESS_EQUAL: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_NOT_EQUAL: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_LOGICAL_AND: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_LOGICAL_OR: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); default: break; } @@ -348,7 +427,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t cudaStream_t stream); template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, bool *y, cudaStream_t stream); - +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, bool *y, + cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y, + cudaStream_t stream); // Element-wise ArithMetic template __global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *y) { @@ -426,7 +508,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8 cudaStream_t stream); template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y, cudaStream_t stream); - +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, int16_t *y, + cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y, + cudaStream_t stream); // Broadcast comparison __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } @@ -489,6 +574,31 @@ void BroadcastCmp(const std::vector &x0_dims, const std::vector x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_GREATER_EQUAL: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_LESS_EQUAL: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_NOT_EQUAL: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_LOGICAL_AND: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_LOGICAL_OR: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); default: break; } @@ -515,7 +625,12 @@ template void BroadcastCmp(const std::vector &x0_dims, const std::vector template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, bool *y, cudaStream_t stream); - +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int16_t *x0, + const int16_t *x1, bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const bool *x0, + const bool *x1, bool *y, cudaStream_t stream); // Broadcast Arithmetic template __global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, @@ -662,7 +777,12 @@ template void BroadcastArith(const std::vector &x0_dims, const std::vect template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y, cudaStream_t stream); - +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int16_t *x0, + const int16_t *x1, int16_t *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const bool *x0, + const bool *x1, bool *y, cudaStream_t stream); // BroadcastTo template __global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index 32268ac22a..397961dfd3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -41,6 +41,11 @@ enum BroadcastOpType { BROADCAST_TYPE_MOD = 15, BROADCAST_TYPE_FLOORMOD = 16, BROADCAST_TYPE_ATAN2 = 17, + BROADCAST_TYPE_GREATER_EQUAL = 18, + BROADCAST_TYPE_LESS_EQUAL = 19, + BROADCAST_TYPE_NOT_EQUAL = 20, + BROADCAST_TYPE_LOGICAL_AND = 21, + BROADCAST_TYPE_LOGICAL_OR = 22, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cu new file mode 100644 index 0000000000..f544813750 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +struct LogicalNotFunc { + __device__ __host__ __forceinline__ bool operator()(const T &x) { return !x; } +}; + +template +__global__ void LogicalNotKernel(const int nums, const T *x, bool *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x[pos]); + } +} + +template +void LogicalNotImpl(const int &nums, const T *x, bool *y, cudaStream_t stream) { + return LogicalNotKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x, y); +} + +template void LogicalNotImpl(const int &nums, const bool *x, bool *y, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh new file mode 100644 index 0000000000..df915f879d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + + + +template +void LogicalNotImpl(const int &nums, const T *x, bool *y, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index ed3e9f0e0f..a88caa3b42 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -64,6 +64,17 @@ MS_REG_GPU_KERNEL_ONE( Atan2, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, double) // fp32 MS_REG_GPU_KERNEL_ONE( @@ -126,6 +137,18 @@ MS_REG_GPU_KERNEL_ONE( Atan2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + NotEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float) // fp16 MS_REG_GPU_KERNEL_ONE( @@ -188,6 +211,18 @@ MS_REG_GPU_KERNEL_ONE( Atan2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + NotEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half) // int32 MS_REG_GPU_KERNEL_ONE( @@ -235,6 +270,16 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int) // int64 MS_REG_GPU_KERNEL_ONE( @@ -279,6 +324,16 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) // int8 MS_REG_GPU_KERNEL_ONE( @@ -287,6 +342,12 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int8_t) // uint8 MS_REG_GPU_KERNEL_ONE( @@ -295,5 +356,44 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE( + NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, uint8_t) + +// int16 +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE( + NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE( + GreaterEqual, + KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int16_t) + +// bool +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, bool) +MS_REG_GPU_KERNEL_ONE( + NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, bool) +MS_REG_GPU_KERNEL_ONE( + LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, bool) +MS_REG_GPU_KERNEL_ONE( + LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 362f583fad..dfe392abd4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -133,6 +133,11 @@ class BroadcastOpGpuKernel : public GpuKernel { {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Equal", BROADCAST_TYPE_EQUAL}, + {"GreaterEqual", BROADCAST_TYPE_GREATER_EQUAL}, + {"LessEqual", BROADCAST_TYPE_LESS_EQUAL}, + {"NotEqual", BROADCAST_TYPE_NOT_EQUAL}, + {"LogicalAnd", BROADCAST_TYPE_LOGICAL_AND}, + {"LogicalOr", BROADCAST_TYPE_LOGICAL_OR}, }; auto iter = kBroadcastCmpTypeMap.find(kernel_name); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.cc new file mode 100644 index 0000000000..ca421fae11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + LogicalNotGpuKernel, bool) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h new file mode 100644 index 0000000000..420579184c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LOGICAL_NOT_GPU_KERNEL_H +#define MINDSPORE_LOGICAL_NOT_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class LogicalNotGpuKernel : public GpuKernel { + public: + LogicalNotGpuKernel() { ResetResource(); } + ~LogicalNotGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto input_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + LogicalNotImpl(input_num_, input_addr, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + input_num_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_num_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_num_ * sizeof(T)); + output_size_list_.push_back(input_num_ * sizeof(T)); + } + + private: + size_t input_num_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/tests/st/ops/gpu/test_equal_op.py b/tests/st/ops/gpu/test_equal_op.py index 878c596e3a..cb3e40378c 100644 --- a/tests/st/ops/gpu/test_equal_op.py +++ b/tests/st/ops/gpu/test_equal_op.py @@ -31,6 +31,7 @@ class NetEqual(Cell): def construct(self, x, y): return self.Equal(x, y) + class NetEqualDynamic(Cell): def __init__(self): super(NetEqualDynamic, self).__init__() @@ -42,6 +43,7 @@ class NetEqualDynamic(Cell): y_conv = self.conv(y) return self.Equal(x_conv, y_conv) + class NetNotEqual(Cell): def __init__(self): super(NetNotEqual, self).__init__() @@ -50,6 +52,7 @@ class NetNotEqual(Cell): def construct(self, x, y): return self.NotEqual(x, y) + class NetGreaterEqual(Cell): def __init__(self): super(NetGreaterEqual, self).__init__() @@ -69,12 +72,12 @@ def test_equal(): expect0 = np.equal(x0_np, y0_np) x1_np = np.array([0, 1, 3]).astype(np.float32) x1 = Tensor(x1_np) - y1_np = np.array([0, 1, -3]).astype(np.float32) + y1_np = np.array([0]).astype(np.float32) y1 = Tensor(y1_np) expect1 = np.equal(x1_np, y1_np) x2_np = np.array([0, 1, 3]).astype(np.int32) x2 = Tensor(x2_np) - y2_np = np.array([0, 1, -3]).astype(np.int32) + y2_np = np.array([0]).astype(np.int32) y2 = Tensor(y2_np) expect2 = np.equal(x2_np, y2_np) x3_np = np.array([0, 1, 3]).astype(np.int16) @@ -93,74 +96,45 @@ def test_equal(): y5 = Tensor(y5_np) expect5 = np.equal(x5_np, y5_np) x6_np = np.array([0, 1, 4]).astype(np.int8) - x6 = Tensor(x4_np) + x6 = Tensor(x6_np) y6_np = np.array([0, 1, 3]).astype(np.int8) - y6 = Tensor(y4_np) + y6 = Tensor(y6_np) expect6 = np.equal(x6_np, y6_np) x7_np = np.array([0, 1, 4]).astype(np.int64) - x7 = Tensor(x4_np) + x7 = Tensor(x7_np) y7_np = np.array([0, 1, 3]).astype(np.int64) - y7 = Tensor(y4_np) + y7 = Tensor(y7_np) expect7 = np.equal(x7_np, y7_np) x8_np = np.array([0, 1, 4]).astype(np.float16) - x8 = Tensor(x4_np) + x8 = Tensor(x8_np) y8_np = np.array([0, 1, 3]).astype(np.float16) - y8 = Tensor(y4_np) + y8 = Tensor(y8_np) expect8 = np.equal(x8_np, y8_np) + x9_np = np.array([0, 1, 4]).astype(np.float64) + x9 = Tensor(x9_np) + y9_np = np.array([0, 1, 3]).astype(np.float64) + y9 = Tensor(y9_np) + expect9 = np.equal(x9_np, y9_np) + + x = [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] + y = [y0, y1, y2, y3, y4, y5, y6, y7, y8, y9] + expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6, expect7, expect8, expect9] context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") equal = NetEqual() - output0 = equal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = equal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = equal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape - output3 = equal(x3, y3) - assert np.all(output3.asnumpy() == expect3) - assert output3.shape == expect3.shape - output4 = equal(x4, y4) - assert np.all(output4.asnumpy() == expect4) - assert output4.shape == expect4.shape - output5 = equal(x5, y5) - assert np.all(output5.asnumpy() == expect5) - assert output5.shape == expect5.shape - - + for i, xi in enumerate(x): + output = equal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") equal = NetEqual() - output0 = equal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = equal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = equal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape - output3 = equal(x3, y3) - assert np.all(output3.asnumpy() == expect3) - assert output3.shape == expect3.shape - output4 = equal(x4, y4) - assert np.all(output4.asnumpy() == expect4) - assert output4.shape == expect4.shape - output5 = equal(x5, y5) - assert np.all(output5.asnumpy() == expect5) - assert output5.shape == expect5.shape - output6 = equal(x6, y6) - assert np.all(output6.asnumpy() == expect6) - assert output6.shape == expect6.shape - output7 = equal(x7, y7) - assert np.all(output7.asnumpy() == expect7) - assert output7.shape == expect7.shape - output8 = equal(x8, y8) - assert np.all(output8.asnumpy() == expect8) - assert output8.shape == expect8.shape - + for i, xi in enumerate(x): + output = equal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -178,44 +152,42 @@ def test_notequal(): x3 = Tensor(np.array([[False, True], [True, False]]).astype(bool)) y3 = Tensor(np.array([[True, False]]).astype(bool)) expect3 = np.array([[True, True], [False, False]]) + x4 = Tensor(np.array([[1.2, 1], [1, 0]]).astype(np.float16)) + y4 = Tensor(np.array([[1, 2]]).astype(np.float16)) + expect4 = np.array([[True, True], [False, True]]) + x5 = Tensor(np.array([[2, 1], [1, 0]]).astype(np.int64)) + y5 = Tensor(np.array([[1, 2]]).astype(np.int64)) + expect5 = np.array([[True, True], [False, True]]) + x6 = Tensor(np.array([[2, 1], [1, 0]]).astype(np.int32)) + y6 = Tensor(np.array([[1, 2], [1, 2]]).astype(np.int32)) + expect6 = np.array([[True, True], [False, True]]) + + x = [x0, x1, x2, x3, x4, x5, x6] + y = [y0, y1, y2, y3, y4, y5, y6] + expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6] context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") notequal = NetNotEqual() - output0 = notequal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = notequal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = notequal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape - output3 = notequal(x3, y3) - assert np.all(output3.asnumpy() == expect3) - assert output3.shape == expect3.shape + for i, xi in enumerate(x): + output = notequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") notequal = NetNotEqual() - output0 = notequal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = notequal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = notequal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape - output3 = notequal(x3, y3) - assert np.all(output3.asnumpy() == expect3) - assert output3.shape == expect3.shape - + for i, xi in enumerate(x): + output = notequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_greaterqual(): x0 = Tensor(np.array([[1.2, 1], [1, 0]]).astype(np.float32)) - y0 = Tensor(np.array([[1, 2]]).astype(np.float32)) + y0 = Tensor(np.array([[1, 2], [1, 2]]).astype(np.float32)) expect0 = np.array([[True, False], [True, False]]) x1 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int16)) y1 = Tensor(np.array([[1, 2]]).astype(np.int16)) @@ -224,29 +196,41 @@ def test_greaterqual(): y2 = Tensor(np.array([[1, 2]]).astype(np.uint8)) expect2 = np.array([[True, False], [True, True]]) + x3 = Tensor(np.array([[2, 1], [1, 2]]).astype(np.float64)) + y3 = Tensor(np.array([[1, 2]]).astype(np.float64)) + expect3 = np.array([[True, False], [True, True]]) + x4 = Tensor(np.array([[2, 1], [1, 2]]).astype(np.float16)) + y4 = Tensor(np.array([[1, 2]]).astype(np.float16)) + expect4 = np.array([[True, False], [True, True]]) + x5 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int64)) + y5 = Tensor(np.array([[1, 2]]).astype(np.int64)) + expect5 = np.array([[True, False], [True, False]]) + x6 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int32)) + y6 = Tensor(np.array([[1, 2]]).astype(np.int32)) + expect6 = np.array([[True, False], [True, False]]) + x7 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int8)) + y7 = Tensor(np.array([[1, 2]]).astype(np.int8)) + expect7 = np.array([[True, False], [True, False]]) + + x = [x0, x1, x2, x3, x4, x5, x6, x7] + y = [y0, y1, y2, y3, y4, y5, y6, y7] + expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6, expect7] + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") gequal = NetGreaterEqual() - output0 = gequal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = gequal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = gequal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape + for i, xi in enumerate(x): + output = gequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gequal = NetGreaterEqual() - output0 = gequal(x0, y0) - assert np.all(output0.asnumpy() == expect0) - assert output0.shape == expect0.shape - output1 = gequal(x1, y1) - assert np.all(output1.asnumpy() == expect1) - assert output1.shape == expect1.shape - output2 = gequal(x2, y2) - assert np.all(output2.asnumpy() == expect2) - assert output2.shape == expect2.shape + for i, xi in enumerate(x): + output = gequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) @pytest.mark.level0 diff --git a/tests/st/ops/gpu/test_lessequal_op.py b/tests/st/ops/gpu/test_lessequal_op.py index 4888ae2c19..8940ade593 100644 --- a/tests/st/ops/gpu/test_lessequal_op.py +++ b/tests/st/ops/gpu/test_lessequal_op.py @@ -36,29 +36,46 @@ class Net(Cell): @pytest.mark.env_onecard def test_lessequal(): x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) - y = Tensor(np.array([[2]]).astype(np.float32)) - expect = [[True, True, False]] + y = Tensor(np.array([[2, 2, 2]]).astype(np.float32)) + expect = np.array([[True, True, False]]) x1 = Tensor(np.array([[1, 2, 3]]).astype(np.int16)) y1 = Tensor(np.array([[2]]).astype(np.int16)) - expect = [[True, True, False]] + expect1 = np.array([[True, True, False]]) x2 = Tensor(np.array([[1, 2, 3]]).astype(np.uint8)) y2 = Tensor(np.array([[2]]).astype(np.uint8)) - expect = [[True, True, False]] + expect2 = np.array([[True, True, False]]) + x3 = Tensor(np.array([[1, 2, 3]]).astype(np.float64)) + y3 = Tensor(np.array([[2]]).astype(np.float64)) + expect3 = np.array([[True, True, False]]) + x4 = Tensor(np.array([[1, 2, 3]]).astype(np.float16)) + y4 = Tensor(np.array([[2]]).astype(np.float16)) + expect4 = np.array([[True, True, False]]) + x5 = Tensor(np.array([[1, 2, 3]]).astype(np.int64)) + y5 = Tensor(np.array([[2]]).astype(np.int64)) + expect5 = np.array([[True, True, False]]) + x6 = Tensor(np.array([[1, 2, 3]]).astype(np.int32)) + y6 = Tensor(np.array([[2, 2, 2]]).astype(np.int32)) + expect6 = np.array([[True, True, False]]) + x7 = Tensor(np.array([[1, 2, 3]]).astype(np.int8)) + y7 = Tensor(np.array([[2]]).astype(np.int8)) + expect7 = np.array([[True, True, False]]) + + x = [x, x1, x2, x3, x4, x5, x6, x7] + y = [y, y1, y2, y3, y4, y5, y6, y7] + expect = [expect, expect1, expect2, expect3, expect4, expect5, expect6, expect7] context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") lessequal = Net() - output = lessequal(x, y) - assert np.all(output.asnumpy() == expect) - output = lessequal(x1, y1) - assert np.all(output.asnumpy() == expect) - output = lessequal(x2, y2) - assert np.all(output.asnumpy() == expect) + for i, xi in enumerate(x): + output = lessequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x))) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") lessequal = Net() - output = lessequal(x, y) - assert np.all(output.asnumpy() == expect) - output = lessequal(x1, y1) - assert np.all(output.asnumpy() == expect) - output = lessequal(x2, y2) - assert np.all(output.asnumpy() == expect) + for i, xi in enumerate(x): + output = lessequal(xi, y[i]) + assert np.all(output.asnumpy() == expect[i]) + assert output.shape == expect[i].shape + print('test [%d/%d] passed!' % (i, len(x)))