| @@ -81,6 +81,7 @@ | |||||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | ||||
| #include "pre_activate/ascend/ir_fission/addn_fission.h" | #include "pre_activate/ascend/ir_fission/addn_fission.h" | ||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | ||||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -116,6 +117,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -34,6 +34,9 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| for (const auto &node_index : manager->node_users()[node]) { | for (const auto &node_index : manager->node_users()[node]) { | ||||
| AnfNodePtr output = node_index.first; | AnfNodePtr output = node_index.first; | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||||
| continue; | |||||
| } | |||||
| auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | ||||
| auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | ||||
| @@ -274,6 +274,9 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| for (const auto &output : bn_outputs) { | for (const auto &output : bn_outputs) { | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||||
| continue; | |||||
| } | |||||
| auto tuple_getitem_cnode = output->cast<CNodePtr>(); | auto tuple_getitem_cnode = output->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); | MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); | ||||
| AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); | AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); | ||||
| @@ -32,7 +32,21 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { | |||||
| std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); | std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); | ||||
| return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); | return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); | ||||
| } | } | ||||
| void AddInputToOutput(const FuncGraphPtr &func_graph, const CNodePtr &old_cnode, const AnfNodePtr &new_node, | |||||
| std::vector<AnfNodePtr> *new_outputs) { | |||||
| MS_EXCEPTION_IF_NULL(old_cnode); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| MS_EXCEPTION_IF_NULL(new_outputs); | |||||
| auto node_to_output = old_cnode->input(kAccumIndex + 1); | |||||
| MS_EXCEPTION_IF_NULL(node_to_output); | |||||
| AbstractBasePtrList abstract_list{old_cnode->abstract(), node_to_output->abstract()}; | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| new_node->set_abstract(abstract_tuple); | |||||
| // Create Output | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, new_outputs); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| const BaseRef MomentumLossscaleFusion::DefinePattern() const { | const BaseRef MomentumLossscaleFusion::DefinePattern() const { | ||||
| VarPtr Xs = std::make_shared<SeqVar>(); | VarPtr Xs = std::make_shared<SeqVar>(); | ||||
| VarPtr X0 = std::make_shared<Var>(); | VarPtr X0 = std::make_shared<Var>(); | ||||
| @@ -80,15 +94,10 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph | |||||
| input_names_value[3] = "x1"; | input_names_value[3] = "x1"; | ||||
| input_names_value.emplace_back("x2"); | input_names_value.emplace_back("x2"); | ||||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); | AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); | ||||
| auto node_to_output = cnode->input(kAccumIndex + 1); | |||||
| MS_EXCEPTION_IF_NULL(node_to_output); | |||||
| AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()}; | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| new_node->set_abstract(abstract_tuple); | |||||
| new_node->set_scope(node->scope()); | new_node->set_scope(node->scope()); | ||||
| // Create Output | |||||
| // Create Outputs | |||||
| std::vector<AnfNodePtr> new_outputs; | std::vector<AnfNodePtr> new_outputs; | ||||
| CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs); | |||||
| AddInputToOutput(func_graph, cnode, new_node, &new_outputs); | |||||
| if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) { | if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) { | ||||
| MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString(); | MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString(); | ||||
| } | } | ||||