From: @danishnxt Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -26,14 +26,6 @@ MS_REG_GPU_KERNEL_TWO( | |||||
| GatherV2, | GatherV2, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | ||||
| GatherV2GpuFwdKernel, half, int) | GatherV2GpuFwdKernel, half, int) | ||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SparseGatherV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| GatherV2GpuFwdKernel, float, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SparseGatherV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||||
| GatherV2GpuFwdKernel, half, int) | |||||
| MS_REG_GPU_KERNEL_TWO(GatherV2, | MS_REG_GPU_KERNEL_TWO(GatherV2, | ||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| @@ -48,5 +40,14 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, | |||||
| .AddInputAttr(kNumberTypeInt64) | .AddInputAttr(kNumberTypeInt64) | ||||
| .AddOutputAttr(kNumberTypeFloat16), | .AddOutputAttr(kNumberTypeFloat16), | ||||
| GatherV2GpuFwdKernel, half, int) | GatherV2GpuFwdKernel, half, int) | ||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SparseGatherV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| GatherV2GpuFwdKernel, float, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SparseGatherV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||||
| GatherV2GpuFwdKernel, half, int) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_GATHER_GPU_KERNEL_H | |||||
| #define MINDSPORE_GATHER_GPU_KERNEL_H | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| @@ -41,45 +41,17 @@ class GatherV2GpuFwdKernel : public GpuKernel { | |||||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | T *input_addr = GetDeviceAddress<T>(inputs, 0); | ||||
| S *indices_addr = GetDeviceAddress<S>(inputs, 1); | S *indices_addr = GetDeviceAddress<S>(inputs, 1); | ||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | T *output_addr = GetDeviceAddress<T>(outputs, 0); | ||||
| if (is_dynamic_shape_) { | if (is_dynamic_shape_) { | ||||
| // if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_, | |||||
| // and axis_ in the workspace to calculate dims_ | |||||
| size_t *input_shape_device_address = GetDeviceAddress<size_t>(workspace, 0); | |||||
| size_t *indices_shape_device_address = GetDeviceAddress<size_t>(workspace, 1); | |||||
| int64_t *axis_device_address = GetDeviceAddress<int64_t>(workspace, 2); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0], | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1], | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync indices_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2], | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| int64_t *axis_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync axis_ failed"); | "cudaMemcpyAsync axis_ failed"); | ||||
| // output shape will be here for us to copy back to host | |||||
| size_t *output_shape_device_address = GetDeviceAddress<size_t>(workspace, 3); | |||||
| CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(), | |||||
| indices_shape_device_address, indices_shapes_.size(), axis_device_address, | |||||
| output_shape_device_address, max_output_size_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size(); | |||||
| real_output_shape_.resize(output_rank); | |||||
| CHECK_CUDA_RET_WITH_ERROR( | |||||
| cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t), | |||||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | |||||
| } else { | |||||
| auto input_dim1 = input_shapes_[IntToSize(axis_)]; | |||||
| CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode"); | |||||
| Reshape(); | |||||
| } | } | ||||
| auto input_dim1 = input_shapes_[IntToSize(axis_)]; | |||||
| GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| @@ -87,33 +59,24 @@ class GatherV2GpuFwdKernel : public GpuKernel { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num == 3) { | if (input_num == 3) { | ||||
| is_dynamic_shape_ = true; | is_dynamic_shape_ = true; | ||||
| } else if (input_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2."; | |||||
| MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode."; | |||||
| } else if (input_num == 2) { | |||||
| MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode."; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2 or 3."; | |||||
| } | } | ||||
| input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | ||||
| indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); | indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); | ||||
| output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); | ||||
| if (is_dynamic_shape_) { | |||||
| c_node_ptr_ = kernel_node; | |||||
| size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end()); | |||||
| max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_); | |||||
| } else { | |||||
| if (!is_dynamic_shape_) { | |||||
| axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis")); | axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis")); | ||||
| if (axis_ < 0) { | |||||
| axis_ = axis_ + SizeToInt(input_shapes_.size()); | |||||
| } | |||||
| Reshape(); | Reshape(); | ||||
| } | } | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void ResetResource() noexcept override { | void ResetResource() noexcept override { | ||||
| is_dynamic_shape_ = false; | is_dynamic_shape_ = false; | ||||
| max_output_size_ = -1; | |||||
| input_shapes_.clear(); | input_shapes_.clear(); | ||||
| indices_shapes_.clear(); | indices_shapes_.clear(); | ||||
| output_shapes_.clear(); | output_shapes_.clear(); | ||||
| @@ -128,52 +91,32 @@ class GatherV2GpuFwdKernel : public GpuKernel { | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| size_t size = GetSize(input_shapes_); | size_t size = GetSize(input_shapes_); | ||||
| input_size_list_.push_back(size); | input_size_list_.push_back(size); | ||||
| size = GetSize(indices_shapes_); | size = GetSize(indices_shapes_); | ||||
| input_size_list_.push_back(size); | input_size_list_.push_back(size); | ||||
| if (is_dynamic_shape_) { | if (is_dynamic_shape_) { | ||||
| // add by chenweifeng | |||||
| input_size_list_.push_back(sizeof(S)); | |||||
| // allocate maximum size needed | |||||
| output_size_list_.push_back(max_output_size_); | |||||
| // allocate workspace memory for input, indices, axis, and output shape respectively | |||||
| size = GetSize(input_shapes_); | |||||
| workspace_size_list_.push_back(size); | |||||
| size = GetSize(indices_shapes_); | |||||
| workspace_size_list_.push_back(size); | |||||
| size = sizeof(int32_t); | |||||
| workspace_size_list_.push_back(size); | |||||
| size = GetSize(input_shapes_); | |||||
| workspace_size_list_.push_back(size); | |||||
| } else { | |||||
| size = GetSize(output_shapes_); | |||||
| output_size_list_.push_back(size); | |||||
| input_size_list_.push_back(sizeof(int64_t)); | |||||
| } | } | ||||
| size = GetSize(output_shapes_); | |||||
| output_size_list_.push_back(size); | |||||
| } | } | ||||
| private: | private: | ||||
| void Reshape() { | void Reshape() { | ||||
| if (axis_ < 0) { | |||||
| axis_ = axis_ + SizeToInt(input_shapes_.size()); | |||||
| } | |||||
| size_t dim_before_axis = 1; | size_t dim_before_axis = 1; | ||||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | for (size_t i = 0; i < IntToSize(axis_); i++) { | ||||
| dim_before_axis *= output_shapes_[i]; | dim_before_axis *= output_shapes_[i]; | ||||
| } | } | ||||
| size_t dim_of_indices = 1; | size_t dim_of_indices = 1; | ||||
| for (size_t i = 0; i < indices_shapes_.size(); i++) { | for (size_t i = 0; i < indices_shapes_.size(); i++) { | ||||
| dim_of_indices *= indices_shapes_[i]; | dim_of_indices *= indices_shapes_[i]; | ||||
| } | } | ||||
| size_t dim_after_indices = 1; | size_t dim_after_indices = 1; | ||||
| for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { | for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { | ||||
| dim_after_indices *= output_shapes_[i]; | dim_after_indices *= output_shapes_[i]; | ||||
| } | } | ||||
| dims_[0] = dim_before_axis; | dims_[0] = dim_before_axis; | ||||
| dims_[1] = dim_of_indices; | dims_[1] = dim_of_indices; | ||||
| dims_[2] = dim_after_indices; | dims_[2] = dim_after_indices; | ||||
| @@ -193,14 +136,9 @@ class GatherV2GpuFwdKernel : public GpuKernel { | |||||
| std::vector<size_t> input_shapes_; | std::vector<size_t> input_shapes_; | ||||
| std::vector<size_t> indices_shapes_; | std::vector<size_t> indices_shapes_; | ||||
| std::vector<size_t> output_shapes_; | std::vector<size_t> output_shapes_; | ||||
| size_t dims_[3] = {}; | size_t dims_[3] = {}; | ||||
| int64_t axis_; | int64_t axis_; | ||||
| bool is_dynamic_shape_; | bool is_dynamic_shape_; | ||||
| int max_output_size_; | |||||
| std::vector<size_t> real_output_shape_; | |||||
| CNodePtr c_node_ptr_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -208,4 +146,4 @@ class GatherV2GpuFwdKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_GATHER_GPU_KERNEL_H | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ | |||||
| @@ -18,7 +18,7 @@ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh" | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, | |||||
| __global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, | |||||
| size_t output_dim2, size_t input_dim1) { | size_t output_dim2, size_t input_dim1) { | ||||
| int num = output_dim0 * output_dim1 * output_dim2; | int num = output_dim0 * output_dim1 * output_dim2; | ||||
| int i, j, k; | int i, j, k; | ||||
| @@ -38,90 +38,17 @@ __device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T, typename S> | |||||
| __global__ void GatherV2StaticShapeWrapper(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, | |||||
| size_t output_dim2, size_t input_dim1) { | |||||
| GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); | |||||
| } | |||||
| template <typename T, typename S> | |||||
| __global__ void GatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, | |||||
| size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, | |||||
| size_t *output_shape_wksp, const int max_output_size) { | |||||
| int gt_id = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| size_t axis = (size_t)(*axis_wksp); | |||||
| int output_shape_index = 0; | |||||
| size_t output_dim0 = 1; | |||||
| for (size_t i = 0; i < axis; i++) { | |||||
| output_dim0 *= input_shape_wksp[i]; | |||||
| if (gt_id == 0) { | |||||
| output_shape_wksp[output_shape_index] = input_shape_wksp[i]; | |||||
| output_shape_index++; | |||||
| } | |||||
| } | |||||
| size_t output_dim1 = 1; | |||||
| for (size_t i = 0; i < indices_rank; i++) { | |||||
| output_dim1 *= indices_shape_wksp[i]; | |||||
| if (gt_id == 0) { | |||||
| output_shape_wksp[output_shape_index] = indices_shape_wksp[i]; | |||||
| output_shape_index++; | |||||
| } | |||||
| } | |||||
| size_t output_dim2 = 1; | |||||
| for (size_t i = axis + 1; i < input_rank; i++) { | |||||
| output_dim2 *= indices_shape_wksp[i]; | |||||
| if (gt_id == 0) { | |||||
| output_shape_wksp[output_shape_index] = input_shape_wksp[i]; | |||||
| output_shape_index++; | |||||
| } | |||||
| } | |||||
| size_t input_dim1 = (size_t)(input_shape_wksp[axis]); | |||||
| GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); | |||||
| } | |||||
| // entry points from gpu kernel's .h file | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||||
| size_t input_dim1, cudaStream_t stream) { | |||||
| void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||||
| size_t input_dim1, cudaStream_t stream) { | |||||
| int size = output_dim0 * output_dim1 * output_dim2; | int size = output_dim0 * output_dim1 * output_dim2; | ||||
| GatherV2StaticShapeWrapper<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, | |||||
| output_dim1, output_dim2, input_dim1); | |||||
| GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1, | |||||
| output_dim2, input_dim1); | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T, typename S> | |||||
| void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, | |||||
| size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, | |||||
| size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream) { | |||||
| GatherV2DynamicShape<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, stream>>>( | |||||
| input, indices, output, input_shape_wksp, input_rank, indices_shape_wksp, indices_rank, axis_wksp, | |||||
| output_shape_wksp, max_output_size); | |||||
| } | |||||
| // template instantiations | |||||
| template void CalGatherV2StaticShape<float, int>(float *input, int *indices, float *output, size_t output_dim0, | |||||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, | |||||
| cudaStream_t stream); | |||||
| template void CalGatherV2StaticShape<half, int>(half *input, int *indices, half *output, size_t output_dim0, | |||||
| size_t output_dim1, size_t output_dim2, size_t input_dim1, | |||||
| cudaStream_t stream); | |||||
| template void CalGatherV2DynamicShape<float, int>(float *input, int *indices, float *output, size_t *input_shape_wksp, | |||||
| size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, | |||||
| int64_t *axis_wksp, size_t *output_shape_wksp, | |||||
| const int max_output_size, cudaStream_t stream); | |||||
| template void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, | |||||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||||
| template void CalGatherV2DynamicShape<half, int>(half *input, int *indices, half *output, size_t *input_shape_wksp, | |||||
| size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, | |||||
| int64_t *axis_wksp, size_t *output_shape_wksp, | |||||
| const int max_output_size, cudaStream_t stream); | |||||
| template void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, | |||||
| size_t output_dim2, size_t input_dim1, cudaStream_t stream); | |||||
| @@ -14,14 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_GATHER_GPU_CU_H | |||||
| #define MINDSPORE_GATHER_GPU_CU_H | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||||
| size_t input_dim1, cudaStream_t stream); | |||||
| void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||||
| size_t input_dim1, cudaStream_t stream); | |||||
| template <typename T, typename S> | |||||
| void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, | |||||
| size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, | |||||
| size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream); | |||||
| #endif | #endif | ||||
| @@ -408,7 +408,8 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | CheckArgsSize(op_name, args_spec_list, 3); | ||||
| AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | ||||
| bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); | |||||
| bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); | |||||
| int64_t axis_val = 0; | int64_t axis_val = 0; | ||||
| // 3rd input is a Tensor when GatherV2 is a dynamic shape operator | // 3rd input is a Tensor when GatherV2 is a dynamic shape operator | ||||
| if (args_spec_list[2]->isa<AbstractTensor>()) { | if (args_spec_list[2]->isa<AbstractTensor>()) { | ||||
| @@ -425,31 +426,36 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); | MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); | ||||
| } | } | ||||
| auto params_shp = params->shape()->shape(); | auto params_shp = params->shape()->shape(); | ||||
| auto indices_shp = indices->shape()->shape(); | auto indices_shp = indices->shape()->shape(); | ||||
| auto params_rank = static_cast<int64_t>(params_shp.size()); | auto params_rank = static_cast<int64_t>(params_shp.size()); | ||||
| // either inputs or both can be dynamic and computation requires min/max shapes for both | |||||
| ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape(); | |||||
| ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape(); | |||||
| ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape(); | |||||
| ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape(); | |||||
| // check axis_val within interval: [-params_rank, params_rank) | |||||
| if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { | |||||
| MS_LOG(EXCEPTION) << "For GatherV2 - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) " | |||||
| << "Got " << axis_val << "."; | |||||
| } | |||||
| if (axis_val < 0) { | if (axis_val < 0) { | ||||
| axis_val += params_rank; | axis_val += params_rank; | ||||
| } | } | ||||
| auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector { | |||||
| auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { | |||||
| ShapeVector out_vec; | ShapeVector out_vec; | ||||
| std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec)); | |||||
| copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec)); | |||||
| copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec)); | |||||
| std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); | |||||
| copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec)); | |||||
| copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); | |||||
| return out_vec; | return out_vec; | ||||
| }; | }; | ||||
| ShapeVector out_shape = calc_shape(indices_shp); | |||||
| if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) { | |||||
| ShapeVector min_shape = calc_shape(indices->shape()->min_shape()); | |||||
| ShapeVector max_shape = calc_shape(indices->shape()->max_shape()); | |||||
| ShapeVector out_shape = calc_shape(indices_shp, params_shp); | |||||
| if (ind_dyn || param_dyn) { | |||||
| ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); | |||||
| ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); | |||||
| return std::make_shared<AbstractTensor>(params->element(), | return std::make_shared<AbstractTensor>(params->element(), | ||||
| std::make_shared<Shape>(out_shape, min_shape, max_shape)); | std::make_shared<Shape>(out_shape, min_shape, max_shape)); | ||||
| } | } | ||||
| return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | ||||
| } | } | ||||
| @@ -749,7 +749,6 @@ class GatherV2(PrimitiveWithCheck): | |||||
| [ 4. 54.] | [ 4. 54.] | ||||
| [ 2. 55.]] | [ 2. 55.]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize index_select""" | """Initialize index_select""" | ||||
| @@ -759,22 +758,7 @@ class GatherV2(PrimitiveWithCheck): | |||||
| def __check__(self, params, indices, axis): | def __check__(self, params, indices, axis): | ||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | ||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||||
| axis_v = axis['value'] | |||||
| params_shp = params['shape'] | |||||
| rank = len(params_shp) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| if axis_v < 0: | |||||
| axis_v += rank | |||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| if 'min_shape' in indices and 'max_shape' in indices: | |||||
| out['min_shape'] = params_shp[:axis_v] + indices['min_shape'] + params_shp[axis_v + 1:] | |||||
| out['max_shape'] = params_shp[:axis_v] + indices['max_shape'] + params_shp[axis_v + 1:] | |||||
| return out | |||||
| validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name) | |||||
| class SparseGatherV2(GatherV2): | class SparseGatherV2(GatherV2): | ||||
| @@ -19,6 +19,7 @@ import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -937,3 +938,158 @@ def test_gather2(): | |||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert np.all(-diff < error) | assert np.all(-diff < error) | ||||
| # Dynamic Shape testing ahead | |||||
| class GatherNetDynamic1(nn.Cell): | |||||
| def __init__(self): | |||||
| super(GatherNetDynamic1, self).__init__() | |||||
| self.gather = P.GatherV2() | |||||
| self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x, indices): | |||||
| # Testing only second input dynamic | |||||
| indices_dyn = self.gpu_convert_to_dynamic_shape(indices) | |||||
| return self.gather(x, indices_dyn, 0) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gather_dynamic_1(): | |||||
| x = Tensor(np.array([[4., 5., 4., 1., 5.,], | |||||
| [4., 9., 5., 6., 4.,], | |||||
| [9., 8., 4., 3., 6.,], | |||||
| [0., 4., 2., 2., 8.,], | |||||
| [1., 8., 6., 2., 8.,], | |||||
| [8., 1., 9., 7., 3.,], | |||||
| [7., 9., 2., 5., 7.,], | |||||
| [9., 8., 6., 8., 5.,], | |||||
| [3., 7., 2., 7., 4.,], | |||||
| [4., 2., 8., 2., 9.,]] | |||||
| ).astype(np.float32)) | |||||
| indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) | |||||
| expect = np.array([[[0., 0., 0., 0., 0.], | |||||
| [4., 9., 5., 6., 4.], | |||||
| [0., 0., 0., 0., 0.]]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gather = GatherNetDynamic1() | |||||
| output = gather(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| class GatherNetDynamic2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(GatherNetDynamic2, self).__init__() | |||||
| self.gather = P.GatherV2() | |||||
| self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x, indices): | |||||
| # Testing only first input dynamic | |||||
| x_dyn = self.gpu_convert_to_dynamic_shape(x) | |||||
| return self.gather(x_dyn, indices, -1) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gather_dynamic_2(): | |||||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) | |||||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||||
| expect = np.array([[[[1., 3., 4.], | |||||
| [6., 8., 9.], | |||||
| [11., 13., 14.], | |||||
| [16., 18., 19.]], | |||||
| [[21., 23., 24.], | |||||
| [26., 28., 29.], | |||||
| [31., 33., 34.], | |||||
| [36., 38., 39.]], | |||||
| [[41., 43., 44.], | |||||
| [46., 48., 49.], | |||||
| [51., 53., 54.], | |||||
| [56., 58., 59.]]], | |||||
| [[[61., 63., 64.], | |||||
| [66., 68., 69.], | |||||
| [71., 73., 74.], | |||||
| [76., 78., 79.]], | |||||
| [[81., 83., 84.], | |||||
| [86., 88., 89.], | |||||
| [91., 93., 94.], | |||||
| [96., 98., 99.]], | |||||
| [[101., 103., 104.], | |||||
| [106., 108., 109.], | |||||
| [111., 113., 114.], | |||||
| [116., 118., 119.]]]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gather = GatherNetDynamic2() | |||||
| output = gather(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| class GatherNetDynamic3(nn.Cell): | |||||
| def __init__(self): | |||||
| super(GatherNetDynamic3, self).__init__() | |||||
| self.gather = P.GatherV2() | |||||
| self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x, indices): | |||||
| # Testing both inputs dynamic shapes | |||||
| x_dyn = self.gpu_convert_to_dynamic_shape(x) | |||||
| indices_dyn = self.gpu_convert_to_dynamic_shape(indices) | |||||
| return self.gather(x_dyn, indices_dyn, -1) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gather_dynamic_3(): | |||||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) | |||||
| indices = Tensor(np.array([1, 3, 4], dtype='i4')) | |||||
| expect = np.array([[[[1., 3., 4.], | |||||
| [6., 8., 9.], | |||||
| [11., 13., 14.], | |||||
| [16., 18., 19.]], | |||||
| [[21., 23., 24.], | |||||
| [26., 28., 29.], | |||||
| [31., 33., 34.], | |||||
| [36., 38., 39.]], | |||||
| [[41., 43., 44.], | |||||
| [46., 48., 49.], | |||||
| [51., 53., 54.], | |||||
| [56., 58., 59.]]], | |||||
| [[[61., 63., 64.], | |||||
| [66., 68., 69.], | |||||
| [71., 73., 74.], | |||||
| [76., 78., 79.]], | |||||
| [[81., 83., 84.], | |||||
| [86., 88., 89.], | |||||
| [91., 93., 94.], | |||||
| [96., 98., 99.]], | |||||
| [[101., 103., 104.], | |||||
| [106., 108., 109.], | |||||
| [111., 113., 114.], | |||||
| [116., 118., 119.]]]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gather = GatherNetDynamic3() | |||||
| output = gather(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||