| @@ -15,6 +15,7 @@ | |||||
| """array_ops""" | """array_ops""" | ||||
| from .. import operations as P | |||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..operations import _inner_ops as inner | from ..operations import _inner_ops as inner | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| @@ -24,6 +25,7 @@ from .grad_base import bprop_getters | |||||
| @bprop_getters.register(inner.StridedSliceAICPU) | @bprop_getters.register(inner.StridedSliceAICPU) | ||||
| def get_bprop_strided_slice_aicpu(self): | def get_bprop_strided_slice_aicpu(self): | ||||
| """Generate bprop for StridedSlice""" | """Generate bprop for StridedSlice""" | ||||
| shape_op = P.Shape() | |||||
| input_grad = G.StridedSliceGradAICPU(self.begin_mask, | input_grad = G.StridedSliceGradAICPU(self.begin_mask, | ||||
| self.end_mask, | self.end_mask, | ||||
| self.ellipsis_mask, | self.ellipsis_mask, | ||||
| @@ -28,11 +28,11 @@ strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ | |||||
| .attr("ellipsis_mask", "int") \ | .attr("ellipsis_mask", "int") \ | ||||
| .attr("new_axis_mask", "int") \ | .attr("new_axis_mask", "int") \ | ||||
| .attr("shrink_axis_mask", "int") \ | .attr("shrink_axis_mask", "int") \ | ||||
| .dtype_format(DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(strided_slice_op_info) | @op_info_register(strided_slice_op_info) | ||||
| @@ -29,12 +29,12 @@ strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ | |||||
| .attr("ellipsis_mask", "int") \ | .attr("ellipsis_mask", "int") \ | ||||
| .attr("new_axis_mask", "int") \ | .attr("new_axis_mask", "int") \ | ||||
| .attr("shrink_axis_mask", "int") \ | .attr("shrink_axis_mask", "int") \ | ||||
| .dtype_format(DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(strided_slice_grad_op_info) | @op_info_register(strided_slice_grad_op_info) | ||||