diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 6f825e460b..5277bb9f7c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -31,6 +31,25 @@ struct LessFunc { __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } }; +template +struct EqualFunc { + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs == rhs ? true : false; } +}; + +template <> +struct EqualFunc { + __device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) { + return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? true : false; + } +}; + +template <> +struct EqualFunc { + __device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) { + return std::abs(lhs - rhs) < 1e-9 ? true : false; + } +}; + template struct MinimumFunc { __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } @@ -188,6 +207,8 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T * return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); case BROADCAST_TYPE_LESS: return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_EQUAL: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); default: break; } @@ -331,6 +352,11 @@ void BroadcastCmp(const std::vector &x0_dims, const std::vector x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_EQUAL: + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); default: break; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index 3984c0ad99..c72c144ef5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -36,6 +36,7 @@ enum BroadcastOpType { BROADCAST_TYPE_ABSGRAD = 10, BROADCAST_TYPE_DIV = 11, BROADCAST_TYPE_DIVNONAN = 12, + BROADCAST_TYPE_EQUAL = 13, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 8eee77fb5b..e0296cabed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -26,6 +26,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE( Maximum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), @@ -75,6 +78,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE( Maximum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -123,6 +129,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int) MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) @@ -135,6 +144,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int) MS_REG_GPU_KERNEL_ONE( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) @@ -156,6 +168,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), BroadcastOpGpuKernel, int64_t) @@ -168,6 +183,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), BroadcastOpGpuKernel, int64_t) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index e8a97d9413..e9617220e0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -132,6 +132,7 @@ class BroadcastOpGpuKernel : public GpuKernel { static std::map kBroadcastCmpTypeMap = { {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, + {"Equal", BROADCAST_TYPE_EQUAL}, }; auto iter = kBroadcastCmpTypeMap.find(kernel_name);