diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index c64696de6e..7db992c0fe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -34,5 +34,19 @@ MS_REG_GPU_KERNEL_TWO( SparseGatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO(GatherV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + GatherV2GpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO(GatherV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + GatherV2GpuFwdKernel, half, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h index 81136ff4fe..6d6b079fb6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h @@ -18,6 +18,7 @@ #define MINDSPORE_GATHER_GPU_KERNEL_H #include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh" @@ -41,31 +42,78 @@ class GatherV2GpuFwdKernel : public GpuKernel { S *indices_addr = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); - auto input_dim1 = input_shapes_[IntToSize(axis_)]; - GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, - reinterpret_cast(stream_ptr)); + if (is_dynamic_shape_) { + // if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_, + // and axis_ in the workspace to calculate dims_ + size_t *input_shape_device_address = GetDeviceAddress(workspace, 0); + size_t *indices_shape_device_address = GetDeviceAddress(workspace, 1); + int64_t *axis_device_address = GetDeviceAddress(workspace, 2); + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0], + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1], + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync indices_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2], + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync axis_ failed"); + + // output shape will be here for us to copy back to host + size_t *output_shape_device_address = GetDeviceAddress(workspace, 3); + CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(), + indices_shape_device_address, indices_shapes_.size(), axis_device_address, + output_shape_device_address, max_output_size_, + reinterpret_cast(stream_ptr)); + + size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size(); + real_output_shape_.resize(output_rank); + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t), + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + + } else { + auto input_dim1 = input_shapes_[IntToSize(axis_)]; + CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, + reinterpret_cast(stream_ptr)); + } return true; } bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { + if (input_num == 3) { + is_dynamic_shape_ = true; + } else if (input_num != 2) { MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2."; } + input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); - axis_ = static_cast(GetAttr(kernel_node, "axis")); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shapes_.size()); + if (is_dynamic_shape_) { + c_node_ptr_ = kernel_node; + size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end()); + max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_); + } else { + axis_ = static_cast(GetAttr(kernel_node, "axis")); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shapes_.size()); + } + + Reshape(); } - Reshape(); InitSizeLists(); return true; } void ResetResource() noexcept override { + is_dynamic_shape_ = false; + max_output_size_ = -1; input_shapes_.clear(); indices_shapes_.clear(); output_shapes_.clear(); @@ -84,8 +132,29 @@ class GatherV2GpuFwdKernel : public GpuKernel { size = GetSize(indices_shapes_); input_size_list_.push_back(size); - size = GetSize(output_shapes_); - output_size_list_.push_back(size); + if (is_dynamic_shape_) { + // add by chenweifeng + input_size_list_.push_back(sizeof(S)); + + // allocate maximum size needed + output_size_list_.push_back(max_output_size_); + + // allocate workspace memory for input, indices, axis, and output shape respectively + size = GetSize(input_shapes_); + workspace_size_list_.push_back(size); + + size = GetSize(indices_shapes_); + workspace_size_list_.push_back(size); + + size = sizeof(int32_t); + workspace_size_list_.push_back(size); + + size = GetSize(input_shapes_); + workspace_size_list_.push_back(size); + } else { + size = GetSize(output_shapes_); + output_size_list_.push_back(size); + } } private: @@ -126,7 +195,11 @@ class GatherV2GpuFwdKernel : public GpuKernel { std::vector output_shapes_; size_t dims_[3] = {}; - int axis_; + int64_t axis_; + bool is_dynamic_shape_; + int max_output_size_; + std::vector real_output_shape_; + CNodePtr c_node_ptr_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu index 50fb1ab851..fafc51e0fd 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu @@ -18,8 +18,8 @@ #include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh" #include "runtime/device/gpu/cuda_common.h" template -__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1) { +__device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1) { int num = output_dim0 * output_dim1 * output_dim2; int i, j, k; for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; @@ -38,17 +38,90 @@ __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di return; } + +template +__global__ void GatherV2StaticShapeWrapper(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1) { + GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); +} + +template +__global__ void GatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, + size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, + size_t *output_shape_wksp, const int max_output_size) { + int gt_id = blockIdx.x * blockDim.x + threadIdx.x; + size_t axis = (size_t)(*axis_wksp); + + int output_shape_index = 0; + size_t output_dim0 = 1; + for (size_t i = 0; i < axis; i++) { + output_dim0 *= input_shape_wksp[i]; + + if (gt_id == 0) { + output_shape_wksp[output_shape_index] = input_shape_wksp[i]; + output_shape_index++; + } + } + + size_t output_dim1 = 1; + for (size_t i = 0; i < indices_rank; i++) { + output_dim1 *= indices_shape_wksp[i]; + + if (gt_id == 0) { + output_shape_wksp[output_shape_index] = indices_shape_wksp[i]; + output_shape_index++; + } + } + + size_t output_dim2 = 1; + for (size_t i = axis + 1; i < input_rank; i++) { + output_dim2 *= indices_shape_wksp[i]; + + if (gt_id == 0) { + output_shape_wksp[output_shape_index] = input_shape_wksp[i]; + output_shape_index++; + } + } + + size_t input_dim1 = (size_t)(input_shape_wksp[axis]); + + GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); +} + +// entry points from gpu kernel's .h file template -void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream) { +void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream) { int size = output_dim0 * output_dim1 * output_dim2; - GatherV2Kernel<<>>(input, indices, output, output_dim0, output_dim1, - output_dim2, input_dim1); + GatherV2StaticShapeWrapper<<>>(input, indices, output, output_dim0, + output_dim1, output_dim2, input_dim1); return; } -template void GatherV2(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template +void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, + size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, + size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream) { + GatherV2DynamicShape<<>>( + input, indices, output, input_shape_wksp, input_rank, indices_shape_wksp, indices_rank, axis_wksp, + output_shape_wksp, max_output_size); +} + +// template instantiations +template void CalGatherV2StaticShape(float *input, int *indices, float *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, + cudaStream_t stream); + +template void CalGatherV2StaticShape(half *input, int *indices, half *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, + cudaStream_t stream); + +template void CalGatherV2DynamicShape(float *input, int *indices, float *output, size_t *input_shape_wksp, + size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, + int64_t *axis_wksp, size_t *output_shape_wksp, + const int max_output_size, cudaStream_t stream); -template void GatherV2(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void CalGatherV2DynamicShape(half *input, int *indices, half *output, size_t *input_shape_wksp, + size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, + int64_t *axis_wksp, size_t *output_shape_wksp, + const int max_output_size, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh index b96fee9dc7..b45bb35159 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh @@ -17,7 +17,11 @@ #ifndef MINDSPORE_GATHER_GPU_CU_H #define MINDSPORE_GATHER_GPU_CU_H template -void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); +void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); +template +void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, + size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, + size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream); #endif