| @@ -145,6 +145,8 @@ const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||||
| const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | ||||
| const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||||
| const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||||
| // Maths | // Maths | ||||
| const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | ||||
| @@ -151,6 +151,8 @@ extern const PrimitivePtr kPrimReshape; | |||||
| extern const PrimitivePtr kPrimTile; | extern const PrimitivePtr kPrimTile; | ||||
| extern const PrimitivePtr kPrimAddN; | extern const PrimitivePtr kPrimAddN; | ||||
| extern const PrimitivePtr KPrimTransData; | extern const PrimitivePtr KPrimTransData; | ||||
| extern const PrimitivePtr kPrimNMSWithMask; | |||||
| extern const PrimitivePtr kPrimPad; | |||||
| // Maths | // Maths | ||||
| extern const PrimitivePtr kPrimTensorAdd; | extern const PrimitivePtr kPrimTensorAdd; | ||||
| @@ -78,6 +78,7 @@ | |||||
| #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | ||||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | ||||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | ||||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||||
| #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" | #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" | ||||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | ||||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | ||||
| @@ -227,6 +228,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>()); | ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>()); | ||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||||
| if (context_ptr->ir_fusion_flag()) { | if (context_ptr->ir_fusion_flag()) { | ||||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | ||||
| } | } | ||||
| @@ -0,0 +1,109 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "kernel//oplib/oplib.h" | |||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef InsertPadForNMSWithMask::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({prim::kPrimNMSWithMask, Xs}); | |||||
| } | |||||
| AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||||
| const TypeId &input_type, const TypeId &output_type, const TypeId &origin_type, | |||||
| const std::vector<size_t> &origin_shape) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> new_pad_inputs; | |||||
| auto prim = std::make_shared<Primitive>(prim::kPrimPad->name()); | |||||
| new_pad_inputs.push_back(NewValueNode(prim)); | |||||
| new_pad_inputs.push_back(input); | |||||
| CNodePtr pad = func_graph->NewCNode(new_pad_inputs); | |||||
| MS_EXCEPTION_IF_NULL(pad); | |||||
| // set kernel build info | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({format}); | |||||
| builder.SetOutputsFormat({format}); | |||||
| builder.SetInputsDeviceType({input_type}); | |||||
| builder.SetOutputsDeviceType({output_type}); | |||||
| builder.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | |||||
| if (kernel::OpLib::FindOp(prim::kPrimPad->name(), kernel::kTBE) != nullptr) { | |||||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||||
| } else { | |||||
| builder.SetKernelType(KernelType::AICPU_KERNEL); | |||||
| } | |||||
| if (pad->kernel_info() == nullptr) { | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| pad->set_kernel_info(kernel_info); | |||||
| } | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), pad.get()); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); | |||||
| return pad; | |||||
| } | |||||
| const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||||
| if (input_num == 0) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||||
| for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { | |||||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); | |||||
| auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); | |||||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, input_idx); | |||||
| auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); | |||||
| if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { | |||||
| return nullptr; | |||||
| } | |||||
| origin_shape[1] = 8; | |||||
| auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_idx); | |||||
| auto pad = INsertPadToGraph(func_graph, cur_input, format, origin_type, device_type, origin_type, origin_shape); | |||||
| MS_EXCEPTION_IF_NULL(pad); | |||||
| pad->set_scope(cnode->scope()); | |||||
| AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector<std::vector<int>>{{0, 0}, {0, 3}}), pad); | |||||
| new_inputs.push_back(pad); | |||||
| } | |||||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||||
| CNodePtr new_node = nullptr; | |||||
| if (kernel_graph == nullptr) { | |||||
| new_node = std::make_shared<CNode>(*cnode); | |||||
| } else { | |||||
| new_node = kernel_graph->NewCNode(cnode); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| new_node->set_inputs(new_inputs); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/common/pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class InsertPadForNMSWithMask : public PatternProcessPass { | |||||
| public: | |||||
| explicit InsertPadForNMSWithMask(bool multigraph = true) | |||||
| : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} | |||||
| ~InsertPadForNMSWithMask() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H | |||||
| @@ -148,6 +148,7 @@ constexpr auto kReturnOpName = "return"; | |||||
| constexpr auto kLarsV2OpName = "LarsV2"; | constexpr auto kLarsV2OpName = "LarsV2"; | ||||
| constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; | constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; | ||||
| constexpr auto kSquareSumAllOpName = "SquareSumAll"; | constexpr auto kSquareSumAllOpName = "SquareSumAll"; | ||||
| constexpr auto kNMSWithMaskOpName = "NMSWithMask"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||
| @@ -2021,11 +2021,6 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | ||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | ||||
| if not self.is_ge: | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) | |||||
| num = bboxes_shape[0] | |||||
| return ((num, 5), (num,), (num,)) | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | ||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||