| @@ -18,6 +18,14 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherGpuFwdKernel, double, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherGpuFwdKernel, double, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -34,5 +42,59 @@ MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherGpuFwdKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| GatherGpuFwdKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), | |||
| GatherGpuFwdKernel, int, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), | |||
| GatherGpuFwdKernel, int8_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), | |||
| GatherGpuFwdKernel, int8_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), | |||
| GatherGpuFwdKernel, int16_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), | |||
| GatherGpuFwdKernel, int16_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), | |||
| GatherGpuFwdKernel, int64_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| GatherGpuFwdKernel, int64_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), | |||
| GatherGpuFwdKernel, uchar, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), | |||
| GatherGpuFwdKernel, uchar, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| GatherGpuFwdKernel, bool, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | |||
| GatherGpuFwdKernel, bool, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), | |||
| GatherGpuFwdKernel, uint32_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), | |||
| GatherGpuFwdKernel, uint32_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), | |||
| GatherGpuFwdKernel, uint64_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), | |||
| GatherGpuFwdKernel, uint64_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), | |||
| GatherGpuFwdKernel, uint16_t, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), | |||
| GatherGpuFwdKernel, uint16_t, int64_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -41,7 +41,7 @@ class GatherGpuFwdKernel : public GpuKernel { | |||
| S *index_addr = GetDeviceAddress<S>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], | |||
| Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3], | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -83,15 +83,17 @@ class GatherGpuFwdKernel : public GpuKernel { | |||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | |||
| dim_before_axis *= output_shapes_[i]; | |||
| } | |||
| size_t dim_of_index = output_shapes_[IntToSize(axis_)]; | |||
| size_t dim_after_index = 1; | |||
| size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)]; | |||
| size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)]; | |||
| size_t dim_after_axis = 1; | |||
| for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) { | |||
| dim_after_index *= output_shapes_[i]; | |||
| dim_after_axis *= output_shapes_[i]; | |||
| } | |||
| dims_[0] = dim_before_axis; | |||
| dims_[1] = dim_of_index; | |||
| dims_[2] = dim_after_index; | |||
| dims_[1] = dim_at_axis_input; | |||
| dims_[2] = dim_at_axis_output; | |||
| dims_[3] = dim_after_axis; | |||
| return; | |||
| } | |||
| size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const { | |||
| @@ -109,7 +111,7 @@ class GatherGpuFwdKernel : public GpuKernel { | |||
| std::vector<size_t> index_shapes_; | |||
| std::vector<size_t> output_shapes_; | |||
| size_t dims_[3] = {}; | |||
| size_t dims_[4] = {}; | |||
| int axis_; | |||
| cudnnHandle_t handle_; | |||
| @@ -18,6 +18,14 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherGradGpuKernel, int, double) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherGradGpuKernel, int64_t, double) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -41,7 +41,7 @@ class GatherGradGpuKernel : public GpuKernel { | |||
| S *grad_addr = GetDeviceAddress<S>(inputs, 1); | |||
| S *output_addr = GetDeviceAddress<S>(outputs, 0); | |||
| GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], | |||
| GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3], | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -84,15 +84,17 @@ class GatherGradGpuKernel : public GpuKernel { | |||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | |||
| dim_before_axis *= output_shapes_[i]; | |||
| } | |||
| size_t dim_of_indices = output_shapes_[IntToSize(axis_)]; | |||
| size_t dim_after_indices = 1; | |||
| size_t dim_at_axis_index = index_shapes_[IntToSize(axis_)]; | |||
| size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)]; | |||
| size_t dim_after_axis = 1; | |||
| for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) { | |||
| dim_after_indices *= output_shapes_[i]; | |||
| dim_after_axis *= output_shapes_[i]; | |||
| } | |||
| dims_[0] = dim_before_axis; | |||
| dims_[1] = dim_of_indices; | |||
| dims_[2] = dim_after_indices; | |||
| dims_[1] = dim_at_axis_index; | |||
| dims_[2] = dim_at_axis_output; | |||
| dims_[3] = dim_after_axis; | |||
| return; | |||
| } | |||
| size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const { | |||
| @@ -110,7 +112,7 @@ class GatherGradGpuKernel : public GpuKernel { | |||
| std::vector<size_t> grad_shapes_; | |||
| std::vector<size_t> output_shapes_; | |||
| size_t dims_[3] = {}; | |||
| size_t dims_[4] = {}; | |||
| int axis_; | |||
| cudnnHandle_t handle_; | |||
| @@ -18,35 +18,125 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| __global__ void GatherKernel(const T *input, const S *index, T *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2) { | |||
| size_t num = output_dim0 * output_dim1 * output_dim2; | |||
| __global__ void GatherKernel(const T *input, const S *index, T *output, const size_t dim_before_axis, | |||
| const size_t dim_at_axis_input, const size_t dim_at_axis_output, | |||
| const size_t dim_after_axis) { | |||
| size_t num = dim_before_axis * dim_at_axis_output * dim_after_axis; | |||
| size_t i, k; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num; | |||
| id += blockDim.x * gridDim.x) { | |||
| i = id / (output_dim1 * output_dim2) % output_dim0; | |||
| k = id % output_dim2; | |||
| i = id / (dim_at_axis_output * dim_after_axis); | |||
| k = id % dim_after_axis; | |||
| size_t j_read = static_cast<size_t>(index[id]); | |||
| size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k; | |||
| size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k; | |||
| output[id] = input[read_id]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream) { | |||
| size_t size = output_dim0 * output_dim1 * output_dim2; | |||
| GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, output_dim0, output_dim1, | |||
| output_dim2); | |||
| void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis, | |||
| const size_t dim_at_axis_input, const size_t dim_at_axis_output, | |||
| const size_t dim_after_axis, cudaStream_t stream) { | |||
| size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis; | |||
| GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, dim_before_axis, dim_at_axis_input, | |||
| dim_at_axis_output, dim_after_axis); | |||
| return; | |||
| } | |||
| template void Gather<float, int>(const float *input, const int *index, float *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| template void Gather<half, int>(const half *input, const int *index, half *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| template void Gather<double, int>(const double *input, const int *index, double *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<double, int64_t>(const double *input, const int64_t *index, double *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<float, int>(const float *input, const int *index, float *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<half, int>(const half *input, const int *index, half *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int64_t, int>(const int64_t *input, const int *index, int64_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int64_t, int64_t>(const int64_t *input, const int64_t *index, int64_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int, int>(const int *input, const int *index, int *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int, int64_t>(const int *input, const int64_t *index, int *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int16_t, int>(const int16_t *input, const int *index, int16_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int16_t, int64_t>(const int16_t *input, const int64_t *index, int16_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int8_t, int>(const int8_t *input, const int *index, int8_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<int8_t, int64_t>(const int8_t *input, const int64_t *index, int8_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<unsigned char, int>(const unsigned char *input, const int *index, unsigned char *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<unsigned char, int64_t>(const unsigned char *input, const int64_t *index, unsigned char *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<bool, int>(const bool *input, const int *index, bool *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<bool, int64_t>(const bool *input, const int64_t *index, bool *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint16_t, int>(const uint16_t *input, const int *index, uint16_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint16_t, int64_t>(const uint16_t *input, const int64_t *index, uint16_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint32_t, int>(const uint32_t *input, const int *index, uint32_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint32_t, int64_t>(const uint32_t *input, const int64_t *index, uint32_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint64_t, int>(const uint64_t *input, const int *index, uint64_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void Gather<uint64_t, int64_t>(const uint64_t *input, const int64_t *index, uint64_t *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_input, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| @@ -17,7 +17,8 @@ | |||
| #ifndef MINDSPORE_GATHER_GPU_CU_H | |||
| #define MINDSPORE_GATHER_GPU_CU_H | |||
| template <typename T, typename S> | |||
| void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis, | |||
| const size_t dim_at_axis_input, const size_t dim_at_axis_output, | |||
| const size_t dim_after_axis, cudaStream_t stream); | |||
| #endif | |||
| @@ -16,44 +16,71 @@ | |||
| #include <iostream> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| __global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2) { | |||
| size_t num = output_dim0 * output_dim1 * output_dim2; | |||
| __global__ void GatherGradKernel(const size_t num, const T *index, const S *grad, S *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis) { | |||
| size_t i, k; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num; | |||
| id += blockDim.x * gridDim.x) { | |||
| i = id / (output_dim1 * output_dim2) % output_dim0; | |||
| k = id % output_dim2; | |||
| i = id / (dim_at_axis_index * dim_after_axis); | |||
| k = id % dim_after_axis; | |||
| size_t j_read = static_cast<size_t>(index[id]); | |||
| size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k; | |||
| output[read_id] = grad[id]; | |||
| size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k; | |||
| MsAtomicAdd(output + read_id, grad[id]); | |||
| } | |||
| return; | |||
| } | |||
| template <typename S> | |||
| __global__ void InitOutput(const size_t size, S *output) { | |||
| S zero = 0; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { | |||
| output[id] = zero; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) { | |||
| size_t size = output_dim0 * output_dim1 * output_dim2; | |||
| GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(index, grad, output, | |||
| output_dim0, output_dim1, output_dim2); | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis, | |||
| const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream) { | |||
| size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis; | |||
| InitOutput<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, output); | |||
| size = dim_before_axis * dim_at_axis_index * dim_after_axis; | |||
| GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, index, grad, output, | |||
| dim_before_axis, dim_at_axis_index, | |||
| dim_at_axis_output, dim_after_axis); | |||
| return; | |||
| } | |||
| template void GatherGrad<int, double>(const int *index, const double *grad, double *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void GatherGrad<int64_t, double>(const int64_t *index, const double *grad, double *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void GatherGrad<int, float>(const int *index, const float *grad, float *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| template void GatherGrad<int, half>(const int *index, const half *grad, half *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void GatherGrad<int64_t, float>(const int64_t *index, const float *grad, float *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void GatherGrad<int, half>(const int *index, const half *grad, half *output, | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| template void GatherGrad<int64_t, half>(const int64_t *index, const half *grad, half *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| const size_t dim_before_axis, const size_t dim_at_axis_index, | |||
| const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| @@ -17,7 +17,8 @@ | |||
| #ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H | |||
| #define MINDSPORE_GATHER_GRAD_GPU_CU_H | |||
| template <typename T, typename S> | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis, | |||
| const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis, | |||
| cudaStream_t stream); | |||
| #endif | |||
| @@ -19,6 +19,18 @@ | |||
| #include <cuda_fp16.h> | |||
| __device__ static inline double MsAtomicAdd(double *address, const double val) { | |||
| unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT | |||
| unsigned long long int old = *address_as_ull; // NOLINT | |||
| unsigned long long int assumed; // NOLINT | |||
| do { | |||
| assumed = old; | |||
| old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); | |||
| } | |||
| while (assumed != old); // NOLINT | |||
| return __longlong_as_double(old); | |||
| } | |||
| __device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } | |||
| __device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); } | |||
| @@ -381,16 +381,11 @@ def get_bprop_gather_v2(self): | |||
| @bprop_getters.register(P.GatherD) | |||
| def get_bprop_gather_d(self): | |||
| """Generate bprop for GatherD""" | |||
| gather_d = P.GatherD() | |||
| def bprop(x, dim, index, out, dout): | |||
| return P.GatherDGrad(dim)(index, dout) | |||
| def bprop_ascend(x, dim, index, out, dout): | |||
| return (gather_d(dout, dim, index), zeros_like(dim), zeros_like(index)) | |||
| if context.get_context('device_target') == 'Ascend': | |||
| return bprop_ascend | |||
| x_shp = shape_op(x) | |||
| dx = G.GatherDGrad(dim, x_shp)(index, dout) | |||
| return dx, zeros_like(dim), zeros_like(index) | |||
| return bprop | |||
| @@ -1385,14 +1385,15 @@ class GatherDGrad(PrimitiveWithInfer): | |||
| """Performs grad of GatherD operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, dim=0): | |||
| def __init__(self, dim=0, shape=None): | |||
| """Initialize GatherDGrad""" | |||
| validator.check_is_int(dim, int) | |||
| self.add_prim_attr("dim", dim) | |||
| self.out_shape = shape | |||
| self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output']) | |||
| def infer_shape(self, index_shape, grad_shape): | |||
| return grad_shape | |||
| return self.out_shape | |||
| def infer_dtype(self, index_dtype, grad_dtype): | |||
| return grad_dtype | |||
| @@ -19,32 +19,37 @@ import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore as ms | |||
| import mindspore.ops.operations._grad_ops as P | |||
| import mindspore.ops.operations as P | |||
| import mindspore.ops.operations._grad_ops as G | |||
| from mindspore.ops.composite import GradOperation | |||
| from mindspore import Tensor | |||
| class GatherDGradNet(nn.Cell): | |||
| class GatherDNet(nn.Cell): | |||
| def __init__(self, dim=0): | |||
| super(GatherDGradNet, self).__init__() | |||
| self.gather_d_grad = P.GatherDGrad(dim) | |||
| super(GatherDNet, self).__init__() | |||
| self.gather_d = P.GatherD() | |||
| self.dim = dim | |||
| def construct(self, index, grad): | |||
| return self.gather_d_grad(index, grad) | |||
| def construct(self, x, index): | |||
| return self.gather_d(x, self.dim, index) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int32_fp32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| net = GatherDNet(dim) | |||
| grad_net = GradOperation(get_all=True, sens_param=True)(net) | |||
| output = grad_net(x, index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| diff = output[0].asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @@ -52,16 +57,18 @@ def test_gather_grad_graph_int32_fp32(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int64_fp32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| net = GatherDNet(dim) | |||
| grad_net = GradOperation(get_all=True, sens_param=True)(net) | |||
| output = grad_net(x, index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| diff = output[0].asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @@ -69,16 +76,18 @@ def test_gather_grad_graph_int64_fp32(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int32_fp16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| net = GatherDNet(dim) | |||
| grad_net = GradOperation(get_all=True, sens_param=True)(net) | |||
| output = grad_net(x, index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| diff = output[0].asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @@ -86,16 +95,18 @@ def test_gather_grad_graph_int32_fp16(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int64_fp16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| net = GatherDNet(dim) | |||
| grad_net = GradOperation(get_all=True, sens_param=True)(net) | |||
| output = grad_net(x, index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| diff = output[0].asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @@ -103,13 +114,14 @@ def test_gather_grad_graph_int64_fp16(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int32_fp32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x_shape = (2, 5) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| output = G.GatherDGrad(dim, x_shape)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @@ -119,13 +131,14 @@ def test_gather_grad_pynative_int32_fp32(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int64_fp32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x_shape = (2, 5) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| output = G.GatherDGrad(dim, x_shape)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @@ -135,13 +148,14 @@ def test_gather_grad_pynative_int64_fp32(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int32_fp16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x_shape = (2, 5) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| output = G.GatherDGrad(dim, x_shape)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @@ -151,13 +165,14 @@ def test_gather_grad_pynative_int32_fp16(): | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int64_fp16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x_shape = (2, 5) | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| output = G.GatherDGrad(dim, x_shape)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||