Merge pull request !289 from YuJianfeng/find_optags/v0.2.0-alpha
| @@ -15,43 +15,9 @@ | |||||
| */ | */ | ||||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | ||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| void GetAdd0AndAdd1(const AnfNodePtr &sub0, AnfNodePtr *add0, AnfNodePtr *add1) { | |||||
| MS_EXCEPTION_IF_NULL(sub0); | |||||
| MS_EXCEPTION_IF_NULL(add0); | |||||
| MS_EXCEPTION_IF_NULL(add1); | |||||
| auto sub0_cnode = sub0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sub0_cnode); | |||||
| CheckCNodeInputSize(sub0_cnode, kSubInputNum); | |||||
| AnfNodePtr mul4 = sub0_cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(mul4); | |||||
| auto mul4_cnode = mul4->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(mul4_cnode); | |||||
| CheckCNodeInputSize(mul4_cnode, kMulInputNum); | |||||
| AnfNodePtr true_div0 = mul4_cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(true_div0); | |||||
| auto true_div0_cnode = true_div0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(true_div0_cnode); | |||||
| CheckCNodeInputSize(true_div0_cnode, kRealDivInputNum); | |||||
| *add0 = true_div0_cnode->input(1); | |||||
| AnfNodePtr add2 = true_div0_cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(add2); | |||||
| auto add2_cnode = add2->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add2_cnode); | |||||
| CheckCNodeInputSize(add2_cnode, kAddInputNum); | |||||
| AnfNodePtr sqrt0 = add2_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||||
| auto sqrt0_cnode = sqrt0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0_cnode); | |||||
| CheckCNodeInputSize(sqrt0_cnode, kSqrtInputNum); | |||||
| *add1 = sqrt0_cnode->input(1); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| @@ -79,10 +45,10 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const { | |||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | ||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | ||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | ||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({prim::kPrimTensorAdd, mul2, mul3})}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | ||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| } | } | ||||
| @@ -96,10 +62,17 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con | |||||
| new_node->set_scope(node->scope()); | new_node->set_scope(node->scope()); | ||||
| // Set abstract of new node | // Set abstract of new node | ||||
| AbstractBasePtrList new_node_abstract_list; | AbstractBasePtrList new_node_abstract_list; | ||||
| AnfNodePtr add0 = nullptr; | |||||
| AnfNodePtr add1 = nullptr; | |||||
| GetAdd0AndAdd1(node, &add0, &add1); | |||||
| auto iter_add0 = (*equiv).find(add0_var_); | |||||
| if (iter_add0 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; | |||||
| } | |||||
| auto iter_add1 = (*equiv).find(add1_var_); | |||||
| if (iter_add1 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; | |||||
| } | |||||
| auto add0 = utils::cast<AnfNodePtr>(iter_add0->second); | |||||
| MS_EXCEPTION_IF_NULL(add0); | MS_EXCEPTION_IF_NULL(add0); | ||||
| auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); | |||||
| MS_EXCEPTION_IF_NULL(add1); | MS_EXCEPTION_IF_NULL(add1); | ||||
| new_node_abstract_list.push_back(add1->abstract()); | new_node_abstract_list.push_back(add1->abstract()); | ||||
| new_node_abstract_list.push_back(add0->abstract()); | new_node_abstract_list.push_back(add0->abstract()); | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -35,6 +36,8 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| mul_x_input_vars_.push_back(std::make_shared<Var>()); | mul_x_input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| } | } | ||||
| ~AdamApplyOneFusion() override = default; | ~AdamApplyOneFusion() override = default; | ||||
| @@ -46,6 +49,8 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| std::vector<VarPtr> input_vars_; | std::vector<VarPtr> input_vars_; | ||||
| std::vector<VarPtr> mul_x_input_vars_; | std::vector<VarPtr> mul_x_input_vars_; | ||||
| VarPtr add2_y_; | VarPtr add2_y_; | ||||
| VarPtr add0_var_; | |||||
| VarPtr add1_var_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,48 +17,13 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <tuple> | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/utils.h" | |||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| std::tuple<AnfNodePtr, AnfNodePtr> GetAdd0Add1Node(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto sub0 = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sub0); | |||||
| auto mul5_anf = sub0->input(2); | |||||
| MS_EXCEPTION_IF_NULL(mul5_anf); | |||||
| auto mul5 = mul5_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(mul5); | |||||
| auto add3_anf = mul5->input(2); | |||||
| MS_EXCEPTION_IF_NULL(add3_anf); | |||||
| auto add3 = add3_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add3); | |||||
| auto real_div0_anf = add3->input(1); | |||||
| MS_EXCEPTION_IF_NULL(real_div0_anf); | |||||
| auto real_div0 = real_div0_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div0); | |||||
| auto add0_anf = real_div0->input(1); | |||||
| MS_EXCEPTION_IF_NULL(add0_anf); | |||||
| auto add2_anf = real_div0->input(2); | |||||
| MS_EXCEPTION_IF_NULL(add2_anf); | |||||
| auto add2 = add2_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add2); | |||||
| auto sqrt0_anf = add2->input(1); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0_anf); | |||||
| auto sqrt0 = sqrt0_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||||
| auto add1_anf = sqrt0->input(1); | |||||
| MS_EXCEPTION_IF_NULL(add1_anf); | |||||
| return std::make_tuple(add0_anf, add1_anf); | |||||
| } | |||||
| } // namespace | |||||
| std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { | std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); | auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); | ||||
| @@ -82,10 +47,10 @@ const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const { | |||||
| VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); | VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); | ||||
| VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); | VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); | ||||
| VectorRef square0_pattern({prim::kPrimSquare, input0_}); | VectorRef square0_pattern({prim::kPrimSquare, input0_}); | ||||
| VectorRef add0_pattern({prim::kPrimTensorAdd, mul0_pattern, mul1_pattern}); | |||||
| VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern}); | |||||
| VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); | VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); | ||||
| VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); | VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); | ||||
| VectorRef add1_pattern({prim::kPrimTensorAdd, mul2_pattern, mul3_pattern}); | |||||
| VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern}); | |||||
| VectorRef sqrt0_pattern({sqrt, add1_pattern}); | VectorRef sqrt0_pattern({sqrt, add1_pattern}); | ||||
| VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); | VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); | ||||
| VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); | VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); | ||||
| @@ -107,9 +72,18 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c | |||||
| MS_EXCEPTION_IF_NULL(fusion_node); | MS_EXCEPTION_IF_NULL(fusion_node); | ||||
| fusion_node->set_scope(node->scope()); | fusion_node->set_scope(node->scope()); | ||||
| AnfNodePtr add0 = nullptr; | |||||
| AnfNodePtr add1 = nullptr; | |||||
| std::tie(add0, add1) = GetAdd0Add1Node(node); | |||||
| auto iter_add0 = (*equiv).find(add0_var_); | |||||
| if (iter_add0 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; | |||||
| } | |||||
| auto iter_add1 = (*equiv).find(add1_var_); | |||||
| if (iter_add1 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; | |||||
| } | |||||
| auto add0 = utils::cast<AnfNodePtr>(iter_add0->second); | |||||
| MS_EXCEPTION_IF_NULL(add0); | |||||
| auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); | |||||
| MS_EXCEPTION_IF_NULL(add1); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), | auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), | ||||
| AnfAlgo::GetOutputInferDataType(node, 0)}; | AnfAlgo::GetOutputInferDataType(node, 0)}; | ||||
| auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), | auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class AdamApplyOneWithDecayRule : public PatternProcessPass { | class AdamApplyOneWithDecayRule : public PatternProcessPass { | ||||
| @@ -36,6 +37,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||||
| mul3_x_ = std::make_shared<Var>(); | mul3_x_ = std::make_shared<Var>(); | ||||
| mul4_x_ = std::make_shared<Var>(); | mul4_x_ = std::make_shared<Var>(); | ||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| } | } | ||||
| ~AdamApplyOneWithDecayRule() override = default; | ~AdamApplyOneWithDecayRule() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -54,6 +57,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||||
| VarPtr mul3_x_; | VarPtr mul3_x_; | ||||
| VarPtr mul4_x_; | VarPtr mul4_x_; | ||||
| VarPtr add2_y_; | VarPtr add2_y_; | ||||
| VarPtr add0_var_; | |||||
| VarPtr add1_var_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,36 +16,9 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" | #include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| AnfNodePtr GetAdd1Node(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto add2_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add2_cnode); | |||||
| if (add2_cnode->inputs().size() != kAddInputNum) { | |||||
| MS_LOG(ERROR) << "The input size of Add2 is not equal to " << kAddInputNum; | |||||
| } | |||||
| AnfNodePtr sqrt0 = add2_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||||
| auto sqrt0_cnode = sqrt0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0_cnode); | |||||
| if (sqrt0_cnode->inputs().size() != kSqrtInputNum) { | |||||
| MS_LOG(ERROR) << "The input size of Sqrt0 is not equal to " << kSqrtInputNum; | |||||
| } | |||||
| AnfNodePtr real_div1 = sqrt0_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(real_div1); | |||||
| auto real_div1_cnode = real_div1->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div1_cnode); | |||||
| if (real_div1_cnode->inputs().size() != kMulInputNum) { | |||||
| MS_LOG(ERROR) << "The input size of RealDiv1 is not equal to " << kMulInputNum; | |||||
| } | |||||
| return real_div1_cnode->input(1); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| @@ -79,7 +52,7 @@ const BaseRef LambNextRightRule::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | ||||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | MS_EXCEPTION_IF_NULL(prim_sqrt); | ||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); | VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); | ||||
| VectorRef add1 = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); | |||||
| VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); | |||||
| return VectorRef( | return VectorRef( | ||||
| {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); | {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); | ||||
| } | } | ||||
| @@ -91,7 +64,11 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons | |||||
| auto new_node = CreateLambNextRightNode(func_graph, equiv); | auto new_node = CreateLambNextRightNode(func_graph, equiv); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| // Set abstract of new node | // Set abstract of new node | ||||
| AnfNodePtr add1 = GetAdd1Node(node); | |||||
| auto iter_add1 = (*equiv).find(add1_var_); | |||||
| if (iter_add1 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; | |||||
| } | |||||
| auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); | |||||
| MS_EXCEPTION_IF_NULL(add1); | MS_EXCEPTION_IF_NULL(add1); | ||||
| AbstractBasePtrList new_node_abstract_list; | AbstractBasePtrList new_node_abstract_list; | ||||
| new_node_abstract_list.push_back(add1->abstract()); | new_node_abstract_list.push_back(add1->abstract()); | ||||
| @@ -18,6 +18,8 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class LambNextRightRule : public PatternProcessPass { | class LambNextRightRule : public PatternProcessPass { | ||||
| @@ -29,7 +31,8 @@ class LambNextRightRule : public PatternProcessPass { | |||||
| mul2_x_(std::make_shared<Var>()), | mul2_x_(std::make_shared<Var>()), | ||||
| mul3_x_(std::make_shared<Var>()), | mul3_x_(std::make_shared<Var>()), | ||||
| true_div1_recip_(std::make_shared<Var>()), | true_div1_recip_(std::make_shared<Var>()), | ||||
| add2_y_(std::make_shared<Var>()) {} | |||||
| add2_y_(std::make_shared<Var>()), | |||||
| add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()))) {} | |||||
| ~LambNextRightRule() override = default; | ~LambNextRightRule() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -44,6 +47,7 @@ class LambNextRightRule : public PatternProcessPass { | |||||
| VarPtr mul3_x_; | VarPtr mul3_x_; | ||||
| VarPtr true_div1_recip_; | VarPtr true_div1_recip_; | ||||
| VarPtr add2_y_; | VarPtr add2_y_; | ||||
| VarPtr add1_var_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,7 +30,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph); | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph); | |||||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | ||||
| if (utils::isa<int>(sexp)) { | if (utils::isa<int>(sexp)) { | ||||
| @@ -71,12 +72,20 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph = false) { | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph = false) { | |||||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | ||||
| MS_EXCEPTION_IF_NULL(primitive_vars); | |||||
| if (utils::isa<VectorRef>(sexp)) { | if (utils::isa<VectorRef>(sexp)) { | ||||
| return HandleSexpVector(sexp, graph, multigraph); | |||||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | |||||
| } | } | ||||
| if (utils::isa<VarPtr>(sexp)) { | if (utils::isa<VarPtr>(sexp)) { | ||||
| auto var_ptr = utils::cast<VarPtr>(sexp); | |||||
| MS_EXCEPTION_IF_NULL(var_ptr); | |||||
| if (var_ptr->primitive()) { | |||||
| (*primitive_vars)[var_ptr->primitive()] = var_ptr; | |||||
| return NewValueNode(var_ptr->primitive()); | |||||
| } | |||||
| return CreateVarNodeWithSexp(sexp, graph); | return CreateVarNodeWithSexp(sexp, graph); | ||||
| } | } | ||||
| if (utils::isa<AnfNodePtr>(sexp)) { | if (utils::isa<AnfNodePtr>(sexp)) { | ||||
| @@ -89,13 +98,14 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph | |||||
| return value_node; | return value_node; | ||||
| } | } | ||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph) { | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph) { | |||||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | ||||
| std::vector<AnfNodePtr> input_nodes; | std::vector<AnfNodePtr> input_nodes; | ||||
| const auto &tuple = utils::cast<VectorRef>(sexp); | const auto &tuple = utils::cast<VectorRef>(sexp); | ||||
| if (multigraph && utils::isa<VarPtr>(graph)) { | if (multigraph && utils::isa<VarPtr>(graph)) { | ||||
| for (auto &x : tuple) { | for (auto &x : tuple) { | ||||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), true); | |||||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true); | |||||
| input_nodes.push_back(node); | input_nodes.push_back(node); | ||||
| } | } | ||||
| VarPtr var_ptr = utils::cast<VarPtr>(graph); | VarPtr var_ptr = utils::cast<VarPtr>(graph); | ||||
| @@ -103,7 +113,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool mult | |||||
| } | } | ||||
| for (auto &x : tuple) { | for (auto &x : tuple) { | ||||
| AnfNodePtr node = SexpToNode(x, graph, multigraph); | |||||
| AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); | |||||
| input_nodes.push_back(node); | input_nodes.push_back(node); | ||||
| } | } | ||||
| return CreateCNodeWithGraph(input_nodes, graph); | return CreateCNodeWithGraph(input_nodes, graph); | ||||
| @@ -166,7 +176,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) | |||||
| multigraph_(multigraph), | multigraph_(multigraph), | ||||
| pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(), | pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(), | ||||
| std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | ||||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))) {} | |||||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))), | |||||
| primitive_vars_(std::make_shared<PrimitiveVarMap>()) {} | |||||
| const BaseRef PatternProcessPass::DefinePattern() const { | const BaseRef PatternProcessPass::DefinePattern() const { | ||||
| VarPtr X = std::make_shared<Var>(); | VarPtr X = std::make_shared<Var>(); | ||||
| @@ -176,7 +187,7 @@ const BaseRef PatternProcessPass::DefinePattern() const { | |||||
| void PatternProcessPass::Build() { | void PatternProcessPass::Build() { | ||||
| VarPtr fg = std::make_shared<Var>("RootG"); | VarPtr fg = std::make_shared<Var>("RootG"); | ||||
| BaseRef pattern = std::move(DefinePattern()); | BaseRef pattern = std::move(DefinePattern()); | ||||
| pattern_ = SexpToNode(pattern, fg, multigraph_); | |||||
| pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); | |||||
| } | } | ||||
| AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | ||||
| @@ -185,7 +196,8 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode | |||||
| } | } | ||||
| auto empty_equiv = std::make_shared<Equiv>(); | auto empty_equiv = std::make_shared<Equiv>(); | ||||
| EquivPtr equiv = pattern_engine_.Match(pattern_, node, empty_equiv); | |||||
| MS_EXCEPTION_IF_NULL(primitive_vars_); | |||||
| EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); | |||||
| if (equiv != nullptr && !equiv->empty()) { | if (equiv != nullptr && !equiv->empty()) { | ||||
| return Process(func_graph, node, equiv); | return Process(func_graph, node, equiv); | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -46,6 +47,7 @@ class PatternProcessPass : public NodePass { | |||||
| AnfNodePtr pattern_ = nullptr; | AnfNodePtr pattern_ = nullptr; | ||||
| bool multigraph_ = true; | bool multigraph_ = true; | ||||
| PatternEngine pattern_engine_; | PatternEngine pattern_engine_; | ||||
| PrimitiveVarMapPtr primitive_vars_; | |||||
| }; | }; | ||||
| class GraphOptimizer { | class GraphOptimizer { | ||||
| @@ -42,7 +42,7 @@ void Var::EnsureTag() { | |||||
| } | } | ||||
| } | } | ||||
| bool operator==(const VarPtr& lhs, const VarPtr& rhs) { | |||||
| bool operator==(const VarPtr &lhs, const VarPtr &rhs) { | |||||
| if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) { | if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) { | ||||
| CondVarPtr v1 = dyn_cast<CondVar>(lhs); | CondVarPtr v1 = dyn_cast<CondVar>(lhs); | ||||
| CondVarPtr v2 = dyn_cast<CondVar>(rhs); | CondVarPtr v2 = dyn_cast<CondVar>(rhs); | ||||
| @@ -63,7 +63,7 @@ std::string SeqVar::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const VarPtr& var) { | |||||
| std::ostream &operator<<(std::ostream &os, const VarPtr &var) { | |||||
| if (var == nullptr) { | if (var == nullptr) { | ||||
| os << ""; | os << ""; | ||||
| } else { | } else { | ||||
| @@ -73,10 +73,10 @@ std::ostream& operator<<(std::ostream& os, const VarPtr& var) { | |||||
| } | } | ||||
| template <> | template <> | ||||
| std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv) { | |||||
| std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) { | |||||
| os << "[Equiv]" | os << "[Equiv]" | ||||
| << "\n"; | << "\n"; | ||||
| for (auto& equiv_item : equiv) { | |||||
| for (auto &equiv_item : equiv) { | |||||
| auto k = equiv_item.first; | auto k = equiv_item.first; | ||||
| os << k << ":"; | os << k << ":"; | ||||
| BaseRef x = equiv_item.second; | BaseRef x = equiv_item.second; | ||||
| @@ -104,7 +104,7 @@ std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv) | |||||
| return os; | return os; | ||||
| } | } | ||||
| static BaseRef GetVar(const BaseRef& x) { | |||||
| static BaseRef GetVar(const BaseRef &x) { | |||||
| MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); | MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); | ||||
| if (utils::isa<AnfNodePtr>(x)) { | if (utils::isa<AnfNodePtr>(x)) { | ||||
| auto node = utils::cast<AnfNodePtr>(x); | auto node = utils::cast<AnfNodePtr>(x); | ||||
| @@ -129,7 +129,7 @@ static BaseRef GetVar(const BaseRef& x) { | |||||
| return x; | return x; | ||||
| } | } | ||||
| EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) { | |||||
| EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { | |||||
| MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); | MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); | ||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| if (utils::isa<VarPtr>(pattern)) { | if (utils::isa<VarPtr>(pattern)) { | ||||
| @@ -144,8 +144,8 @@ EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, | |||||
| VectorRef* const values_expr) const { | |||||
| bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const { | |||||
| MS_EXCEPTION_IF_NULL(values_expr); | MS_EXCEPTION_IF_NULL(values_expr); | ||||
| if (utils::isa<SeqPtr>(pattern_ref)) { | if (utils::isa<SeqPtr>(pattern_ref)) { | ||||
| *values_pattern = pattern_ref; | *values_pattern = pattern_ref; | ||||
| @@ -155,12 +155,12 @@ bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref, VectorRef* const values_pattern, | |||||
| VectorRef* const values_expr) const { | |||||
| bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const { | |||||
| MS_EXCEPTION_IF_NULL(values_expr); | MS_EXCEPTION_IF_NULL(values_expr); | ||||
| // visitor to visite the list | // visitor to visite the list | ||||
| auto appender_pattern = [](VectorRef& values) { | |||||
| std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) { | |||||
| auto appender_pattern = [](VectorRef &values) { | |||||
| std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) { | |||||
| values.push_back(GetVar(u)); | values.push_back(GetVar(u)); | ||||
| return u; | return u; | ||||
| }; | }; | ||||
| @@ -174,8 +174,8 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto appender_expr = [](VectorRef& values) { | |||||
| std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) { | |||||
| auto appender_expr = [](VectorRef &values) { | |||||
| std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) { | |||||
| values.push_back(u); | values.push_back(u); | ||||
| return u; | return u; | ||||
| }; | }; | ||||
| @@ -187,10 +187,10 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref | |||||
| return visitor_->Visit(expr_ref, nullptr); | return visitor_->Visit(expr_ref, nullptr); | ||||
| } | } | ||||
| static int GetSVarStartIndex(const VectorRef& values) { | |||||
| static int GetSVarStartIndex(const VectorRef &values) { | |||||
| int index = -1; | int index = -1; | ||||
| int count = 0; | int count = 0; | ||||
| for (auto& value : values) { | |||||
| for (auto &value : values) { | |||||
| if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) { | if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) { | ||||
| if (index != -1) { | if (index != -1) { | ||||
| MS_LOG(DEBUG) << "Multiple SVars in sequence"; | MS_LOG(DEBUG) << "Multiple SVars in sequence"; | ||||
| @@ -203,7 +203,35 @@ static int GetSVarStartIndex(const VectorRef& values) { | |||||
| return index; | return index; | ||||
| } | } | ||||
| EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const { | |||||
| void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) { | |||||
| if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) || | |||||
| !utils::isa<AnfNodePtr>(expr_ref)) { | |||||
| return; | |||||
| } | |||||
| auto real_node = utils::cast<AnfNodePtr>(expr_ref); | |||||
| MS_EXCEPTION_IF_NULL(real_node); | |||||
| if (!real_node->isa<CNode>()) { | |||||
| return; | |||||
| } | |||||
| auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]); | |||||
| MS_EXCEPTION_IF_NULL(prim_node); | |||||
| if (!IsValueNode<Primitive>(prim_node)) { | |||||
| return; | |||||
| } | |||||
| ValuePtr value = GetValueNode(prim_node); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto prim = value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto iter = primitive_vars.find(prim); | |||||
| if (iter == primitive_vars.end()) { | |||||
| return; | |||||
| } | |||||
| (*equiv)[iter->second] = real_node; | |||||
| } | |||||
| EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, | |||||
| const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { | |||||
| int svar_index = GetSVarStartIndex(values_pattern); | int svar_index = GetSVarStartIndex(values_pattern); | ||||
| if (svar_index == kInvalidVarIndex) { | if (svar_index == kInvalidVarIndex) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -229,12 +257,12 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR | |||||
| if (svar_index != -1 && i == IntToSize(svar_index)) { | if (svar_index != -1 && i == IntToSize(svar_index)) { | ||||
| auto seq = | auto seq = | ||||
| std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); | std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); | ||||
| equiv = Match(values_pattern[svar_index], seq, equiv); | |||||
| equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); | |||||
| } else { | } else { | ||||
| if (svar_index != -1 && i > IntToSize(svar_index)) { | if (svar_index != -1 && i > IntToSize(svar_index)) { | ||||
| expr_i = i + diff - 1; | expr_i = i + diff - 1; | ||||
| } | } | ||||
| equiv = Match(values_pattern[i], values_expr[expr_i], equiv); | |||||
| equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); | |||||
| } | } | ||||
| if (equiv == nullptr) { | if (equiv == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -243,7 +271,8 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR | |||||
| return equiv; | return equiv; | ||||
| } | } | ||||
| EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const { | |||||
| EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) const { | |||||
| MS_LOG(DEBUG) << "-----[in Match]"; | MS_LOG(DEBUG) << "-----[in Match]"; | ||||
| MS_LOG(DEBUG) << "GetVar w"; | MS_LOG(DEBUG) << "GetVar w"; | ||||
| BaseRef pattern_ref = GetVar(pattern); | BaseRef pattern_ref = GetVar(pattern); | ||||
| @@ -292,10 +321,12 @@ EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, Equiv | |||||
| // 6. if any svar in both side, find the SeqVar index, | // 6. if any svar in both side, find the SeqVar index, | ||||
| // try to pack the Var s in std::vector to a Seq and match elements one by one. | // try to pack the Var s in std::vector to a Seq and match elements one by one. | ||||
| // check svar | // check svar | ||||
| return AlignSVar(values_pattern, values_expr, equiv); | |||||
| equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); | |||||
| UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); | |||||
| return equiv; | |||||
| } | } | ||||
| BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) const { | |||||
| BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| MS_LOG(DEBUG) << "-----[in Replace]"; | MS_LOG(DEBUG) << "-----[in Replace]"; | ||||
| BaseRef ref = GetVar(pattern); | BaseRef ref = GetVar(pattern); | ||||
| @@ -304,7 +335,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co | |||||
| // w is var | // w is var | ||||
| if (utils::isa<VarPtr>(ref)) { | if (utils::isa<VarPtr>(ref)) { | ||||
| const VarPtr& var = utils::cast<VarPtr>(ref); | |||||
| const VarPtr &var = utils::cast<VarPtr>(ref); | |||||
| auto iter = equiv->find(var); | auto iter = equiv->find(var); | ||||
| if (iter != equiv->end()) { | if (iter != equiv->end()) { | ||||
| out = iter->second; | out = iter->second; | ||||
| @@ -316,7 +347,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co | |||||
| } | } | ||||
| // visitor to visit the list | // visitor to visit the list | ||||
| std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef& u) { return Replace(u, equiv); }; | |||||
| std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; | |||||
| visitor_->SetFn(fn); | visitor_->SetFn(fn); | ||||
| BaseRef visit_out; | BaseRef visit_out; | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <stdexcept> | #include <stdexcept> | ||||
| #include <list> | #include <list> | ||||
| #include <utility> | |||||
| #include "pre_activate/common/visit.h" | #include "pre_activate/common/visit.h" | ||||
| #include "ir/base.h" | #include "ir/base.h" | ||||
| @@ -44,16 +45,19 @@ using CondVarPtr = std::shared_ptr<CondVar>; | |||||
| using SVarPtr = std::shared_ptr<SeqVar>; | using SVarPtr = std::shared_ptr<SeqVar>; | ||||
| const int kInvalidVarIndex = -2; | const int kInvalidVarIndex = -2; | ||||
| using ConditionFunc = std::function<bool(const BaseRef&)>; | |||||
| using ConditionFunc = std::function<bool(const BaseRef &)>; | |||||
| // Base wildcard variable which could match any anf node. | // Base wildcard variable which could match any anf node. | ||||
| class Var : public Base { | class Var : public Base { | ||||
| friend class VarHasher; | friend class VarHasher; | ||||
| public: | public: | ||||
| explicit Var(const std::string& tag = "") : tag_(tag) { EnsureTag(); } | |||||
| Var(const Var& other) : Base(other), tag_(other.tag_) {} | |||||
| virtual Var& operator=(const Var& other) { | |||||
| explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } | |||||
| explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { | |||||
| EnsureTag(); | |||||
| } | |||||
| Var(const Var &other) : Base(other), tag_(other.tag_) {} | |||||
| virtual Var &operator=(const Var &other) { | |||||
| if (&other == this) { | if (&other == this) { | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -63,12 +67,13 @@ class Var : public Base { | |||||
| ~Var() override = default; | ~Var() override = default; | ||||
| MS_DECLARE_PARENT(Var, Base); | MS_DECLARE_PARENT(Var, Base); | ||||
| virtual bool matches(const BaseRef&) { return true; } | |||||
| virtual bool matches(const BaseRef &) { return true; } | |||||
| virtual bool operator==(const Var& other) const { return tag_ == other.tag_; } | |||||
| bool operator!=(const Var& other) const { return !(&other == this); } | |||||
| virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } | |||||
| bool operator!=(const Var &other) const { return !(&other == this); } | |||||
| std::string tag() const { return tag_; } | std::string tag() const { return tag_; } | ||||
| PrimitivePtr primitive() const { return primitive_; } | |||||
| std::string ToString() const override { | std::string ToString() const override { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << "Var(" << tag_ << ")"; | buffer << "Var(" << tag_ << ")"; | ||||
| @@ -80,12 +85,13 @@ class Var : public Base { | |||||
| void EnsureTag(); | void EnsureTag(); | ||||
| std::string tag_; | std::string tag_; | ||||
| PrimitivePtr primitive_; | |||||
| }; | }; | ||||
| // VarNode means variable node, a subclass of AnfNode | // VarNode means variable node, a subclass of AnfNode | ||||
| class VarNode : public AnfNode { | class VarNode : public AnfNode { | ||||
| public: | public: | ||||
| VarNode(const VarPtr& value, const FuncGraphPtr& func_graph) : AnfNode(func_graph), var_(value) {} | |||||
| VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} | |||||
| ~VarNode() override = default; | ~VarNode() override = default; | ||||
| MS_DECLARE_PARENT(VarNode, AnfNode); | MS_DECLARE_PARENT(VarNode, AnfNode); | ||||
| @@ -95,16 +101,16 @@ using VarNodePtr = std::shared_ptr<VarNode>; | |||||
| class VarHasher { | class VarHasher { | ||||
| public: | public: | ||||
| std::size_t operator()(const Var& var) const { return var.hash(); } | |||||
| std::size_t operator()(const Var &var) const { return var.hash(); } | |||||
| }; | }; | ||||
| // Condition Var, match an anf node when condition function return true. | // Condition Var, match an anf node when condition function return true. | ||||
| class CondVar : public Var { | class CondVar : public Var { | ||||
| public: | public: | ||||
| explicit CondVar(const ConditionFunc& cond) : cond_fn_(cond) {} | |||||
| explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} | |||||
| ~CondVar() override = default; | ~CondVar() override = default; | ||||
| MS_DECLARE_PARENT(CondVar, Var); | MS_DECLARE_PARENT(CondVar, Var); | ||||
| bool matches(const BaseRef& value) override { | |||||
| bool matches(const BaseRef &value) override { | |||||
| MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); | MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); | ||||
| if (utils::isa<Var>(value)) { | if (utils::isa<Var>(value)) { | ||||
| return false; | return false; | ||||
| @@ -124,55 +130,60 @@ class SeqVar : public Var { | |||||
| ~SeqVar() override = default; | ~SeqVar() override = default; | ||||
| MS_DECLARE_PARENT(SeqVar, Var); | MS_DECLARE_PARENT(SeqVar, Var); | ||||
| explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } | explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } | ||||
| bool matches(const BaseRef& value) override { | |||||
| bool matches(const BaseRef &value) override { | |||||
| // match Seq. | // match Seq. | ||||
| if (utils::isa<Seq>(value)) { | if (utils::isa<Seq>(value)) { | ||||
| const Seq& seq = utils::cast<Seq>(value); | |||||
| return std::all_of(seq.begin(), seq.end(), [this](const BaseRef& v) { | |||||
| const Seq &seq = utils::cast<Seq>(value); | |||||
| return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { | |||||
| auto eq = subvar_->matches(v); | auto eq = subvar_->matches(v); | ||||
| return eq; | return eq; | ||||
| }); | }); | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool operator==(const SeqVar& other) const { return *subvar_ == *other.subvar_; } | |||||
| bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| private: | private: | ||||
| VarPtr subvar_; | VarPtr subvar_; | ||||
| }; | }; | ||||
| bool operator==(const VarPtr& lhs, const VarPtr& rhs); | |||||
| bool operator==(const VarPtr &lhs, const VarPtr &rhs); | |||||
| inline bool operator!=(const VarPtr& lhs, const VarPtr& rhs) { return !(lhs == rhs); } | |||||
| inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } | |||||
| std::ostream& operator<<(std::ostream& os, const VarPtr& var); | |||||
| std::ostream &operator<<(std::ostream &os, const VarPtr &var); | |||||
| using Equiv = std::map<VarPtr, BaseRef>; | using Equiv = std::map<VarPtr, BaseRef>; | ||||
| using EquivPtr = std::shared_ptr<Equiv>; | using EquivPtr = std::shared_ptr<Equiv>; | ||||
| using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>; | |||||
| using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>; | |||||
| inline bool DefaultTypeEq(const BaseRef& x, const BaseRef& y) { return x.type() == y.type(); } | |||||
| inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } | |||||
| class PatternEngine { | class PatternEngine { | ||||
| public: | public: | ||||
| PatternEngine(const std::shared_ptr<Visitor>& visitor, const std::function<bool(const BaseRef&, const BaseRef&)>& eq, | |||||
| const std::function<bool(const BaseRef&, const BaseRef&)>& type_eq = DefaultTypeEq) | |||||
| PatternEngine(const std::shared_ptr<Visitor> &visitor, | |||||
| const std::function<bool(const BaseRef &, const BaseRef &)> &eq, | |||||
| const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq) | |||||
| : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} | : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} | ||||
| ~PatternEngine() = default; | ~PatternEngine() = default; | ||||
| EquivPtr Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const; | |||||
| EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) const; | |||||
| // Replace pattern with equivalent | // Replace pattern with equivalent | ||||
| BaseRef Replace(const BaseRef& pattern, const EquivPtr& equiv) const; | |||||
| BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; | |||||
| private: | private: | ||||
| EquivPtr AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const; | |||||
| bool ToVector(const BaseRef& pattern, const BaseRef& expr, VectorRef* const values_pattern, | |||||
| VectorRef* const values_expr) const; | |||||
| bool ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, | |||||
| VectorRef* const values_expr) const; | |||||
| EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, | |||||
| const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; | |||||
| bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const; | |||||
| bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const; | |||||
| std::shared_ptr<Visitor> visitor_; | std::shared_ptr<Visitor> visitor_; | ||||
| std::function<bool(const BaseRef&, const BaseRef&)> eq_; | |||||
| std::function<bool(const BaseRef&, const BaseRef&)> type_eq_; | |||||
| std::function<bool(const BaseRef &, const BaseRef &)> eq_; | |||||
| std::function<bool(const BaseRef &, const BaseRef &)> type_eq_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| namespace std { | namespace std { | ||||
| @@ -40,6 +40,7 @@ class TestMatchEngine : public UT::Common { | |||||
| public: | public: | ||||
| PatternEngine TU; | PatternEngine TU; | ||||
| EquivPtr equiv_null; | EquivPtr equiv_null; | ||||
| PrimitiveVarMap primitive_vars_null; | |||||
| }; | }; | ||||
| TEST_F(TestMatchEngine, Var) { | TEST_F(TestMatchEngine, Var) { | ||||
| @@ -106,30 +107,30 @@ TEST_F(TestMatchEngine, MatchRaw_Var) { | |||||
| // common | // common | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(v1, 1, equiv_null); | |||||
| d = TU.Match(v1, 1, primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ((*d)[v1], 1); | ASSERT_EQ((*d)[v1], 1); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| (*equiv_null)[v1] = v2; | (*equiv_null)[v1] = v2; | ||||
| d = TU.Match(v1, 1, equiv_null); | |||||
| d = TU.Match(v1, 1, primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->count(v2), std::size_t(1)); | ASSERT_EQ(d->count(v2), std::size_t(1)); | ||||
| ASSERT_EQ((*d)[v2], 1); | ASSERT_EQ((*d)[v2], 1); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| (*equiv_null)[v1] = v2; | (*equiv_null)[v1] = v2; | ||||
| (*equiv_null)[v3] = 1; | (*equiv_null)[v3] = 1; | ||||
| d = TU.Match(v1, 1, equiv_null); | |||||
| d = TU.Match(v1, 1, primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->count(v2), std::size_t(1)); | ASSERT_EQ(d->count(v2), std::size_t(1)); | ||||
| ASSERT_EQ((*d)[v2], 1); | ASSERT_EQ((*d)[v2], 1); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({v1}), VectorRef({1}), equiv_null); | |||||
| d = TU.Match(VectorRef({v1}), VectorRef({1}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->size(), std::size_t(1)); | ASSERT_EQ(d->size(), std::size_t(1)); | ||||
| ASSERT_EQ(d->count(v1), std::size_t(1)); | ASSERT_EQ(d->count(v1), std::size_t(1)); | ||||
| ASSERT_EQ((*d)[v1], 1); | ASSERT_EQ((*d)[v1], 1); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| ASSERT_EQ(TU.Match(1, 2, equiv_null), nullptr); | |||||
| ASSERT_EQ(TU.Match(1, 2, primitive_vars_null, equiv_null), nullptr); | |||||
| } | } | ||||
| TEST_F(TestMatchEngine, MatchRaw_SVar) { | TEST_F(TestMatchEngine, MatchRaw_SVar) { | ||||
| @@ -139,22 +140,22 @@ TEST_F(TestMatchEngine, MatchRaw_SVar) { | |||||
| EquivPtr d; | EquivPtr d; | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), equiv_null); | |||||
| d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->size(), std::size_t(1)); | ASSERT_EQ(d->size(), std::size_t(1)); | ||||
| ASSERT_EQ(d->count(sv1), std::size_t(1)); | ASSERT_EQ(d->count(sv1), std::size_t(1)); | ||||
| ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({1, 2})); | ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({1, 2})); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), equiv_null); | |||||
| d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->size(), std::size_t(2)); | ASSERT_EQ(d->size(), std::size_t(2)); | ||||
| ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({2})); | ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({2})); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), equiv_null), nullptr); | |||||
| ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), primitive_vars_null, equiv_null), nullptr); | |||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| (*equiv_null)[sv1] = std::make_shared<Seq>(PatternListType{1, 2}); | (*equiv_null)[sv1] = std::make_shared<Seq>(PatternListType{1, 2}); | ||||
| d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), equiv_null); | |||||
| d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->size(), std::size_t(2)); | ASSERT_EQ(d->size(), std::size_t(2)); | ||||
| ASSERT_EQ((*d)[v1], 1); | ASSERT_EQ((*d)[v1], 1); | ||||
| } | } | ||||
| @@ -167,13 +168,13 @@ TEST_F(TestMatchEngine, Match) { | |||||
| EquivPtr d; | EquivPtr d; | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), equiv_null); | |||||
| d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d->size(), std::size_t(2)); | ASSERT_EQ(d->size(), std::size_t(2)); | ||||
| ASSERT_EQ((*d)[v1], 1); | ASSERT_EQ((*d)[v1], 1); | ||||
| ASSERT_EQ((*d)[v2], 2); | ASSERT_EQ((*d)[v2], 2); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(static_cast<int>(1), static_cast<float>(1), equiv_null); | |||||
| d = TU.Match(static_cast<int>(1), static_cast<float>(1), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d, nullptr); | ASSERT_EQ(d, nullptr); | ||||
| } | } | ||||
| @@ -197,18 +198,19 @@ TEST_F(TestMatchEngine, Match_CondVar) { | |||||
| EquivPtr d; | EquivPtr d; | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), -1}), equiv_null); | |||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), -1}), primitive_vars_null, equiv_null); | |||||
| ASSERT_GE(d->size(), std::size_t(0)); | ASSERT_GE(d->size(), std::size_t(0)); | ||||
| auto vfn = (*d)[vf]; | auto vfn = (*d)[vf]; | ||||
| ASSERT_EQ((*d)[vf], static_cast<float>(1.0)); | ASSERT_EQ((*d)[vf], static_cast<float>(1.0)); | ||||
| ASSERT_EQ((*d)[vn], -1); | ASSERT_EQ((*d)[vn], -1); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast<float>(-1.0)}), equiv_null); | |||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast<float>(-1.0)}), primitive_vars_null, equiv_null); | |||||
| ASSERT_EQ(d, nullptr); | ASSERT_EQ(d, nullptr); | ||||
| equiv_null->clear(); | equiv_null->clear(); | ||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), static_cast<int>(1)}), equiv_null); | |||||
| d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), static_cast<int>(1)}), primitive_vars_null, | |||||
| equiv_null); | |||||
| ASSERT_EQ(d, nullptr); | ASSERT_EQ(d, nullptr); | ||||
| } | } | ||||