From 241c8f3d960eb12dbd92cca5f1bf6a1b4b975b0d Mon Sep 17 00:00:00 2001 From: danishnxt Date: Wed, 25 Nov 2020 17:49:27 -0500 Subject: [PATCH] GatherUpdate test file finishing pending Update to GatherV2_Bug_Fix lint fix lint fix - 2 lint fix Update to GatherV2 - fixed default inferImpl func + CudeStreamSync lint fix SyncDevice added dynamic shape init_size input lint --- .../gpu/arrays/gatherv2_gpu_kernel.cc | 17 +- .../gpu/arrays/gatherv2_gpu_kernel.h | 108 +++--------- .../kernel_compiler/gpu/cuda_impl/gatherv2.cu | 91 +--------- .../gpu/cuda_impl/gatherv2.cuh | 12 +- mindspore/core/abstract/prim_arrays.cc | 34 ++-- mindspore/core/abstract/prim_others.cc | 2 +- mindspore/ops/operations/array_ops.py | 18 +- tests/st/ops/gpu/test_gatherV2_op.py | 156 ++++++++++++++++++ 8 files changed, 223 insertions(+), 215 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index 7db992c0fe..78836ae394 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -26,14 +26,6 @@ MS_REG_GPU_KERNEL_TWO( GatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) -MS_REG_GPU_KERNEL_TWO( - SparseGatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherV2GpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - SparseGatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - GatherV2GpuFwdKernel, half, int) MS_REG_GPU_KERNEL_TWO(GatherV2, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -48,5 +40,14 @@ MS_REG_GPU_KERNEL_TWO(GatherV2, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + SparseGatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherV2GpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + SparseGatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherV2GpuFwdKernel, half, int) + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h index 6d6b079fb6..45f3fb6580 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_GATHER_GPU_KERNEL_H -#define MINDSPORE_GATHER_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ #include #include @@ -41,45 +41,17 @@ class GatherV2GpuFwdKernel : public GpuKernel { T *input_addr = GetDeviceAddress(inputs, 0); S *indices_addr = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); - if (is_dynamic_shape_) { - // if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_, - // and axis_ in the workspace to calculate dims_ - size_t *input_shape_device_address = GetDeviceAddress(workspace, 0); - size_t *indices_shape_device_address = GetDeviceAddress(workspace, 1); - int64_t *axis_device_address = GetDeviceAddress(workspace, 2); - - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0], - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1], - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync indices_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2], - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + int64_t *axis_device_address = GetDeviceAddress(inputs, 2); // only get this if in dynamic mode + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), "cudaMemcpyAsync axis_ failed"); - - // output shape will be here for us to copy back to host - size_t *output_shape_device_address = GetDeviceAddress(workspace, 3); - CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(), - indices_shape_device_address, indices_shapes_.size(), axis_device_address, - output_shape_device_address, max_output_size_, - reinterpret_cast(stream_ptr)); - - size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size(); - real_output_shape_.resize(output_rank); - CHECK_CUDA_RET_WITH_ERROR( - cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t), - cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - - } else { - auto input_dim1 = input_shapes_[IntToSize(axis_)]; - CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, - reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode"); + Reshape(); } + auto input_dim1 = input_shapes_[IntToSize(axis_)]; + GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, + reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -87,33 +59,24 @@ class GatherV2GpuFwdKernel : public GpuKernel { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num == 3) { is_dynamic_shape_ = true; - } else if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2."; + MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode."; + } else if (input_num == 2) { + MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode."; + } else { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2 or 3."; } - input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); - - if (is_dynamic_shape_) { - c_node_ptr_ = kernel_node; - size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end()); - max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_); - } else { + if (!is_dynamic_shape_) { axis_ = static_cast(GetAttr(kernel_node, "axis")); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shapes_.size()); - } - Reshape(); } - InitSizeLists(); return true; } void ResetResource() noexcept override { is_dynamic_shape_ = false; - max_output_size_ = -1; input_shapes_.clear(); indices_shapes_.clear(); output_shapes_.clear(); @@ -128,52 +91,32 @@ class GatherV2GpuFwdKernel : public GpuKernel { void InitSizeLists() override { size_t size = GetSize(input_shapes_); input_size_list_.push_back(size); - size = GetSize(indices_shapes_); input_size_list_.push_back(size); - if (is_dynamic_shape_) { - // add by chenweifeng - input_size_list_.push_back(sizeof(S)); - - // allocate maximum size needed - output_size_list_.push_back(max_output_size_); - - // allocate workspace memory for input, indices, axis, and output shape respectively - size = GetSize(input_shapes_); - workspace_size_list_.push_back(size); - - size = GetSize(indices_shapes_); - workspace_size_list_.push_back(size); - - size = sizeof(int32_t); - workspace_size_list_.push_back(size); - - size = GetSize(input_shapes_); - workspace_size_list_.push_back(size); - } else { - size = GetSize(output_shapes_); - output_size_list_.push_back(size); + input_size_list_.push_back(sizeof(int64_t)); } + size = GetSize(output_shapes_); + output_size_list_.push_back(size); } private: void Reshape() { + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shapes_.size()); + } size_t dim_before_axis = 1; for (size_t i = 0; i < IntToSize(axis_); i++) { dim_before_axis *= output_shapes_[i]; } - size_t dim_of_indices = 1; for (size_t i = 0; i < indices_shapes_.size(); i++) { dim_of_indices *= indices_shapes_[i]; } - size_t dim_after_indices = 1; for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { dim_after_indices *= output_shapes_[i]; } - dims_[0] = dim_before_axis; dims_[1] = dim_of_indices; dims_[2] = dim_after_indices; @@ -193,14 +136,9 @@ class GatherV2GpuFwdKernel : public GpuKernel { std::vector input_shapes_; std::vector indices_shapes_; std::vector output_shapes_; - size_t dims_[3] = {}; int64_t axis_; bool is_dynamic_shape_; - int max_output_size_; - std::vector real_output_shape_; - CNodePtr c_node_ptr_; - std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; @@ -208,4 +146,4 @@ class GatherV2GpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_GATHER_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu index fafc51e0fd..a02cd215e7 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu @@ -18,7 +18,7 @@ #include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh" #include "runtime/device/gpu/cuda_common.h" template -__device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, +__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1) { int num = output_dim0 * output_dim1 * output_dim2; int i, j, k; @@ -38,90 +38,17 @@ __device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di return; } - -template -__global__ void GatherV2StaticShapeWrapper(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1) { - GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); -} - -template -__global__ void GatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, - size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, - size_t *output_shape_wksp, const int max_output_size) { - int gt_id = blockIdx.x * blockDim.x + threadIdx.x; - size_t axis = (size_t)(*axis_wksp); - - int output_shape_index = 0; - size_t output_dim0 = 1; - for (size_t i = 0; i < axis; i++) { - output_dim0 *= input_shape_wksp[i]; - - if (gt_id == 0) { - output_shape_wksp[output_shape_index] = input_shape_wksp[i]; - output_shape_index++; - } - } - - size_t output_dim1 = 1; - for (size_t i = 0; i < indices_rank; i++) { - output_dim1 *= indices_shape_wksp[i]; - - if (gt_id == 0) { - output_shape_wksp[output_shape_index] = indices_shape_wksp[i]; - output_shape_index++; - } - } - - size_t output_dim2 = 1; - for (size_t i = axis + 1; i < input_rank; i++) { - output_dim2 *= indices_shape_wksp[i]; - - if (gt_id == 0) { - output_shape_wksp[output_shape_index] = input_shape_wksp[i]; - output_shape_index++; - } - } - - size_t input_dim1 = (size_t)(input_shape_wksp[axis]); - - GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1); -} - -// entry points from gpu kernel's .h file template -void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream) { +void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream) { int size = output_dim0 * output_dim1 * output_dim2; - GatherV2StaticShapeWrapper<<>>(input, indices, output, output_dim0, - output_dim1, output_dim2, input_dim1); + GatherV2Kernel<<>>(input, indices, output, output_dim0, output_dim1, + output_dim2, input_dim1); return; } -template -void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, - size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, - size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream) { - GatherV2DynamicShape<<>>( - input, indices, output, input_shape_wksp, input_rank, indices_shape_wksp, indices_rank, axis_wksp, - output_shape_wksp, max_output_size); -} - -// template instantiations -template void CalGatherV2StaticShape(float *input, int *indices, float *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); - -template void CalGatherV2StaticShape(half *input, int *indices, half *output, size_t output_dim0, - size_t output_dim1, size_t output_dim2, size_t input_dim1, - cudaStream_t stream); - -template void CalGatherV2DynamicShape(float *input, int *indices, float *output, size_t *input_shape_wksp, - size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, - int64_t *axis_wksp, size_t *output_shape_wksp, - const int max_output_size, cudaStream_t stream); +template void GatherV2(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); -template void CalGatherV2DynamicShape(half *input, int *indices, half *output, size_t *input_shape_wksp, - size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank, - int64_t *axis_wksp, size_t *output_shape_wksp, - const int max_output_size, cudaStream_t stream); +template void GatherV2(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh index b45bb35159..9af9fd1b71 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh @@ -14,14 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_GATHER_GPU_CU_H -#define MINDSPORE_GATHER_GPU_CU_H +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ template -void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream); +void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream); -template -void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank, - size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp, - size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream); #endif diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 7a6143f401..d183c4c770 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -408,7 +408,8 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr CheckArgsSize(op_name, args_spec_list, 3); AbstractTensorPtr params = CheckArg(op_name, args_spec_list, 0); AbstractTensorPtr indices = CheckArg(op_name, args_spec_list, 1); - + bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); + bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); int64_t axis_val = 0; // 3rd input is a Tensor when GatherV2 is a dynamic shape operator if (args_spec_list[2]->isa()) { @@ -425,31 +426,36 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr } else { MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); } - auto params_shp = params->shape()->shape(); auto indices_shp = indices->shape()->shape(); - auto params_rank = static_cast(params_shp.size()); + // either inputs or both can be dynamic and computation requires min/max shapes for both + ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape(); + ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape(); + ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape(); + ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape(); + // check axis_val within interval: [-params_rank, params_rank) + if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { + MS_LOG(EXCEPTION) << "For GatherV2 - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) " + << "Got " << axis_val << "."; + } if (axis_val < 0) { axis_val += params_rank; } - - auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector { + auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { ShapeVector out_vec; - std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec)); - copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec)); - copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec)); + std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); + copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec)); + copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); return out_vec; }; - - ShapeVector out_shape = calc_shape(indices_shp); - if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) { - ShapeVector min_shape = calc_shape(indices->shape()->min_shape()); - ShapeVector max_shape = calc_shape(indices->shape()->max_shape()); + ShapeVector out_shape = calc_shape(indices_shp, params_shp); + if (ind_dyn || param_dyn) { + ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); + ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); return std::make_shared(params->element(), std::make_shared(out_shape, min_shape, max_shape)); } - return std::make_shared(params->element(), std::make_shared(out_shape)); } diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 7fe4be12d9..6a2da572c1 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -535,7 +535,7 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con ShapeVector input_shape = input->shape()->shape(); int32_t input_rank = input_shape.size(); ShapeVector inferred_shape(input_rank, Shape::SHP_ANY); - ShapeVector min_shape = {1}; + ShapeVector min_shape(input_rank, 1); ShapeVector max_shape = input_shape; ShapePtr shape = std::make_shared(inferred_shape, min_shape, max_shape); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 4a871f6fc7..6a40f9f2db 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -703,7 +703,6 @@ class GatherV2(PrimitiveWithCheck): [ 4. 54.] [ 2. 55.]] """ - @prim_attr_register def __init__(self): """Initialize index_select""" @@ -713,22 +712,7 @@ class GatherV2(PrimitiveWithCheck): def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) - axis_v = axis['value'] - params_shp = params['shape'] - rank = len(params_shp) - validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) - - if axis_v < 0: - axis_v += rank - out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - if 'min_shape' in indices and 'max_shape' in indices: - out['min_shape'] = params_shp[:axis_v] + indices['min_shape'] + params_shp[axis_v + 1:] - out['max_shape'] = params_shp[:axis_v] + indices['max_shape'] + params_shp[axis_v + 1:] - return out + validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name) class SparseGatherV2(GatherV2): diff --git a/tests/st/ops/gpu/test_gatherV2_op.py b/tests/st/ops/gpu/test_gatherV2_op.py index dae1cfb62c..c530f5108f 100644 --- a/tests/st/ops/gpu/test_gatherV2_op.py +++ b/tests/st/ops/gpu/test_gatherV2_op.py @@ -19,6 +19,7 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import operations as P @@ -937,3 +938,158 @@ def test_gather2(): diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + + +# Dynamic Shape testing ahead +class GatherNetDynamic1(nn.Cell): + def __init__(self): + super(GatherNetDynamic1, self).__init__() + self.gather = P.GatherV2() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + + def construct(self, x, indices): + # Testing only second input dynamic + indices_dyn = self.gpu_convert_to_dynamic_shape(indices) + return self.gather(x, indices_dyn, 0) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_dynamic_1(): + x = Tensor(np.array([[4., 5., 4., 1., 5.,], + [4., 9., 5., 6., 4.,], + [9., 8., 4., 3., 6.,], + [0., 4., 2., 2., 8.,], + [1., 8., 6., 2., 8.,], + [8., 1., 9., 7., 3.,], + [7., 9., 2., 5., 7.,], + [9., 8., 6., 8., 5.,], + [3., 7., 2., 7., 4.,], + [4., 2., 8., 2., 9.,]] + ).astype(np.float32)) + + indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) + expect = np.array([[[0., 0., 0., 0., 0.], + [4., 9., 5., 6., 4.], + [0., 0., 0., 0., 0.]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNetDynamic1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class GatherNetDynamic2(nn.Cell): + def __init__(self): + super(GatherNetDynamic2, self).__init__() + self.gather = P.GatherV2() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + + def construct(self, x, indices): + # Testing only first input dynamic + x_dyn = self.gpu_convert_to_dynamic_shape(x) + return self.gather(x_dyn, indices, -1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_dynamic_2(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNetDynamic2() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class GatherNetDynamic3(nn.Cell): + def __init__(self): + super(GatherNetDynamic3, self).__init__() + self.gather = P.GatherV2() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + + def construct(self, x, indices): + # Testing both inputs dynamic shapes + x_dyn = self.gpu_convert_to_dynamic_shape(x) + indices_dyn = self.gpu_convert_to_dynamic_shape(indices) + return self.gather(x_dyn, indices_dyn, -1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_dynamic_3(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNetDynamic3() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error)