diff --git a/mindspore/ops/_grad/grad_inner_ops.py b/mindspore/ops/_grad/grad_inner_ops.py index e577b1b0ab..4a6a4c470e 100644 --- a/mindspore/ops/_grad/grad_inner_ops.py +++ b/mindspore/ops/_grad/grad_inner_ops.py @@ -15,6 +15,7 @@ """array_ops""" +from .. import operations as P from ..operations import _grad_ops as G from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like @@ -24,6 +25,7 @@ from .grad_base import bprop_getters @bprop_getters.register(inner.StridedSliceAICPU) def get_bprop_strided_slice_aicpu(self): """Generate bprop for StridedSlice""" + shape_op = P.Shape() input_grad = G.StridedSliceGradAICPU(self.begin_mask, self.end_mask, self.ellipsis_mask, diff --git a/mindspore/ops/_op_impl/aicpu/strided_slice.py b/mindspore/ops/_op_impl/aicpu/strided_slice.py index b62a86f3f3..0506e4104d 100644 --- a/mindspore/ops/_op_impl/aicpu/strided_slice.py +++ b/mindspore/ops/_op_impl/aicpu/strided_slice.py @@ -28,11 +28,11 @@ strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ .attr("ellipsis_mask", "int") \ .attr("new_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \ - .dtype_format(DataType.F32_NCHW, - DataType.I32_NCHW, - DataType.I32_NCHW, - DataType.I32_NCHW, - DataType.F32_NCHW) \ + .dtype_format(DataType.F32_Default, + DataType.I32_Default, + DataType.I32_Default, + DataType.I32_Default, + DataType.F32_Default) \ .get_op_info() @op_info_register(strided_slice_op_info) diff --git a/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py b/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py index f1ce9319c4..b94c5d4c4b 100644 --- a/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +++ b/mindspore/ops/_op_impl/aicpu/strided_slice_grad.py @@ -29,12 +29,12 @@ strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ .attr("ellipsis_mask", "int") \ .attr("new_axis_mask", "int") \ .attr("shrink_axis_mask", "int") \ - .dtype_format(DataType.F32_NCHW, - DataType.I32_NCHW, - DataType.I32_NCHW, - DataType.I32_NCHW, - DataType.I32_NCHW, - DataType.F32_NCHW) \ + .dtype_format(DataType.F32_Default, + DataType.I32_Default, + DataType.I32_Default, + DataType.I32_Default, + DataType.I32_Default, + DataType.F32_Default) \ .get_op_info() @op_info_register(strided_slice_grad_op_info)