| @@ -100,6 +100,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | ||||
| @@ -74,10 +74,21 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An | |||||
| } | } | ||||
| bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, | bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, | ||||
| const AnfNodePtr &reduce_sum) { | |||||
| const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { | |||||
| MS_EXCEPTION_IF_NULL(mul0_anf); | MS_EXCEPTION_IF_NULL(mul0_anf); | ||||
| MS_EXCEPTION_IF_NULL(mul1_anf); | MS_EXCEPTION_IF_NULL(mul1_anf); | ||||
| MS_EXCEPTION_IF_NULL(reduce_sum); | MS_EXCEPTION_IF_NULL(reduce_sum); | ||||
| MS_EXCEPTION_IF_NULL(input2); | |||||
| auto addn = input2->cast<CNodePtr>(); | |||||
| if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { | |||||
| MS_LOG(INFO) << "mul's second input is not addn"; | |||||
| return true; | |||||
| } | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0); | |||||
| if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { | |||||
| MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; | |||||
| return true; | |||||
| } | |||||
| if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) { | if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -86,11 +97,6 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf | |||||
| auto mul0 = mul0_anf->cast<CNodePtr>(); | auto mul0 = mul0_anf->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(mul0); | MS_EXCEPTION_IF_NULL(mul0); | ||||
| // when network is _VirtualDatasetCell, quit fusion | |||||
| if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) { | |||||
| return true; | |||||
| } | |||||
| if (IsDepend(graph, mul0->input(1), reduce_sum)) { | if (IsDepend(graph, mul0->input(1), reduce_sum)) { | ||||
| MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; | MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; | ||||
| return true; | return true; | ||||
| @@ -128,7 +134,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons | |||||
| MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; | MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (QuitFusion(graph, mul0, mul1, node)) { | |||||
| if (QuitFusion(graph, mul0, mul1, node, input2)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -84,8 +84,9 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||||
| inputs.push_back(mul->input(index)); | inputs.push_back(mul->input(index)); | ||||
| } | } | ||||
| auto another_input_node = add->input(add->size() - mul_index); | auto another_input_node = add->input(add->size() - mul_index); | ||||
| if (IsUsedByOthers(graph, another_input_node)) { | |||||
| MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; | |||||
| if (another_input_node->isa<CNode>() && | |||||
| AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { | |||||
| MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| inputs.push_back(another_input_node); | inputs.push_back(another_input_node); | ||||
| @@ -32,7 +32,7 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { | |||||
| TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { | TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); | ||||
| EXPECT_NE(g, nullptr); | EXPECT_NE(g, nullptr); | ||||
| std::vector<int> shp{1, 1, 1, 1}; | |||||
| std::vector<int> shp{10, 1024}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| for (size_t i = 0; i < 3; ++i) { | for (size_t i = 0; i < 3; ++i) { | ||||
| @@ -49,6 +49,5 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after"); | ||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -15,12 +15,13 @@ | |||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| addn = P.AddN() | |||||
| mul = P.Mul() | mul = P.Mul() | ||||
| reduce_sum = P.ReduceSum() | reduce_sum = P.ReduceSum() | ||||
| confusion_mul_grad = Primitive('ConfusionMulGrad') | confusion_mul_grad = Primitive('ConfusionMulGrad') | ||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| axis = 2 | |||||
| axis = 1 | |||||
| class FnDict: | class FnDict: | ||||
| @@ -39,8 +40,10 @@ def test_confusion_mul_grad_fusion(tag): | |||||
| @fns | @fns | ||||
| def before(input1, input2, input3): | def before(input1, input2, input3): | ||||
| output1 = mul(input1, input2) | |||||
| mul1 = mul(input3, input2) | |||||
| addn0 = addn((input2, input2)) | |||||
| output1 = mul(input1, addn0) | |||||
| mul1 = mul(input3, addn0) | |||||
| # input axis will be convert to attr in step ConstructKernelGraph | # input axis will be convert to attr in step ConstructKernelGraph | ||||
| output2 = reduce_sum(mul1, axis) | output2 = reduce_sum(mul1, axis) | ||||
| res = make_tuple(output1, output2) | res = make_tuple(output1, output2) | ||||
| @@ -48,7 +51,8 @@ def test_confusion_mul_grad_fusion(tag): | |||||
| @fns | @fns | ||||
| def after(input1, input2, input3): | def after(input1, input2, input3): | ||||
| res = confusion_mul_grad(input1, input2, input3) | |||||
| addn0 = addn(input2, input2) | |||||
| res = confusion_mul_grad(input1, addn0, input3) | |||||
| item0 = tuple_getitem(res, 0) | item0 = tuple_getitem(res, 0) | ||||
| item1 = tuple_getitem(res, 1) | item1 = tuple_getitem(res, 1) | ||||
| res = make_tuple(item0, item1) | res = make_tuple(item0, item1) | ||||