Merge pull request !3317 from zhangbuxue/add_check_for_stridedslice_when_choose_aicpu_or_aicoretags/v1.0.0
| @@ -22,25 +22,58 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| using Tensor = mindspore::tensor::Tensor; | |||||
| using TensorPtr = mindspore::tensor::TensorPtr; | |||||
| using AbstractTensor = mindspore::abstract::AbstractTensor; | |||||
| using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr; | |||||
| using CheckSupportFun = bool (*)(const CNodePtr &cnode); | using CheckSupportFun = bool (*)(const CNodePtr &cnode); | ||||
| constexpr char kAttrStrides[] = "strides"; | constexpr char kAttrStrides[] = "strides"; | ||||
| constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask"; | |||||
| static bool CheckStridedSlice(const CNodePtr &cnode) { | static bool CheckStridedSlice(const CNodePtr &cnode) { | ||||
| // check stride[-1] != 1 TODO | |||||
| // check stride[-1] != 1 | |||||
| if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { | if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) { | ||||
| auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides); | auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides); | ||||
| if (!strides.empty() && strides[strides.size() - 1] == 1) { | |||||
| return true; | |||||
| if (!strides.empty() && strides[strides.size() - 1] != 1) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // check reduction on the last dimension | |||||
| if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) { | |||||
| auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask); | |||||
| AnfNodePtr input = cnode->input(1); | |||||
| int input_dims = 0; | |||||
| if (input->isa<ValueNode>()) { | |||||
| ValuePtr input_value = input->cast<ValueNodePtr>()->value(); | |||||
| if (!input_value->isa<Tensor>()) { | |||||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " | |||||
| << input_value->ToString(); | |||||
| } | |||||
| input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size()); | |||||
| } else if (input->isa<CNode>() || input->isa<Parameter>()) { | |||||
| AbstractBasePtr input_abstract = input->abstract(); | |||||
| if (!input_abstract->isa<AbstractTensor>()) { | |||||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got " | |||||
| << input_abstract->ToString(); | |||||
| } | |||||
| input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size()); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got " | |||||
| << input->ToString(); | |||||
| } | |||||
| int base_number = 2; | |||||
| if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) { | |||||
| return false; | |||||
| } | } | ||||
| } | } | ||||
| // last tensor TODO | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { | bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}}; | |||||
| static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice}, | |||||
| {kStridedSliceGradOpName, CheckStridedSlice}}; | |||||
| auto cnode_type = AnfAlgo::GetCNodeName(cnode); | auto cnode_type = AnfAlgo::GetCNodeName(cnode); | ||||
| auto find_iter = tbe_property_checker.find(cnode_type); | auto find_iter = tbe_property_checker.find(cnode_type); | ||||
| if (find_iter != tbe_property_checker.end()) { | if (find_iter != tbe_property_checker.end()) { | ||||
| @@ -59,6 +59,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||||
| Register(kScatterNdOpName, {2}); | Register(kScatterNdOpName, {2}); | ||||
| Register(kStridedSliceAssignOpName, {1, 2, 3}); | Register(kStridedSliceAssignOpName, {1, 2, 3}); | ||||
| Register(kStridedSliceOpName, {1, 2, 3}); | Register(kStridedSliceOpName, {1, 2, 3}); | ||||
| Register(kStridedSliceGradOpName, {1, 2, 3, 4}); | |||||
| Register(kFlattenGradOpName, {1}); | Register(kFlattenGradOpName, {1}); | ||||
| Register(kExpandDimsOpName, {1}); | Register(kExpandDimsOpName, {1}); | ||||
| Register(kSplitOpName, {0}); | Register(kSplitOpName, {0}); | ||||
| @@ -16,25 +16,27 @@ | |||||
| """StridedSlice op""" | """StridedSlice op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | ||||
| strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ | |||||
| strided_slice_op_info = AiCPURegOp("StridedSlice") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "input", "required") \ | .input(0, "input", "required") \ | ||||
| .input(1, "begin", "required") \ | |||||
| .input(2, "end", "required") \ | |||||
| .input(3, "stride", "required") \ | |||||
| .output(0, "output", "required") \ | .output(0, "output", "required") \ | ||||
| .attr("begin", "listInt") \ | |||||
| .attr("end", "listInt") \ | |||||
| .attr("strides", "listInt") \ | |||||
| .attr("begin_mask", "int") \ | .attr("begin_mask", "int") \ | ||||
| .attr("end_mask", "int") \ | .attr("end_mask", "int") \ | ||||
| .attr("ellipsis_mask", "int") \ | .attr("ellipsis_mask", "int") \ | ||||
| .attr("new_axis_mask", "int") \ | .attr("new_axis_mask", "int") \ | ||||
| .attr("shrink_axis_mask", "int") \ | .attr("shrink_axis_mask", "int") \ | ||||
| .dtype_format(DataType.F32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(strided_slice_op_info) | @op_info_register(strided_slice_op_info) | ||||
| def _strided_slice_aicpu(): | def _strided_slice_aicpu(): | ||||
| """StridedSlice AiCPU register""" | """StridedSlice AiCPU register""" | ||||
| @@ -16,27 +16,28 @@ | |||||
| """StridedSliceGrad op""" | """StridedSliceGrad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | ||||
| strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ | |||||
| strided_slice_grad_op_info = AiCPURegOp("StridedSliceGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "dy", "required") \ | .input(0, "dy", "required") \ | ||||
| .input(1, "shape", "required") \ | |||||
| .input(2, "begin", "required") \ | |||||
| .input(3, "end", "required") \ | |||||
| .input(4, "stride", "required") \ | |||||
| .output(0, "output", "required") \ | .output(0, "output", "required") \ | ||||
| .attr("shapex", "listInt") \ | |||||
| .attr("begin", "listInt") \ | |||||
| .attr("end", "listInt") \ | |||||
| .attr("strides", "listInt") \ | |||||
| .attr("begin_mask", "int") \ | .attr("begin_mask", "int") \ | ||||
| .attr("end_mask", "int") \ | .attr("end_mask", "int") \ | ||||
| .attr("ellipsis_mask", "int") \ | .attr("ellipsis_mask", "int") \ | ||||
| .attr("new_axis_mask", "int") \ | .attr("new_axis_mask", "int") \ | ||||
| .attr("shrink_axis_mask", "int") \ | .attr("shrink_axis_mask", "int") \ | ||||
| .dtype_format(DataType.F32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.I32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(strided_slice_grad_op_info) | @op_info_register(strided_slice_grad_op_info) | ||||
| def _strided_slice_grad_aicpu(): | def _strided_slice_grad_aicpu(): | ||||
| """StridedSliceGrad AiCPU register""" | """StridedSliceGrad AiCPU register""" | ||||
| @@ -915,13 +915,14 @@ test_case_math_ops = [ | |||||
| 'block': G.MinimumGrad(), | 'block': G.MinimumGrad(), | ||||
| 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], | 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('StridedSlice', { | |||||
| 'block': P.StridedSlice(), | |||||
| ('StridedSlice_00', { | |||||
| 'block': P.StridedSlice(shrink_axis_mask=0), | |||||
| 'desc_const': [(0, 1, 2, 1), | 'desc_const': [(0, 1, 2, 1), | ||||
| (2, 3, 3, 4), | (2, 3, 3, 4), | ||||
| (1, 1, 1, 1)], | |||||
| (1, 1, 1, 2)], | |||||
| 'desc_inputs': [[2, 3, 3, 5]], | 'desc_inputs': [[2, 3, 3, 5]], | ||||
| 'desc_bprop': [[2, 2, 1, 3]]}), | |||||
| 'desc_bprop': [[2, 2, 1, 3]], | |||||
| 'skip': ['backward']}), | |||||
| ('Slice_1', { | ('Slice_1', { | ||||
| 'block': P.Slice(), | 'block': P.Slice(), | ||||
| 'desc_const': [(0, 1, 2, 1), | 'desc_const': [(0, 1, 2, 1), | ||||