| @@ -18,60 +18,131 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMaxGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| UnsortedSegmentMaxGpuKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int) | |||
| UnsortedSegmentMaxGpuKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| UnsortedSegmentMax, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int, int64_t) | |||
| // Dynamic Mode - registered for int32/int64 3rd input | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| UnsortedSegmentMaxGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UnsortedSegmentMaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| UnsortedSegmentMaxGpuKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| UnsortedSegmentMaxGpuKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| UnsortedSegmentMaxGpuKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| UnsortedSegmentMaxGpuKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, | |||
| UnsortedSegmentMaxGpuKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int) | |||
| UnsortedSegmentMaxGpuKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| UnsortedSegmentMaxGpuKernel, int, int64_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| template <typename T, typename S> | |||
| class UnsortedSegmentMaxGpuKernel : public GpuKernel { | |||
| public: | |||
| UnsortedSegmentMaxGpuKernel() { ResetResource(); } | |||
| @@ -41,7 +41,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| int *indices_addr = GetDeviceAddress<int>(inputs, 1); | |||
| S *indices_addr = GetDeviceAddress<S>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| @@ -17,21 +17,21 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh" | |||
| #include <limits> | |||
| template <typename T> | |||
| __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) { | |||
| template <typename T, typename S> | |||
| __global__ void UnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, bool fp16_flag, T init_K, T *output) { | |||
| if (fp16_flag) { | |||
| init_K = __int2half_rd(-65504); // min value representable by float16 | |||
| } | |||
| for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; | |||
| for (size_t t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; | |||
| t_idx += blockDim.x * gridDim.x) { | |||
| int segment_id = t_idx / KWARPSIZE / inner_size; | |||
| int inner_id = t_idx / KWARPSIZE % inner_size; | |||
| int lane_id = threadIdx.x % KWARPSIZE; | |||
| size_t segment_id = t_idx / KWARPSIZE / inner_size; | |||
| size_t inner_id = t_idx / KWARPSIZE % inner_size; | |||
| size_t lane_id = threadIdx.x % KWARPSIZE; | |||
| T threadK = init_K; | |||
| for (int i = lane_id; i < outer_size; i += KWARPSIZE) { | |||
| for (size_t i = lane_id; i < outer_size; i += KWARPSIZE) { | |||
| if (segment_ids[i] != segment_id) continue; | |||
| T other_K = input[i * inner_size + inner_id]; | |||
| if (threadK < other_K) { | |||
| @@ -40,7 +40,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const | |||
| } | |||
| __syncwarp(); | |||
| for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) { | |||
| for (size_t offset = KWARPSIZE / 2; offset > 0; offset /= 2) { | |||
| T other_K = __shfl_down_sync(0xffffffff, threadK, offset); | |||
| if (threadK < other_K) { | |||
| threadK = other_K; | |||
| @@ -56,10 +56,10 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| template <typename T, typename S> | |||
| void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream) { | |||
| int size = (inner_size * KWARPSIZE * num_segments); | |||
| size_t size = (inner_size * KWARPSIZE * num_segments); | |||
| bool fp16_flag = false; | |||
| // handle fp16 min value | |||
| if (std::is_same<T, half>::value) { | |||
| @@ -71,9 +71,19 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t | |||
| return; | |||
| } | |||
| template void CalUnsortedSegmentMax<float>(const float *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<half>(const half *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<int>(const int *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<float, int>(const float *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, float *output, | |||
| cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<float, int64_t>(const float *input, const int64_t *segment_ids, | |||
| const int64_t num_segments, size_t outer_size, size_t inner_size, | |||
| float *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<half, int>(const half *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<half, int64_t>(const half *input, const int64_t *segment_ids, | |||
| const int64_t num_segments, size_t outer_size, size_t inner_size, | |||
| half *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<int, int>(const int *input, const int *segment_ids, const int64_t num_segments, | |||
| size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); | |||
| template void CalUnsortedSegmentMax<int, int64_t>(const int *input, const int64_t *segment_ids, | |||
| const int64_t num_segments, size_t outer_size, size_t inner_size, | |||
| int *output, cudaStream_t stream); | |||
| @@ -22,8 +22,8 @@ | |||
| // Setting warp size to sync data across threads | |||
| #define KWARPSIZE 32 | |||
| template <typename T> | |||
| void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| template <typename T, typename S> | |||
| void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, | |||
| size_t inner_size, T *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_ | |||
| @@ -306,7 +306,7 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri | |||
| MS_EXCEPTION_IF_NULL(segment_ids->shape()); | |||
| auto segment_ids_shape = segment_ids->shape()->shape(); | |||
| (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s"); | |||
| (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); | |||
| (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); | |||
| // check if dynamic shape | |||
| bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); | |||
| bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); | |||
| @@ -2001,7 +2001,8 @@ class UnsortedSegmentMax(PrimitiveWithCheck): | |||
| segment_ids_shape = segment_ids['shape'] | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, | |||
| [mstype.int32, mstype.int64], self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| @@ -71,12 +71,12 @@ def test_2d_int32(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| def test_3d_float16_int64(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 5 | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int64) | |||
| num_segments = Tensor(5, dtype=mstype.int64) | |||
| net = UnsortedSegmentMaxNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04], | |||
| @@ -110,12 +110,12 @@ def test_3d_float16(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_3d_float32(): | |||
| def test_3d_float32_int64(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_x = Tensor(np.arange( | |||
| 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int32) | |||
| num_segments = 3 | |||
| segment_ids = Tensor([2, 1, 1, -1], mstype.int64) | |||
| num_segments = Tensor(3, dtype=mstype.int64) | |||
| net = UnsortedSegmentMaxNet(num_segments) | |||
| output = net(input_x, segment_ids).asnumpy() | |||
| expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], | |||