From: @lingyunli63 Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doupull/14553/MERGE
| @@ -39,20 +39,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsFusibleOp(const AnfNodePtr &node) { | |||
| #if ENABLE_D | |||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | |||
| "LambNextMV", "LambUpdateWithLR"}; | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_attr != nullptr) { | |||
| return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0; | |||
| } | |||
| } | |||
| #endif | |||
| return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); | |||
| } | |||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||
| if (cur_node == node) { | |||
| return FOLLOW; | |||
| @@ -0,0 +1,120 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/cast_matmul_fusion.h" | |||
| #include <tuple> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| // Check if leaf is used by root | |||
| bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphManagerPtr &mng) { | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool result = false; | |||
| auto IncludeUser = [&result, &root](const AnfNodePtr &node) { | |||
| if (node == root) { | |||
| result = true; | |||
| } | |||
| return result ? EXCLUDE : FOLLOW; | |||
| }; | |||
| static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser)); | |||
| return result; | |||
| } | |||
| } // namespace | |||
| /* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword | |||
| * case1, cast only used by MatMul | |||
| * | |||
| * bias_fp32 = depend(bias_fp32, u) | |||
| * %0 = cast(bias_fp32, fp16) | |||
| * %1 = MatMul(A_fp16, B_fp16, %0) | |||
| * ------> | |||
| * bias_fp32 = depend(bias_fp32, u) | |||
| * %1 = MatMul(A_fp16, B_fp16, bias_fp32) | |||
| * | |||
| * case2, cast used by MatMul and UpdateStatus | |||
| * | |||
| * bias_fp32 = load(p, status) | |||
| * %0 = cast(bias_fp32, fp16) | |||
| * %1 = MatMul(A_fp16, B_fp16, %0) | |||
| * %2 = UpstateStatus(status, %0) | |||
| * ------> | |||
| * bias_fp32 = load(p, status) | |||
| * %1 = MatMul(A_fp16, B_fp16, bias_fp32) | |||
| * %2 = UpstateStatus(status, %1) | |||
| */ | |||
| bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| auto changed = false; | |||
| auto nodes = TopoSort(func_graph->get_return()); | |||
| for (auto node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->size() != 4) { | |||
| continue; | |||
| } | |||
| auto cast_node = cnode->input(3); | |||
| if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) { | |||
| continue; | |||
| } | |||
| auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); | |||
| auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); | |||
| if (cast_input_type != kNumberTypeFloat32 || cast_output_type != kNumberTypeFloat16) { | |||
| continue; | |||
| } | |||
| // Cast cannot fuse with its input | |||
| if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) { | |||
| continue; | |||
| } | |||
| auto user_index_set = mng->node_users()[cast_node]; | |||
| // Case1 : Cast is only used by matmul | |||
| if (user_index_set.size() == 1) { | |||
| mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1)); | |||
| changed = true; | |||
| continue; | |||
| } | |||
| // Case2 : Cast is used by matmul and Upstatus | |||
| if (user_index_set.size() > 2) { | |||
| continue; | |||
| } | |||
| for (auto user_index : user_index_set) { | |||
| // Exclude when UpdateStatus-> ... ->matmul path is found | |||
| if (IsPrimitiveCNode(user_index.first, prim::kPrimUpdateState) && !HasPath(user_index.first, node, mng)) { | |||
| auto update_state = (user_index.first)->cast<CNodePtr>(); | |||
| update_state->set_input(2, node); | |||
| cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1)); | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| changed = true; | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| @@ -24,13 +24,13 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class OptimizeMatmul : public Pass { | |||
| class CastMatmulFusion : public Pass { | |||
| public: | |||
| OptimizeMatmul() : Pass("optimize_matmul") {} | |||
| ~OptimizeMatmul() override = default; | |||
| CastMatmulFusion() : Pass("cast_matmul_fusion") {} | |||
| ~CastMatmulFusion() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| }; | |||
| using OptimizeMatmulPtr = std::shared_ptr<OptimizeMatmul>; | |||
| using OptimizeMatmulPtr = std::shared_ptr<CastMatmulFusion>; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ | |||
| @@ -627,6 +627,20 @@ bool IsBasicFuseOp(const AnfNodePtr &node) { | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||
| } | |||
| bool IsFusibleOp(const AnfNodePtr &node) { | |||
| #if ENABLE_D | |||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | |||
| "LambNextMV", "LambUpdateWithLR"}; | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_attr != nullptr) { | |||
| return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0; | |||
| } | |||
| } | |||
| #endif | |||
| return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); | |||
| } | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -73,6 +73,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | |||
| std::vector<PrimitivePtr> GetFusibleOpList(); | |||
| bool IsBasicFuseOp(const AnfNodePtr &node); | |||
| bool IsFusibleOp(const AnfNodePtr &node); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior); | |||
| @@ -30,7 +30,7 @@ | |||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||
| #include "backend/optimizer/graph_kernel/optimize_matmul.h" | |||
| #include "backend/optimizer/graph_kernel/cast_matmul_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/raise_reduction_precision.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" | |||
| #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" | |||
| @@ -52,7 +52,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { | |||
| if (is_ascend) { | |||
| // Remove redundant Cast(bias, fp16) for Matmul input | |||
| pm->AddPass(std::make_shared<OptimizeMatmul>()); | |||
| pm->AddPass(std::make_shared<CastMatmulFusion>()); | |||
| // Reorder TransData-Cast to Cast-TransData | |||
| pm->AddPass(std::make_shared<ReorderOps>()); | |||
| @@ -1,64 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/optimize_matmul.h" | |||
| #include <tuple> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| /* MatMul supports fp32 bias, so remove the redundant cast when cast only used by MatMul | |||
| * | |||
| * %0 = cast(bias_fp32, fp16) | |||
| * %1 = MatMul(A_fp16, B_fp16, %0) | |||
| * ------> | |||
| * %1 = MatMul(A_fp16, B_fp16, bias_fp32) | |||
| */ | |||
| bool OptimizeMatmul::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| auto changed = false; | |||
| auto nodes = TopoSort(func_graph->get_return()); | |||
| for (auto node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->size() != 4) { | |||
| continue; | |||
| } | |||
| auto cast_node = cnode->input(3); | |||
| if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) { | |||
| continue; | |||
| } | |||
| auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); | |||
| auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); | |||
| if (cast_input_type == kNumberTypeFloat32 && cast_output_type == kNumberTypeFloat16 && | |||
| mng->node_users()[cast_node].size() == 1) { | |||
| mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1)); | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||