align fuse_ops_fusion align composite_ops_fusion unify ops table Init new_code's kernel_info with orig_node's kernel_info in function NewCNodeWithInfo enable run bert add pass tensor_promotion add macro for bias_add and bias_add_grad in expander pass exclude unused attrs in primitive compare for GraphKernelCSE exclude fusion_type in kernelinfo cmp for cse in graphkernel check processor remove graph kernel pass before select kernel recover run_standalone_pretrain_ascend.sh remove is_before_kernel_select move add_atomic_clean from pass directory to graph_kernel directory update fuse op list in Ascend back-endtags/v1.1.0
| @@ -74,7 +74,6 @@ | |||||
| #include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | #include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | ||||
| #include "backend/optimizer/pass/eliminate_redundant_op.h" | #include "backend/optimizer/pass/eliminate_redundant_op.h" | ||||
| #include "backend/optimizer/pass/common_subexpression_elimination.h" | #include "backend/optimizer/pass/common_subexpression_elimination.h" | ||||
| #include "backend/optimizer/pass/add_atomic_clean.h" | |||||
| #include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" | #include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" | ||||
| #include "backend/optimizer/ascend/format_type/check_consistency.h" | #include "backend/optimizer/ascend/format_type/check_consistency.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" | #include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" | ||||
| @@ -382,74 +381,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| } | } | ||||
| } | } | ||||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_d_graph_kernel_opt_before_graph_" + std::to_string(!is_before_kernel_select) + "_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph); | |||||
| } | |||||
| // Fuse graph kernels with basic ops | |||||
| static_cast<void>(FuseCompositeOps(kernel_graph, is_before_kernel_select)); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_d_graph_kernel_opt_end_graph_" + std::to_string(!is_before_kernel_select) + "_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_fuse_basic_opt_before_graph_" + std::to_string(!is_before_kernel_select) + "_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph, true); | |||||
| } | |||||
| // Fuse basic ops with basic ops | |||||
| static_cast<void>(FuseBasicOps(kernel_graph, is_before_kernel_select)); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_fuse_basic_opt_end_graph_" + std::to_string(!is_before_kernel_select) + "_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_d_add_atomic_clean_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph); | |||||
| } | |||||
| AddAtomicClean(kernel_graph); | |||||
| if (save_graphs) { | |||||
| std::string file_name = "hwopt_d_end_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_name, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -25,11 +25,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||||
| void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select = false); | |||||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select = false); | |||||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/optimizer/pass/add_atomic_clean.h" | |||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <functional> | #include <functional> | ||||
| @@ -75,7 +75,7 @@ CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr<session::KernelGraph> &k | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| bool AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto mng = kernel_graph->manager(); | auto mng = kernel_graph->manager(); | ||||
| if (mng == nullptr) { | if (mng == nullptr) { | ||||
| @@ -83,6 +83,7 @@ void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| kernel_graph->set_manager(mng); | kernel_graph->set_manager(mng); | ||||
| } | } | ||||
| auto &todos = kernel_graph->execution_order(); | auto &todos = kernel_graph->execution_order(); | ||||
| bool changed = false; | |||||
| for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { | for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { | ||||
| auto node = *iter; | auto node = *iter; | ||||
| if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { | if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { | ||||
| @@ -112,9 +113,17 @@ void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| new_cnode->set_kernel_info(node->kernel_info_ptr()); | new_cnode->set_kernel_info(node->kernel_info_ptr()); | ||||
| mng->Replace(node, new_cnode); | mng->Replace(node, new_cnode); | ||||
| g_output_idx.clear(); | g_output_idx.clear(); | ||||
| changed = true; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return changed; | |||||
| } | |||||
| bool CleanAddAtomic::Run(const FuncGraphPtr &func_graph) { | |||||
| return AddAtomicClean(std::dynamic_pointer_cast<session::KernelGraph>(func_graph)); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,16 +14,23 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| class CleanAddAtomic : public Pass { | |||||
| public: | |||||
| CleanAddAtomic() : Pass("clean_add_atomic") {} | |||||
| ~CleanAddAtomic() override = default; | |||||
| bool Run(const FuncGraphPtr &func_graph) override; | |||||
| }; | |||||
| using CleanAddAtomicPtr = std::shared_ptr<CleanAddAtomic>; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_ | |||||
| @@ -35,24 +35,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||||
| #if ENABLE_D | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, | |||||
| prim::kPrimExpandDims}; | |||||
| if (!is_before_kernel_select) { | |||||
| fusible_basic_ops.push_back(prim::kPrimCast); | |||||
| } | |||||
| #elif ENABLE_GPU | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList(); | |||||
| #else | |||||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||||
| #endif | |||||
| return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, | |||||
| const AnfNodePtr &node) { | |||||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||||
| if (cur_node == node) { | if (cur_node == node) { | ||||
| return FOLLOW; | return FOLLOW; | ||||
| } | } | ||||
| @@ -60,16 +43,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKe | |||||
| return EXCLUDE; | return EXCLUDE; | ||||
| } | } | ||||
| bool is_fusable = IsBasicOp(node, info.is_before_kernel_select); | |||||
| bool is_fusable = IsBasicFuseOp(node); | |||||
| return is_fusable ? FOLLOW : EXCLUDE; | return is_fusable ? FOLLOW : EXCLUDE; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { | |||||
| GraphKernelInfo info; | |||||
| info.is_before_kernel_select = is_before_kernel_select; | |||||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) { | |||||
| // Search fusable nodes according input direction. | // Search fusable nodes according input direction. | ||||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); | |||||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | ||||
| if (used_nodes.size() > 1) { | if (used_nodes.size() > 1) { | ||||
| used_nodes = RemoveCircle(used_nodes, false); | used_nodes = RemoveCircle(used_nodes, false); | ||||
| @@ -176,7 +156,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con | |||||
| } | } | ||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos, | bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos, | ||||
| std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) { | |||||
| std::unordered_set<AnfNodePtr> *fused_ops) { | |||||
| bool changed = false; | bool changed = false; | ||||
| auto mng = kernel_graph->manager(); | auto mng = kernel_graph->manager(); | ||||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | ||||
| @@ -187,12 +167,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| if (fused_ops->count(node)) { | if (fused_ops->count(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| bool is_basic_op = IsBasicOp(node, is_before_kernel_select); | |||||
| bool is_basic_op = IsBasicFuseOp(node); | |||||
| if (!is_basic_op || !kernel_graph->nodes().contains(node)) { | if (!is_basic_op || !kernel_graph->nodes().contains(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); | |||||
| auto fuse_nodes = FindFuseCNodes(node); | |||||
| if (fuse_nodes.size() <= 1) { | if (fuse_nodes.size() <= 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -204,10 +184,8 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); | std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); | ||||
| RemoveControlDependOut(fg, &outputs, mng); | RemoveControlDependOut(fg, &outputs, mng); | ||||
| ConvertNonscalarTensorToParameter(fg, &inputs); | ConvertNonscalarTensorToParameter(fg, &inputs); | ||||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); | |||||
| if (!is_before_kernel_select) { | |||||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); | |||||
| } | |||||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); | |||||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); | |||||
| ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); | ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); | ||||
| @@ -224,7 +202,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select) { | |||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto mng = kernel_graph->manager(); | auto mng = kernel_graph->manager(); | ||||
| if (mng == nullptr) { | if (mng == nullptr) { | ||||
| @@ -234,9 +212,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select | |||||
| std::unordered_set<AnfNodePtr> fused_ops; | std::unordered_set<AnfNodePtr> fused_ops; | ||||
| auto todos = TopoSort(kernel_graph->get_return()); | auto todos = TopoSort(kernel_graph->get_return()); | ||||
| std::reverse(todos.begin(), todos.end()); | std::reverse(todos.begin(), todos.end()); | ||||
| return FuseBasicOps(kernel_graph, todos, &fused_ops, is_before_kernel_select); | |||||
| return FuseBasicOps(kernel_graph, todos, &fused_ops); | |||||
| } | } | ||||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph, false); } | |||||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select); | |||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph); | |||||
| class BasicOpsFusion : public Pass { | class BasicOpsFusion : public Pass { | ||||
| public: | public: | ||||
| @@ -60,127 +60,23 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||||
| #if ENABLE_D | |||||
| std::vector<PrimitivePtr> basic_ops = { | |||||
| prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum, | |||||
| prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt, | |||||
| prim::kPrimExpandDims, prim::kPrimReciprocal, prim::kPrimLessEqual}; | |||||
| if (!is_before_kernel_select) { | |||||
| basic_ops.push_back(prim::kPrimCast); | |||||
| } | |||||
| #elif ENABLE_GPU | |||||
| std::vector<PrimitivePtr> basic_ops = GetFusibleOpList(); | |||||
| #else | |||||
| std::vector<PrimitivePtr> basic_ops; | |||||
| #endif | |||||
| return std::any_of(basic_ops.begin(), basic_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| bool IsReduceOp(const AnfNodePtr &node) { | |||||
| std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin, | |||||
| prim::kPrimReduceMax, prim::kPrimReduceAll}; | |||||
| return std::any_of(reduce_ops.begin(), reduce_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| void GetGraphKernelInfo(const FuncGraphPtr &fg, GraphKernelInfo *info) { | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto mng = fg->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(fg, false); | |||||
| fg->set_manager(mng); | |||||
| } | |||||
| const auto &nodes = fg->nodes(); | |||||
| info->op_type = ELEWISE; | |||||
| info->cal_step = -1; | |||||
| info->reduce_op_num = 0; | |||||
| for (auto node : nodes) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| continue; | |||||
| } | |||||
| info->cal_step++; | |||||
| if (IsReduceOp(node)) { | |||||
| info->op_type = REDUCE; | |||||
| info->reduce_op_num++; | |||||
| } | |||||
| } | |||||
| auto fg_flag = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| if (fg_flag != nullptr) { | |||||
| auto fg_name = GetValue<std::string>(fg_flag); | |||||
| info->origin_composite_name = fg_name; | |||||
| } | |||||
| } | |||||
| bool IsCompositeFuseBasic(const GraphKernelInfo &info, const AnfNodePtr &node) { | |||||
| #if ENABLE_D | |||||
| std::vector<PrimitivePtr> fusable_with_reduce; | |||||
| if (!info.is_before_kernel_select) { | |||||
| fusable_with_reduce.push_back(prim::kPrimCast); | |||||
| } | |||||
| if (info.op_type == REDUCE && | |||||
| (info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) { | |||||
| return std::any_of(fusable_with_reduce.begin(), fusable_with_reduce.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| #endif | |||||
| return IsBasicFuseOp(node, info.is_before_kernel_select); | |||||
| } | |||||
| bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) { | |||||
| bool IsFuse(const AnfNodePtr &node) { | |||||
| // composite fuse composite op | // composite fuse composite op | ||||
| if (AnfAlgo::IsGraphKernel(node)) { | if (AnfAlgo::IsGraphKernel(node)) { | ||||
| #if ENABLE_D | |||||
| return false; | |||||
| #else | |||||
| return true; | return true; | ||||
| #endif | |||||
| } | } | ||||
| return IsCompositeFuseBasic(info, node); | |||||
| return IsBasicFuseOp(node); | |||||
| } | } | ||||
| void UpdateGraphKernelInfo(GraphKernelInfo *info, const AnfNodePtr &node) { | |||||
| if (IsPrimitiveCNode(node)) { | |||||
| info->cal_step++; | |||||
| if (IsReduceOp(node)) { | |||||
| info->op_type = REDUCE; | |||||
| } | |||||
| info->origin_composite_name += AnfAlgo::GetCNodePrimitive(node)->name() + "_"; | |||||
| } else if (AnfAlgo::IsGraphKernel(node)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto composite_g = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||||
| GraphKernelInfo fuse_info; | |||||
| GetGraphKernelInfo(composite_g, &fuse_info); | |||||
| info->cal_step += fuse_info.cal_step; | |||||
| info->origin_composite_name += fuse_info.origin_composite_name; | |||||
| } | |||||
| } | |||||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) { | |||||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||||
| if (cur_node == node) { | if (cur_node == node) { | ||||
| return FOLLOW; | return FOLLOW; | ||||
| } | } | ||||
| #if ENABLE_D | |||||
| if (!IsPrimitiveCNode(node)) { | |||||
| return EXCLUDE; | |||||
| } | |||||
| #else | |||||
| bool is_fuse_composite = AnfAlgo::IsGraphKernel(node); | |||||
| if (!IsPrimitiveCNode(node) && !is_fuse_composite) { | |||||
| return EXCLUDE; | |||||
| } | |||||
| #endif | |||||
| bool is_fusable = IsFuse(*info, node); | |||||
| if (is_fusable) { | |||||
| UpdateGraphKernelInfo(info, node); | |||||
| } | |||||
| bool is_fusable = IsFuse(node); | |||||
| return is_fusable ? FOLLOW : EXCLUDE; | return is_fusable ? FOLLOW : EXCLUDE; | ||||
| } | } | ||||
| IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) { | |||||
| IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||||
| if (cur_node == node) { | if (cur_node == node) { | ||||
| return FOLLOW; | return FOLLOW; | ||||
| } | } | ||||
| @@ -195,14 +91,7 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI | |||||
| } | } | ||||
| return EXCLUDE; | return EXCLUDE; | ||||
| } | } | ||||
| if (!IsPrimitiveCNode(node)) { | |||||
| return EXCLUDE; | |||||
| } | |||||
| bool is_fusable = IsFuse(*info, node); | |||||
| if (is_fusable) { | |||||
| UpdateGraphKernelInfo(info, node); | |||||
| } | |||||
| bool is_fusable = IsBasicFuseOp(node); | |||||
| return is_fusable ? FOLLOW : EXCLUDE; | return is_fusable ? FOLLOW : EXCLUDE; | ||||
| } | } | ||||
| @@ -350,19 +239,15 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) { | |||||
| lst->assign(res.begin(), res.end()); | lst->assign(res.begin(), res.end()); | ||||
| } | } | ||||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { | |||||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) { | |||||
| auto func_graph = cnode->func_graph(); | auto func_graph = cnode->func_graph(); | ||||
| auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||||
| GraphKernelInfo info; | |||||
| info.is_before_kernel_select = is_before_kernel_select; | |||||
| GetGraphKernelInfo(graph_kernel_g, &info); | |||||
| auto mng = func_graph->manager(); | auto mng = func_graph->manager(); | ||||
| // Search fusable nodes according input direction. | // Search fusable nodes according input direction. | ||||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, &info, std::placeholders::_1); | |||||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | ||||
| std::reverse(used_nodes.begin(), used_nodes.end()); | std::reverse(used_nodes.begin(), used_nodes.end()); | ||||
| // Search fusable nodes according output direction. | // Search fusable nodes according output direction. | ||||
| auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, &info, std::placeholders::_1); | |||||
| auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1); | |||||
| auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); | auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); | ||||
| used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); | used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); | ||||
| @@ -373,7 +258,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker | |||||
| return used_nodes; | return used_nodes; | ||||
| } | } | ||||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) { | |||||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| bool changed = false; | bool changed = false; | ||||
| auto &todos = kernel_graph->execution_order(); | auto &todos = kernel_graph->execution_order(); | ||||
| @@ -392,19 +277,19 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| } | } | ||||
| } | } | ||||
| auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); | |||||
| auto fuse_nodes = FindFuseCNodes(node); | |||||
| if (fuse_nodes.size() <= 1) { | if (fuse_nodes.size() <= 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| changed = true; | changed = true; | ||||
| FuseNodesToSubGraph(fuse_nodes, kernel_graph, "", is_before_kernel_select); | |||||
| FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); | |||||
| } | } | ||||
| return changed; | return changed; | ||||
| } | } | ||||
| bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { | bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { | ||||
| return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph), false); | |||||
| return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph)); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,25 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| enum GraphKernelType { | |||||
| ELEWISE = 0, // only contain elewise basic ops | |||||
| REDUCE, // contain reduce ops | |||||
| CUBE, // contain cube ops | |||||
| }; | |||||
| struct GraphKernelInfo { | |||||
| GraphKernelType op_type = ELEWISE; | |||||
| bool is_before_kernel_select = false; | |||||
| int reduce_op_num = 0; | |||||
| int cal_step = 0; | |||||
| std::string origin_composite_name = ""; | |||||
| }; | |||||
| // when composite fuse composite the cal step is greate than this number, not fuse | |||||
| #if ENABLE_D | |||||
| const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; | |||||
| const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; | |||||
| #endif | |||||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | ||||
| "LambNextMV", "LambUpdateWithLR"}; | "LambNextMV", "LambUpdateWithLR"}; | ||||
| @@ -52,7 +33,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo | |||||
| void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); | void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); | ||||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false); | |||||
| bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| class CompositeOpsFusion : public Pass { | class CompositeOpsFusion : public Pass { | ||||
| public: | public: | ||||
| @@ -176,7 +176,7 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func | |||||
| EliminateRedundantParameters(new_func_graph, &inputs); | EliminateRedundantParameters(new_func_graph, &inputs); | ||||
| kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); | kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes); | ||||
| kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); | ||||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false); | |||||
| auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); | |||||
| SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); | ||||
| std::string graph_kernel_flag; | std::string graph_kernel_flag; | ||||
| std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) { | std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) { | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | ||||
| #include "backend/kernel_compiler/kernel.h" | |||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | #include "backend/optimizer/pass/const_input_to_attr_registry.h" | ||||
| @@ -445,7 +446,7 @@ AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { | |||||
| } | } | ||||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs, bool is_before_kernel_select) { | |||||
| const AnfNodePtrList &outputs) { | |||||
| auto func_node = NewValueNode(fg); | auto func_node = NewValueNode(fg); | ||||
| std::vector<AnfNodePtr> fn_inputs; | std::vector<AnfNodePtr> fn_inputs; | ||||
| fn_inputs.push_back(func_node); | fn_inputs.push_back(func_node); | ||||
| @@ -467,9 +468,6 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); | auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); | ||||
| auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); | auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); | ||||
| fg->parameters()[i]->set_abstract(input_abs); | fg->parameters()[i]->set_abstract(input_abs); | ||||
| if (is_before_kernel_select) { | |||||
| fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| } | |||||
| } | } | ||||
| return fuse_cnode; | return fuse_cnode; | ||||
| } | } | ||||
| @@ -529,8 +527,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f | |||||
| } | } | ||||
| void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix, | |||||
| bool is_before_kernel_select) { | |||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) { | |||||
| if (fuse_nodes.empty()) { | if (fuse_nodes.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -547,10 +544,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||||
| AnfNodePtrList outputs; | AnfNodePtrList outputs; | ||||
| std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs); | std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs); | ||||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); | |||||
| if (!is_before_kernel_select) { | |||||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); | |||||
| } | |||||
| auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); | |||||
| SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); | |||||
| // Handle get-item probleam. | // Handle get-item probleam. | ||||
| ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs); | ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs); | ||||
| @@ -702,9 +697,20 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, | |||||
| prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad, | |||||
| prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad}; | |||||
| prim::kPrimSquare, | |||||
| #if ENABLE_GPU | |||||
| prim::kPrimBiasAdd, | |||||
| prim::kPrimBiasAddGrad, | |||||
| prim::kPrimGelu, | |||||
| prim::kPrimGeluGrad, | |||||
| prim::kPrimFusedAdam, | |||||
| prim::kPrimFusedAdamWeightDecay, | |||||
| prim::kPrimTanhGrad, | |||||
| prim::kPrimReduceMean, | |||||
| prim::kPrimMaximumGrad, | |||||
| prim::kPrimMinimumGrad | |||||
| #endif | |||||
| }; | |||||
| return expand_ops; | return expand_ops; | ||||
| } | } | ||||
| @@ -725,16 +731,54 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||||
| } | } | ||||
| std::vector<PrimitivePtr> GetFusibleOpList() { | std::vector<PrimitivePtr> GetFusibleOpList() { | ||||
| #if ENABLE_D | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, | |||||
| prim::kPrimTranspose, prim::kPrimCast}; | |||||
| #elif ENABLE_GPU | |||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | |||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | |||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, | |||||
| prim::kPrimTranspose}; | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | |||||
| prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||||
| prim::kPrimCast}; | |||||
| #else | |||||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||||
| #endif | |||||
| return fusible_basic_ops; | return fusible_basic_ops; | ||||
| } | } | ||||
| bool CheckProcessor(const AnfNodePtr &node, kernel::Processor processor = kernel::Processor::AICORE) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| if (node_kernel_info == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo(); | |||||
| if (node_build_info == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return node_build_info->processor() == processor; | |||||
| } | |||||
| bool IsBasicFuseOp(const AnfNodePtr &node) { | |||||
| std::vector<PrimitivePtr> basic_ops = GetFusibleOpList(); | |||||
| #if ENABLE_D | |||||
| if (!CheckProcessor(node)) { | |||||
| return false; | |||||
| } | |||||
| #endif | |||||
| return std::any_of(basic_ops.begin(), basic_ops.end(), | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||||
| } | |||||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -45,12 +45,11 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const | |||||
| const AnfNodePtrList &outputs, kernel::Processor processor); | const AnfNodePtrList &outputs, kernel::Processor processor); | ||||
| AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs); | AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs); | ||||
| AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| const AnfNodePtrList &outputs, bool is_before_kernel_select); | |||||
| const AnfNodePtrList &outputs); | |||||
| void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, | ||||
| const AnfNodePtrList &outputs); | const AnfNodePtrList &outputs); | ||||
| void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix, | |||||
| bool is_before_kernel_select); | |||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix); | |||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); | bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); | ||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | ||||
| std::map<std::string, AnfNodePtr> *address_node_map); | std::map<std::string, AnfNodePtr> *address_node_map); | ||||
| @@ -59,6 +58,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps(); | std::unordered_set<PrimitivePtr> GetExpandOps(); | ||||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | ||||
| std::vector<PrimitivePtr> GetFusibleOpList(); | std::vector<PrimitivePtr> GetFusibleOpList(); | ||||
| bool IsBasicFuseOp(const AnfNodePtr &node); | |||||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,7 +45,7 @@ bool TensorPromotion::Run(const FuncGraphPtr &func_graph) { | |||||
| AnfNodePtrList inputs, outputs; | AnfNodePtrList inputs, outputs; | ||||
| inputs.insert(inputs.end(), args.begin() + 1, args.end()); | inputs.insert(inputs.end(), args.begin() + 1, args.end()); | ||||
| kernel::GetFuncGraphOutputNodes(fg, &outputs); | kernel::GetFuncGraphOutputNodes(fg, &outputs); | ||||
| auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs, false); | |||||
| auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs); | |||||
| SetNewKernelInfo(new_cnode, fg, inputs, outputs, AnfAlgo::GetProcessor(node)); | SetNewKernelInfo(new_cnode, fg, inputs, outputs, AnfAlgo::GetProcessor(node)); | ||||
| mng->Replace(node, new_cnode); | mng->Replace(node, new_cnode); | ||||
| changed = true; | changed = true; | ||||
| @@ -41,6 +41,14 @@ | |||||
| #include "debug/data_dump/dump_json_parser.h" | #include "debug/data_dump/dump_json_parser.h" | ||||
| #include "debug/tensor_load.h" | #include "debug/tensor_load.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" | |||||
| #include "backend/optimizer/graph_kernel/value_graph_binder.h" | |||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean.h" | |||||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||||
| #include "debug/data_dump/e2e_dump_util.h" | #include "debug/data_dump/e2e_dump_util.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/dump_proto.h" | #include "debug/dump_proto.h" | ||||
| @@ -291,8 +299,6 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); | MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); | ||||
| opt::AscendBackendIRFusionOptimization(child_graph); | opt::AscendBackendIRFusionOptimization(child_graph); | ||||
| opt::AscendBackendFuseBasicOpt(child_graph, true); | |||||
| opt::AscendBackendGraphKernelOpt(child_graph, true); | |||||
| child_graph->SetExecOrderByDefault(); | child_graph->SetExecOrderByDefault(); | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -466,13 +472,35 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_ | |||||
| MS_LOG(INFO) << "HardwareOptimize start!"; | MS_LOG(INFO) << "HardwareOptimize start!"; | ||||
| opt::AscendBackendOptimization(kernel_graph); | opt::AscendBackendOptimization(kernel_graph); | ||||
| opt::AscendGraphKernelCommonProcess(kernel_graph); | opt::AscendGraphKernelCommonProcess(kernel_graph); | ||||
| opt::AscendBackendFuseBasicOpt(kernel_graph, false); | |||||
| opt::AscendBackendAddAtomicClean(kernel_graph); | |||||
| GraphKernelOptimize(kernel_graph); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| MS_LOG(INFO) << "HardwareOptimize Finish!"; | MS_LOG(INFO) << "HardwareOptimize Finish!"; | ||||
| } | } | ||||
| void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | |||||
| return; | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm"); | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); | |||||
| pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::CompositeOpsFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); | |||||
| pm->AddPass(std::make_shared<opt::TensorPromotion>()); | |||||
| pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); | |||||
| // After Simplify and Splitter, a lot of redundant getitem/maketuple | |||||
| // will be exposed, use GetitemTuple Pass to delete them. | |||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | |||||
| pm->AddPass(std::make_shared<opt::BindValueToGraph>()); | |||||
| pm->AddPass(std::make_shared<opt::CleanAddAtomic>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| } | |||||
| void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| opt::HideNopNode(kernel_graph.get()); | opt::HideNopNode(kernel_graph.get()); | ||||
| @@ -865,8 +893,6 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st | |||||
| memo->insert(graph.get()); | memo->insert(graph.get()); | ||||
| opt::AscendBackendIRFusionOptimization(graph); | opt::AscendBackendIRFusionOptimization(graph); | ||||
| opt::AscendBackendFuseBasicOpt(graph, true); | |||||
| opt::AscendBackendGraphKernelOpt(graph, true); | |||||
| graph->SetExecOrderByDefault(); | graph->SetExecOrderByDefault(); | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| @@ -87,6 +87,7 @@ class AscendSession : public SessionBasic { | |||||
| void InitRuntimeResource(); | void InitRuntimeResource(); | ||||
| void SelectKernel(const KernelGraph &kernel_graph) const; | void SelectKernel(const KernelGraph &kernel_graph) const; | ||||
| void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | ||||