Merge pull request !3317 from zhangbuxue/add_check_for_stridedslice_when_choose_aicpu_or_aicoretags/v1.0.0
| @@ -22,25 +22,58 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using Tensor = mindspore::tensor::Tensor; | |||
| using TensorPtr = mindspore::tensor::TensorPtr; | |||
| using AbstractTensor = mindspore::abstract::AbstractTensor; | |||
| using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr; | |||
| using CheckSupportFun = bool (*)(const CNodePtr &cnode); | |||
| constexpr char kAttrStrides[] = "strides"; | |||
| constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask"; | |||
| static bool CheckStridedSlice(const CNodePtr &cnode) { | |||
| // check stride[-1] != 1 TODO | |||
| // check stride[-1] != 1 | |||
| if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { | |||
| auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides); | |||
| if (!strides.empty() && strides[strides.size() - 1] == 1) { | |||
| return true; | |||
| if (!strides.empty() && strides[strides.size() - 1] != 1) { | |||
| return false; | |||
| } | |||
| } | |||
| // check reduction on the last dimension | |||
| if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) { | |||
| auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask); | |||
| AnfNodePtr input = cnode->input(1); | |||
| int input_dims = 0; | |||
| if (input->isa<ValueNode>()) { | |||
| ValuePtr input_value = input->cast<ValueNodePtr>()->value(); | |||
| if (!input_value->isa<Tensor>()) { | |||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " | |||
| << input_value->ToString(); | |||
| } | |||
| input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size()); | |||
| } else if (input->isa<CNode>() || input->isa<Parameter>()) { | |||
| AbstractBasePtr input_abstract = input->abstract(); | |||
| if (!input_abstract->isa<AbstractTensor>()) { | |||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " | |||
| << input_abstract->ToString(); | |||
| } | |||
| input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size()); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got " | |||
| << input->ToString(); | |||
| } | |||
| int base_number = 2; | |||
| if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) { | |||
| return false; | |||
| } | |||
| } | |||
| // last tensor TODO | |||
| return true; | |||
| } | |||
| bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}}; | |||
| static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice}, | |||
| {kStridedSliceGradOpName, CheckStridedSlice}}; | |||
| auto cnode_type = AnfAlgo::GetCNodeName(cnode); | |||
| auto find_iter = tbe_property_checker.find(cnode_type); | |||
| if (find_iter != tbe_property_checker.end()) { | |||
| @@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(kScatterNdOpName, {2}); | |||
| Register(kStridedSliceAssignOpName, {1, 2, 3}); | |||
| Register(kStridedSliceOpName, {1, 2, 3}); | |||
| Register(kStridedSliceGradOpName, {1, 2, 3, 4}); | |||
| Register(kFlattenGradOpName, {1}); | |||
| Register(kExpandDimsOpName, {1}); | |||
| Register(kSplitOpName, {0}); | |||
| @@ -16,25 +16,27 @@ | |||
| """StridedSlice op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ | |||
| strided_slice_op_info = AiCPURegOp("StridedSlice") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "input", "required") \ | |||
| .input(1, "begin", "required") \ | |||
| .input(2, "end", "required") \ | |||
| .input(3, "stride", "required") \ | |||
| .output(0, "output", "required") \ | |||
| .attr("begin", "listInt") \ | |||
| .attr("end", "listInt") \ | |||
| .attr("strides", "listInt") \ | |||
| .attr("begin_mask", "int") \ | |||
| .attr("end_mask", "int") \ | |||
| .attr("ellipsis_mask", "int") \ | |||
| .attr("new_axis_mask", "int") \ | |||
| .attr("shrink_axis_mask", "int") \ | |||
| .dtype_format(DataType.F32_Default, | |||
| DataType.I32_Default, | |||
| DataType.I32_Default, | |||
| DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(strided_slice_op_info) | |||
| def _strided_slice_aicpu(): | |||
| """StridedSlice AiCPU register""" | |||
| @@ -16,27 +16,28 @@ | |||
| """StridedSliceGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ | |||
| strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "dy", "required") \ | |||
| .input(1, "shape", "required") \ | |||
| .input(2, "begin", "required") \ | |||
| .input(3, "end", "required") \ | |||
| .input(4, "stride", "required") \ | |||
| .output(0, "output", "required") \ | |||
| .attr("shapex", "listInt") \ | |||
| .attr("begin", "listInt") \ | |||
| .attr("end", "listInt") \ | |||
| .attr("strides", "listInt") \ | |||
| .attr("begin_mask", "int") \ | |||
| .attr("end_mask", "int") \ | |||
| .attr("ellipsis_mask", "int") \ | |||
| .attr("new_axis_mask", "int") \ | |||
| .attr("shrink_axis_mask", "int") \ | |||
| .dtype_format(DataType.F32_Default, | |||
| DataType.I32_Default, | |||
| DataType.I32_Default, | |||
| DataType.I32_Default, | |||
| DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(strided_slice_grad_op_info) | |||
| def _strided_slice_grad_aicpu(): | |||
| """StridedSliceGrad AiCPU register""" | |||
| @@ -915,13 +915,14 @@ test_case_math_ops = [ | |||
| 'block': G.MinimumGrad(), | |||
| 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], | |||
| 'skip': ['backward']}), | |||
| ('StridedSlice', { | |||
| 'block': P.StridedSlice(), | |||
| ('StridedSlice_00', { | |||
| 'block': P.StridedSlice(shrink_axis_mask=0), | |||
| 'desc_const': [(0, 1, 2, 1), | |||
| (2, 3, 3, 4), | |||
| (1, 1, 1, 1)], | |||
| (1, 1, 1, 2)], | |||
| 'desc_inputs': [[2, 3, 3, 5]], | |||
| 'desc_bprop': [[2, 2, 1, 3]]}), | |||
| 'desc_bprop': [[2, 2, 1, 3]], | |||
| 'skip': ['backward']}), | |||
| ('Slice_1', { | |||
| 'block': P.Slice(), | |||
| 'desc_const': [(0, 1, 2, 1), | |||