diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h index e9a6a13713..60d2bd926f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -143,12 +143,22 @@ class StridedSliceGpuKernel : public GpuKernel { } } - auto shrink_axis_mask_str = static_cast(GetAttr(kernel_node, "shrink_axis_mask")); - auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); - for (size_t l = 0; l < shrink_axis_mask.size(); l++) { - if (shrink_axis_mask[l]) { - end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; - strides_[l] = end_[l] > begin_[l] ? 1 : -1; + auto new_axis_mask_int = static_cast(GetAttr(kernel_node, "new_axis_mask")); + auto new_axis_mask = Dec2Bin(new_axis_mask_int); + for (size_t l = 0; l < new_axis_mask.size(); l++) { + if (new_axis_mask[l]) { + begin_[l] = 0; + end_[l] = input_shape_[l]; + strides_[l] = 1; + } + } + + auto shrink_axis_mask_int = static_cast(GetAttr(kernel_node, "shrink_axis_mask")); + auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); + for (size_t m = 0; m < shrink_axis_mask.size(); m++) { + if (shrink_axis_mask[m]) { + end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; + strides_[m] = end_[m] > begin_[m] ? 1 : -1; } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h index 984702d50d..acc271bac9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h @@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel { } } - auto shrink_axis_mask_str = static_cast(GetAttr(kernel_node, "shrink_axis_mask")); - auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); - for (size_t l = 0; l < shrink_axis_mask.size(); l++) { - if (shrink_axis_mask[l]) { - end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; - strides_[l] = end_[l] > begin_[l] ? 1 : -1; + auto new_axis_mask_int = static_cast(GetAttr(kernel_node, "new_axis_mask")); + auto new_axis_mask = Dec2Bin(new_axis_mask_int); + for (size_t l = 0; l < new_axis_mask.size(); l++) { + if (new_axis_mask[l]) { + begin_[l] = 0; + end_[l] = input_shape_[l]; + strides_[l] = 1; + } + } + + auto shrink_axis_mask_int = static_cast(GetAttr(kernel_node, "shrink_axis_mask")); + auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); + for (size_t m = 0; m < shrink_axis_mask.size(); m++) { + if (shrink_axis_mask[m]) { + end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; + strides_[m] = end_[m] > begin_[m] ? 1 : -1; } } }