Merge pull request !1146 from Etone.Chan/mastertags/v0.3.0-alpha
| @@ -580,27 +580,43 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session:: | |||||
| } | } | ||||
| } | } | ||||
| void BufferFusion::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, | |||||
| const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); | |||||
| std::unordered_set<AnfNodePtr> record{cnode, relu_input}; | |||||
| candidate_fusion->push_back(record); | |||||
| SetRecordFusionId(record); | |||||
| } | |||||
| void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) { | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node) || | |||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { | if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { | ||||
| MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); | MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); | ||||
| } else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || | |||||
| AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) { | |||||
| auto relu_input = cnode->input(1); | |||||
| if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) { | |||||
| MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion); | |||||
| } else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) { | |||||
| MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion); | |||||
| } else if (relu_input->isa<CNode>() && | |||||
| AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) { | |||||
| MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); | |||||
| } else if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | |||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { | |||||
| auto eltwise_input = cnode->input(1); | |||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { | |||||
| MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { | |||||
| if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { | |||||
| MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | |||||
| } else if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { | |||||
| MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); | |||||
| } | |||||
| } | } | ||||
| } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { | } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { | ||||
| MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); | MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); | ||||
| @@ -53,6 +53,8 @@ class BufferFusion : public Pass { | |||||
| const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | ||||
| void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | ||||
| FusedNodeRecord *candidate_fusion, bool is_order); | FusedNodeRecord *candidate_fusion, bool is_order); | ||||
| void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion); | |||||
| void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | ||||
| void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); | ||||
| @@ -33,6 +33,7 @@ relu_op_info = TBERegOp("ReLU") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,15 +26,15 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | ||||
| class TestHWBufferFusionPy : public BackendCommon { | |||||
| class TestHWBufferFusion : public BackendCommon { | |||||
| public: | public: | ||||
| TestHWBufferFusionPy() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {} | |||||
| ~TestHWBufferFusionPy() override = default; | |||||
| TestHWBufferFusion() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {} | |||||
| ~TestHWBufferFusion() override = default; | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | UT::PyFuncGraphFetcher get_py_fun_; | ||||
| }; | }; | ||||
| TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) { | |||||
| TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_1", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_1", "before"); | ||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -90,7 +90,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) { | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) { | |||||
| TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_2", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_2", "before"); | ||||
| std::vector<int> shp{32, 10}; | std::vector<int> shp{32, 10}; | ||||
| std::vector<int> shp_bias{10}; | std::vector<int> shp_bias{10}; | ||||
| @@ -179,7 +179,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) { | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) { | |||||
| TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "before"); | ||||
| std::vector<int> shp{32, 10}; | std::vector<int> shp{32, 10}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -265,5 +265,71 @@ TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) { | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "after"); | ||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "before"); | |||||
| std::vector<int> x_shp{2048, 768}; | |||||
| std::vector<int> y_shp{768, 768}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, x_shp); | |||||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, y_shp); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| auto ret = kg->get_return(); | |||||
| EXPECT_NE(ret->input(1), nullptr); | |||||
| auto tuple = ret->input(1); | |||||
| EXPECT_NE(tuple, nullptr); | |||||
| auto cast = tuple->cast<CNodePtr>()->input(1); | |||||
| EXPECT_NE(cast, nullptr); | |||||
| auto relu = cast->cast<CNodePtr>()->input(1); | |||||
| EXPECT_NE(relu, nullptr); | |||||
| auto matmul = relu->cast<CNodePtr>()->input(1); | |||||
| KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({"NC1HWC0"}); | |||||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||||
| builder.SetInputsDeviceType({kFloat32->type_id()}); | |||||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | |||||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||||
| relu->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu.get()); | |||||
| KernelBuildInfoBuilder builder2; | |||||
| builder2.SetInputsFormat({"NC1HWC0", "NC1HWC0"}); | |||||
| builder2.SetOutputsFormat({"NC1HWC0"}); | |||||
| builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); | |||||
| builder2.SetOutputsDeviceType({kFloat32->type_id()}); | |||||
| builder2.SetKernelType(KernelType::TBE_KERNEL); | |||||
| builder2.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| builder2.SetProcessor(kernel::Processor::AICORE); | |||||
| builder2.SetKernelType(KernelType::TBE_KERNEL); | |||||
| matmul->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), matmul.get()); | |||||
| KernelBuildInfoBuilder builder1; | |||||
| builder1.SetInputsFormat({"NC1HWC0"}); | |||||
| builder1.SetOutputsFormat({"NC1HWC0"}); | |||||
| builder1.SetInputsDeviceType({kFloat32->type_id()}); | |||||
| builder1.SetOutputsDeviceType({kFloat16->type_id()}); | |||||
| builder1.SetKernelType(KernelType::TBE_KERNEL); | |||||
| builder1.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| builder1.SetProcessor(kernel::Processor::AICORE); | |||||
| builder1.SetKernelType(KernelType::TBE_KERNEL); | |||||
| cast->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get()); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto buffer_fusion_pass = std::make_shared<opt::BufferFusion>(); | |||||
| pm->AddPass(buffer_fusion_pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,10 +24,12 @@ Reduce = P.ReduceOp() | |||||
| Biasadd = P.BiasAdd() | Biasadd = P.BiasAdd() | ||||
| Biasaddgrad = G.BiasAddGrad() | Biasaddgrad = G.BiasAddGrad() | ||||
| Cast = P.Cast() | Cast = P.Cast() | ||||
| MatMul = P.MatMul() | |||||
| Fusion_relu_relu = Primitive('FusionOp_ReLU_ReLU') | Fusion_relu_relu = Primitive('FusionOp_ReLU_ReLU') | ||||
| Fusion_biasadd = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAdd_ReLU_ReLU_ReLU') | Fusion_biasadd = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAdd_ReLU_ReLU_ReLU') | ||||
| Fusion_biasaddgrad = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAddGrad_ReLU_ReLU_ReLU') | Fusion_biasaddgrad = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAddGrad_ReLU_ReLU_ReLU') | ||||
| Fusion_matmul_relu = Primitive('FusionOp_MatMul_ReLU') | |||||
| Add = P.TensorAdd() | Add = P.TensorAdd() | ||||
| Sub = P.Sub() | Sub = P.Sub() | ||||
| @@ -133,3 +135,23 @@ def test_conv_singlein_fusion(tag): | |||||
| return tuple | return tuple | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_tbe_matmul_eltwise_fusion(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x, y): | |||||
| matmul = MatMul(x, y) | |||||
| relu = Relu(matmul) | |||||
| res = Cast(relu, mstype.float16) | |||||
| return res | |||||
| @fns | |||||
| def after(x, y): | |||||
| fusion = Fusion_matmul_relu(x, y) | |||||
| res = Cast(fusion) | |||||
| tuple = make_tuple(res) | |||||
| return tuple | |||||
| return fns[tag] | |||||