From: @danishnxt Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosmantags/v1.1.0
| @@ -23,11 +23,21 @@ MS_REG_GPU_KERNEL_TWO( | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| GatherV2GpuFwdKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), | |||
| GatherV2GpuFwdKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherV2GpuFwdKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherV2GpuFwdKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -36,6 +46,14 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| GatherV2GpuFwdKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| GatherV2GpuFwdKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| @@ -44,6 +62,14 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| GatherV2GpuFwdKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO(GatherV2, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| GatherV2GpuFwdKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SparseGatherV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -20,16 +20,16 @@ | |||
| template <typename T, typename S> | |||
| __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1) { | |||
| int num = output_dim0 * output_dim1 * output_dim2; | |||
| int i, j, k; | |||
| for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; | |||
| size_t num = output_dim0 * output_dim1 * output_dim2; | |||
| size_t i, j, k; | |||
| for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; | |||
| write_index += blockDim.x * gridDim.x) { | |||
| i = write_index / (output_dim1 * output_dim2) % output_dim0; | |||
| j = write_index / output_dim2 % output_dim1; | |||
| k = write_index % output_dim2; | |||
| if ((indices[j] >= 0) && (indices[j] < input_dim1)) { | |||
| int read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; | |||
| size_t read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; | |||
| output[write_index] = input[read_index]; | |||
| } else { | |||
| output[write_index] = 0; | |||
| @@ -41,7 +41,7 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di | |||
| template <typename T, typename S> | |||
| void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||
| size_t input_dim1, cudaStream_t stream) { | |||
| int size = output_dim0 * output_dim1 * output_dim2; | |||
| size_t size = output_dim0 * output_dim1 * output_dim2; | |||
| GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1, | |||
| output_dim2, input_dim1); | |||
| return; | |||
| @@ -49,6 +49,9 @@ void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output | |||
| template void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<float, int64_t>(float *input, int64_t *indices, float *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, | |||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| template void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0, | |||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||
| @@ -926,7 +926,7 @@ def test_gather2(): | |||
| [4., 2., 8., 2., 9.,]] | |||
| ).astype(np.float32)) | |||
| indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) | |||
| indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) | |||
| expect = np.array([[[0., 0., 0., 0., 0.], | |||
| [4., 9., 5., 6., 4.], | |||
| [0., 0., 0., 0., 0.]]]) | |||
| @@ -1010,7 +1010,7 @@ def test_gatherV2_dyn_a(): | |||
| [3., 7., 2., 7., 4.,], | |||
| [4., 2., 8., 2., 9.,]] | |||
| ).astype(np.float32)) | |||
| indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) | |||
| indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) | |||
| expect = np.array([[[0., 5., 0.]], | |||
| [[0., 9., 0.]], | |||
| [[0., 8., 0.]], | |||