Merge pull request !3973 from zhengjun10/mastertags/v0.7.0-beta
| @@ -13,83 +13,46 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "src/gllo/common/utils.h" | |||||
| #include "src/ir/primitive_t_value.h" | #include "src/ir/primitive_t_value.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||||
| MS_EXCEPTION_IF_NULL(a_node); | |||||
| MS_EXCEPTION_IF_NULL(b_node); | |||||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(a_value_node); | |||||
| auto a_value = a_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(a_value); | |||||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(a_prim); | |||||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(b_value_node); | |||||
| auto b_value = b_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(b_value); | |||||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(b_prim); | |||||
| return a_prim->name() == b_prim->name(); | |||||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||||
| if (a_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto a_value_ptr = a_value_node_ptr->value(); | |||||
| if (a_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||||
| if (b_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto b_value_ptr = b_value_node_ptr->value(); | |||||
| if (b_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) { | |||||
| auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get()); | |||||
| auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get()); | |||||
| return (*a_obj) == (*b_obj); | |||||
| } else { | |||||
| return (*a_value_ptr) == (*b_value_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (a.m_ptr->isa<lite::PrimitiveTValue>()) { | |||||
| auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>(); | |||||
| auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>(); | |||||
| return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; | |||||
| namespace { | |||||
| constexpr auto kAnfPrimitiveIndex = 0; | |||||
| bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | } | ||||
| return a == b; | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); | |||||
| } | } | ||||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||||
| // To matchCNode and Kernel's type | |||||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||||
| bool IsRealKernel(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // parameter and value node is not a real kernel too | |||||
| if (!node->isa<CNode>()) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return a.type() == b.type(); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().empty()) { | |||||
| MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); | |||||
| } | |||||
| auto input = cnode->inputs()[0]; | |||||
| bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || | |||||
| IsPrimitive(input, prim::kPrimTensorSummary) || | |||||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | |||||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||||
| return !is_virtual_node; | |||||
| } | } | ||||
| namespace { | |||||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | ||||
| if (utils::isa<int>(sexp)) { | if (utils::isa<int>(sexp)) { | ||||
| return NewValueNode(utils::cast<int>(sexp)); | return NewValueNode(utils::cast<int>(sexp)); | ||||
| @@ -118,11 +81,11 @@ CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const | |||||
| VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | ||||
| if (utils::isa<VarPtr>(graph)) { | if (utils::isa<VarPtr>(graph)) { | ||||
| // MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||||
| MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | ||||
| } | } | ||||
| if (utils::isa<FuncGraphPtr>(graph)) { | if (utils::isa<FuncGraphPtr>(graph)) { | ||||
| // MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||||
| MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | ||||
| } | } | ||||
| MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | ||||
| @@ -131,7 +94,7 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | ||||
| bool multigraph) { | bool multigraph) { | ||||
| // MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| std::vector<AnfNodePtr> input_nodes; | std::vector<AnfNodePtr> input_nodes; | ||||
| const auto &tuple = utils::cast<VectorRef>(sexp); | const auto &tuple = utils::cast<VectorRef>(sexp); | ||||
| if (multigraph && utils::isa<VarPtr>(graph)) { | if (multigraph && utils::isa<VarPtr>(graph)) { | ||||
| @@ -151,8 +114,75 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||||
| MS_EXCEPTION_IF_NULL(a_node); | |||||
| MS_EXCEPTION_IF_NULL(b_node); | |||||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(a_value_node); | |||||
| auto a_value = a_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(a_value); | |||||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(a_prim); | |||||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(b_value_node); | |||||
| auto b_value = b_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(b_value); | |||||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(b_prim); | |||||
| return a_prim->name() == b_prim->name(); | |||||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||||
| if (a_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto a_value_ptr = a_value_node_ptr->value(); | |||||
| if (a_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||||
| if (b_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto b_value_ptr = b_value_node_ptr->value(); | |||||
| if (b_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) { | |||||
| auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get()); | |||||
| auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get()); | |||||
| return (*a_obj) == (*b_obj); | |||||
| } else { | |||||
| return (*a_value_ptr) == (*b_value_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (a.m_ptr->isa<lite::PrimitiveTValue>() && b.m_ptr->isa<lite::PrimitiveTValue>()) { | |||||
| auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>(); | |||||
| auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>(); | |||||
| return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; | |||||
| } | |||||
| return a == b; | |||||
| } | |||||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||||
| // To matchCNode and Kernel's type | |||||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||||
| return true; | |||||
| } | |||||
| return a.type() == b.type(); | |||||
| } | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | ||||
| // MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| MS_EXCEPTION_IF_NULL(primitive_vars); | MS_EXCEPTION_IF_NULL(primitive_vars); | ||||
| if (utils::isa<VectorRef>(sexp)) { | if (utils::isa<VectorRef>(sexp)) { | ||||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | ||||
| @@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap | |||||
| return value_node; | return value_node; | ||||
| } | } | ||||
| bool IsRealCNodeKernel(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // parameter and value node is not a real cnode kernel | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| // return considered as a real node | |||||
| if (CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| return true; | |||||
| } | |||||
| return IsRealKernel(node); | |||||
| } | |||||
| bool IsGraphKernel(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // graph kernel should be a real cnode kernel. | |||||
| if (!IsRealCNodeKernel(node)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input = cnode->input(kAnfPrimitiveIndex); | |||||
| // graph kernel should has func_graph as first input. | |||||
| if (!IsValueNode<FuncGraph>(input)) { | |||||
| return false; | |||||
| } | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(input); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| } | |||||
| void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { | void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { | ||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "The graph is null."; | MS_LOG(EXCEPTION) << "The graph is null."; | ||||
| @@ -14,22 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| #include <mindspore/lite/src/ir/primitive_t_value.h> | |||||
| #include <memory> | #include <memory> | ||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/gllo/common/pattern_engine.h" | |||||
| #include "backend/optimizer/common/pattern_engine.h" | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>; | using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | bool AnfEqual(const BaseRef &a, const BaseRef &b); | ||||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | ||||
| @@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | ||||
| bool multigraph = false); | bool multigraph = false); | ||||
| bool IsRealCNodeKernel(const AnfNodePtr &node); | |||||
| bool IsGraphKernel(const AnfNodePtr &node); | |||||
| void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); | void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); | ||||
| void CheckIfAnfNodeIsNull(const AnfNodePtr &node); | void CheckIfAnfNodeIsNull(const AnfNodePtr &node); | ||||
| @@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n); | |||||
| bool IsConvNode(const BaseRef &n); | bool IsConvNode(const BaseRef &n); | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/gllo/common/node_pass.h" | |||||
| #include "backend/optimizer/common/node_pass.h" | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <deque> | #include <deque> | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(const_func_graph); | MS_EXCEPTION_IF_NULL(const_func_graph); | ||||
| todo.push_back(const_func_graph->output()); | todo.push_back(const_func_graph->output()); | ||||
| } else if (new_node && new_node->isa<CNode>()) { | } else if (new_node && new_node->isa<CNode>()) { | ||||
| if (IsGraphKernel(new_node)) { | |||||
| todo.push_back(new_node); | |||||
| } | |||||
| auto cnode = new_node->cast<CNodePtr>(); | auto cnode = new_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| @@ -1,36 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "src/gllo/common/pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| // @brief ANF Node level optimization base pass | |||||
| class NodePass : public Pass { | |||||
| public: | |||||
| explicit NodePass(const std::string &name) : Pass(name) {} | |||||
| ~NodePass() override = default; | |||||
| bool Run(const FuncGraphPtr &func_graph) final; | |||||
| virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; | |||||
| }; | |||||
| using NodePassPtr = std::shared_ptr<NodePass>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <initializer_list> | #include <initializer_list> | ||||
| #include "src/gllo/common/pass_manager.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/optimizer/common/pass_manager.h" | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -26,9 +26,9 @@ | |||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/gllo/common/pass_manager.h" | |||||
| #include "src/gllo/common/pattern_engine.h" | |||||
| #include "src/gllo/common/utils.h" | |||||
| #include "backend/optimizer/common/pass_manager.h" | |||||
| #include "backend/optimizer/common/pattern_engine.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/gllo/common/pass_manager.h" | |||||
| #include "backend/optimizer/common/pass_manager.h" | |||||
| #include <sys/time.h> | #include <sys/time.h> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| @@ -1,61 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "src/gllo/common/pass.h" | |||||
| #include "src/gllo/common/node_pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| // @brief For optimization passes management | |||||
| class PassManager { | |||||
| public: | |||||
| explicit PassManager(const std::string &name = "pm", bool run_only_once = true) | |||||
| : name_(name), passes_{}, run_only_once_(run_only_once) {} | |||||
| virtual ~PassManager() = default; | |||||
| // Get all the passes added by AddPass | |||||
| const std::vector<PassPtr> &Passes() const; | |||||
| // Add graph pass, the pass object will be freed when pass manager freed. | |||||
| void AddPass(const PassPtr &pass); | |||||
| // Run passes added in pass manager on the input graph | |||||
| // @param [inout] graph The graph to be optimized | |||||
| // @return true, graph changed | |||||
| // @return false, graph not changed | |||||
| bool Run(const FuncGraphPtr &func_graph) const; | |||||
| // Run the given graph passes on the input graph | |||||
| // @param [inout] graph The graph to be optimized | |||||
| // @param [in] passes The given graph passes | |||||
| // @return true, graph changed | |||||
| // @return false, graph not changed | |||||
| bool Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const; | |||||
| std::string name() const { return name_; } | |||||
| private: | |||||
| const std::string name_; | |||||
| std::vector<PassPtr> passes_; | |||||
| bool run_only_once_; | |||||
| }; | |||||
| using PassManagerPtr = std::shared_ptr<PassManager>; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ | |||||
| @@ -1,365 +0,0 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/gllo/common/pattern_engine.h" | |||||
| #include <exception> | |||||
| #include <iostream> | |||||
| #include <functional> | |||||
| #include <iterator> | |||||
| #include "ir/func_graph.h" | |||||
| #include "mindspore/core/ir/primitive.h" | |||||
| #include "utils/info.h" | |||||
| #include "ir/anf.h" | |||||
| #include "utils/convert_utils_base.h" | |||||
| #include "utils/overload.h" | |||||
| namespace mindspore { | |||||
| static int GetNextTag() { | |||||
| static int kID = 0; | |||||
| return kID++; | |||||
| } | |||||
| void Var::EnsureTag() { | |||||
| if (tag_.length() == 0) { | |||||
| std::ostringstream buffer; | |||||
| buffer << "_" << GetNextTag(); | |||||
| tag_ = buffer.str(); | |||||
| } | |||||
| } | |||||
| bool operator==(const VarPtr &lhs, const VarPtr &rhs) { | |||||
| if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) { | |||||
| CondVarPtr v1 = dyn_cast<CondVar>(lhs); | |||||
| CondVarPtr v2 = dyn_cast<CondVar>(rhs); | |||||
| return *v1 == *v2; | |||||
| } | |||||
| if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) { | |||||
| SVarPtr v1 = dyn_cast<SeqVar>(lhs); | |||||
| SVarPtr v2 = dyn_cast<SeqVar>(rhs); | |||||
| return *v1 == *v2; | |||||
| } | |||||
| return (*lhs == *rhs); | |||||
| } | |||||
| std::string SeqVar::ToString() const { | |||||
| std::ostringstream buffer; | |||||
| buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; | |||||
| return buffer.str(); | |||||
| } | |||||
| std::ostream &operator<<(std::ostream &os, const VarPtr &var) { | |||||
| if (var == nullptr) { | |||||
| os << ""; | |||||
| } else { | |||||
| os << var->ToString(); | |||||
| } | |||||
| return os; | |||||
| } | |||||
| template <> | |||||
| std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) { | |||||
| os << "[Equiv]" | |||||
| << "\n"; | |||||
| for (auto &equiv_item : equiv) { | |||||
| auto k = equiv_item.first; | |||||
| os << k << ":"; | |||||
| BaseRef x = equiv_item.second; | |||||
| if (utils::isa<AnfNodePtr>(x)) { | |||||
| auto node = utils::cast<AnfNodePtr>(x); | |||||
| os << "TypeString[" << node->type_name() << "]"; | |||||
| if (IsValueNode<FuncGraph>(node)) { | |||||
| os << "IsValueNodeGraph "; | |||||
| } | |||||
| os << "type " << node->type_name(); | |||||
| if (node->isa<ValueNode>()) { | |||||
| os << " value " << GetValueNode(node); | |||||
| } | |||||
| os << " addr: " << node; | |||||
| } else if (utils::isa<Named>(x)) { | |||||
| os << "Named " << x.ToString().c_str(); | |||||
| } else if (utils::isa<VarPtr>(x)) { | |||||
| os << "TypeString[Var]"; | |||||
| os << utils::cast<VarPtr>(x); | |||||
| } else if (utils::isa<FuncGraphPtr>(x)) { | |||||
| os << "TypeString[Graph]"; | |||||
| } | |||||
| os << "\n"; | |||||
| } | |||||
| return os; | |||||
| } | |||||
| static BaseRef GetVar(const BaseRef &x) { | |||||
| // MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); | |||||
| if (utils::isa<AnfNodePtr>(x)) { | |||||
| auto node = utils::cast<AnfNodePtr>(x); | |||||
| // MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; | |||||
| if (node->isa<VarNode>()) { | |||||
| // MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString(); | |||||
| return node->cast<VarNodePtr>()->var_; | |||||
| } | |||||
| // if (node->isa<ValueNode>()) { | |||||
| // MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); | |||||
| // } else { | |||||
| // MS_LOG(DEBUG) << "type " + node->type_name(); | |||||
| // } | |||||
| // } else if (utils::isa<Named>(x)) { | |||||
| // MS_LOG(DEBUG) << "Named " + x.ToString(); | |||||
| // } else if (utils::isa<VectorRef>(x)) { | |||||
| // MS_LOG(DEBUG) << "VectorRef"; | |||||
| // } else if (utils::isa<VarPtr>(x)) { | |||||
| // MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); | |||||
| } | |||||
| // MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); | |||||
| return x; | |||||
| } | |||||
| EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { | |||||
| MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| if (utils::isa<VarPtr>(pattern)) { | |||||
| VarPtr var = utils::cast<VarPtr>(pattern); | |||||
| if (var->matches(expr)) { | |||||
| (*equiv)[var] = expr; | |||||
| MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); | |||||
| return equiv; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const { | |||||
| MS_EXCEPTION_IF_NULL(values_expr); | |||||
| if (utils::isa<SeqPtr>(pattern_ref)) { | |||||
| *values_pattern = pattern_ref; | |||||
| *values_expr = expr_ref; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const { | |||||
| MS_EXCEPTION_IF_NULL(values_expr); | |||||
| // visitor to visite the list | |||||
| auto appender_pattern = [](VectorRef &values) { | |||||
| std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) { | |||||
| values.push_back(GetVar(u)); | |||||
| return u; | |||||
| }; | |||||
| return fn; | |||||
| }; | |||||
| visitor_->SetFn(appender_pattern(*values_pattern)); | |||||
| // MS_LOG(DEBUG) << "visit pattern_ref"; | |||||
| bool success = visitor_->Visit(pattern_ref, nullptr); | |||||
| if (!success) { | |||||
| return false; | |||||
| } | |||||
| auto appender_expr = [](VectorRef &values) { | |||||
| std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) { | |||||
| values.push_back(u); | |||||
| return u; | |||||
| }; | |||||
| return fn; | |||||
| }; | |||||
| visitor_->SetFn(appender_expr(*values_expr)); | |||||
| // MS_LOG(DEBUG) << "visit expr_ref"; | |||||
| return visitor_->Visit(expr_ref, nullptr); | |||||
| } | |||||
| static int GetSVarStartIndex(const VectorRef &values) { | |||||
| int index = -1; | |||||
| int count = 0; | |||||
| for (auto &value : values) { | |||||
| if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) { | |||||
| if (index != -1) { | |||||
| // MS_LOG(DEBUG) << "Multiple SVars in sequence"; | |||||
| return kInvalidVarIndex; | |||||
| } | |||||
| index = count; | |||||
| } | |||||
| count++; | |||||
| } | |||||
| return index; | |||||
| } | |||||
| void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) { | |||||
| if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) || | |||||
| !utils::isa<AnfNodePtr>(expr_ref)) { | |||||
| return; | |||||
| } | |||||
| auto real_node = utils::cast<AnfNodePtr>(expr_ref); | |||||
| MS_EXCEPTION_IF_NULL(real_node); | |||||
| if (!real_node->isa<CNode>()) { | |||||
| return; | |||||
| } | |||||
| auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]); | |||||
| MS_EXCEPTION_IF_NULL(prim_node); | |||||
| if (!IsValueNode<Primitive>(prim_node)) { | |||||
| return; | |||||
| } | |||||
| ValuePtr value = GetValueNode(prim_node); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto prim = value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto iter = primitive_vars.find(prim); | |||||
| if (iter == primitive_vars.end()) { | |||||
| return; | |||||
| } | |||||
| (*equiv)[iter->second] = real_node; | |||||
| } | |||||
| EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, | |||||
| const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { | |||||
| int svar_index = GetSVarStartIndex(values_pattern); | |||||
| if (svar_index == kInvalidVarIndex) { | |||||
| return nullptr; | |||||
| } | |||||
| size_t values_pattern_len = values_pattern.size(); | |||||
| size_t values_expr_len = values_expr.size(); | |||||
| if (svar_index == -1) { | |||||
| if (values_pattern_len != values_expr_len) { | |||||
| // MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", | |||||
| // expr len " << values_expr_len; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| if (values_expr_len < values_pattern_len - 1) { | |||||
| MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; | |||||
| return nullptr; | |||||
| } | |||||
| size_t diff = values_expr_len - values_pattern_len + 1; | |||||
| for (size_t i = 0; i < values_pattern_len; i++) { | |||||
| size_t expr_i = i; | |||||
| if (svar_index != -1 && i == IntToSize(svar_index)) { | |||||
| auto seq = | |||||
| std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); | |||||
| equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); | |||||
| } else { | |||||
| if (svar_index != -1 && i > IntToSize(svar_index)) { | |||||
| expr_i = i + diff - 1; | |||||
| } | |||||
| equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); | |||||
| } | |||||
| if (equiv == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return equiv; | |||||
| } | |||||
| EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) const { | |||||
| MS_LOG(DEBUG) << "-----[in Match]"; | |||||
| // MS_LOG(DEBUG) << "GetVar w"; | |||||
| BaseRef pattern_ref = GetVar(pattern); | |||||
| // MS_LOG(DEBUG) << "GetVar v"; | |||||
| BaseRef expr_ref = expr; | |||||
| if (equiv == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Equiv pointer is null"; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); | |||||
| // 1. if pattern_ref is var and already in equiv, replace it. | |||||
| if (utils::isa<VarPtr>(pattern_ref)) { | |||||
| VarPtr var = utils::cast<VarPtr>(pattern_ref); | |||||
| auto iter = equiv->find(var); | |||||
| if (iter != equiv->end()) { | |||||
| pattern_ref = iter->second; | |||||
| } | |||||
| } | |||||
| // 2. check equal | |||||
| if (eq_(pattern_ref, expr_ref)) { | |||||
| return equiv; | |||||
| } | |||||
| // 3. match var | |||||
| EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); | |||||
| if (ret_equiv) { | |||||
| return ret_equiv; | |||||
| } | |||||
| // 4. here the type can be std:vector, std:list, | |||||
| // or cnode. | |||||
| if (!type_eq_(pattern_ref, expr_ref)) { | |||||
| MS_LOG(DEBUG) << "Type mismatch"; | |||||
| return nullptr; | |||||
| } | |||||
| // 5. transfer the Containers by visitor to std::vector | |||||
| VectorRef values_pattern; | |||||
| VectorRef values_expr; | |||||
| if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { | |||||
| return nullptr; | |||||
| } | |||||
| // 6. if any svar in both side, find the SeqVar index, | |||||
| // try to pack the Var s in std::vector to a Seq and match elements one by one. | |||||
| // check svar | |||||
| equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); | |||||
| UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); | |||||
| return equiv; | |||||
| } | |||||
| BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| MS_LOG(DEBUG) << "-----[in Replace]"; | |||||
| BaseRef ref = GetVar(pattern); | |||||
| BaseRef out; | |||||
| bool is_match = false; | |||||
| // w is var | |||||
| if (utils::isa<VarPtr>(ref)) { | |||||
| const VarPtr &var = utils::cast<VarPtr>(ref); | |||||
| auto iter = equiv->find(var); | |||||
| if (iter != equiv->end()) { | |||||
| out = iter->second; | |||||
| is_match = true; | |||||
| } | |||||
| } | |||||
| if (is_match) { | |||||
| return out; | |||||
| } | |||||
| // visitor to visit the list | |||||
| std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; | |||||
| visitor_->SetFn(fn); | |||||
| BaseRef visit_out; | |||||
| if (!visitor_->Visit(pattern, &visit_out)) { | |||||
| return pattern; | |||||
| } | |||||
| return visit_out; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -1,203 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ | |||||
| #include <string> | |||||
| #include <sstream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <unordered_set> | |||||
| #include <unordered_map> | |||||
| #include <initializer_list> | |||||
| #include <iostream> | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <stdexcept> | |||||
| #include <list> | |||||
| #include <utility> | |||||
| #include "src/gllo/common/visit.h" | |||||
| #include "mindspore/core/base/base.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "base/base_ref.h" | |||||
| namespace mindspore { | |||||
| class CondVar; | |||||
| class SeqVar; | |||||
| using CondVarPtr = std::shared_ptr<CondVar>; | |||||
| using SVarPtr = std::shared_ptr<SeqVar>; | |||||
| const int kInvalidVarIndex = -2; | |||||
| using ConditionFunc = std::function<bool(const BaseRef &)>; | |||||
| // Base wildcard variable which could match any anf node. | |||||
| class Var : public Base { | |||||
| friend class VarHasher; | |||||
| public: | |||||
| explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } | |||||
| explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { | |||||
| EnsureTag(); | |||||
| } | |||||
| Var(const Var &other) : Base(other), tag_(other.tag_) {} | |||||
| virtual Var &operator=(const Var &other) { | |||||
| if (&other == this) { | |||||
| return *this; | |||||
| } | |||||
| this->tag_ = other.tag_; | |||||
| return *this; | |||||
| } | |||||
| ~Var() override = default; | |||||
| MS_DECLARE_PARENT(Var, Base); | |||||
| virtual bool matches(const BaseRef &) { return true; } | |||||
| virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } | |||||
| bool operator!=(const Var &other) const { return !(&other == this); } | |||||
| std::string tag() const { return tag_; } | |||||
| PrimitivePtr primitive() const { return primitive_; } | |||||
| std::string ToString() const override { | |||||
| std::ostringstream buffer; | |||||
| buffer << "Var(" << tag_ << ")"; | |||||
| return buffer.str(); | |||||
| } | |||||
| std::size_t hash() const override { return std::hash<std::string>()(tag_); } | |||||
| protected: | |||||
| void EnsureTag(); | |||||
| std::string tag_; | |||||
| PrimitivePtr primitive_; | |||||
| }; | |||||
| // VarNode means variable node, a subclass of AnfNode | |||||
| class VarNode : public AnfNode { | |||||
| public: | |||||
| VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} | |||||
| ~VarNode() override = default; | |||||
| MS_DECLARE_PARENT(VarNode, AnfNode); | |||||
| const VarPtr var_; | |||||
| }; | |||||
| using VarNodePtr = std::shared_ptr<VarNode>; | |||||
| class VarHasher { | |||||
| public: | |||||
| std::size_t operator()(const Var &var) const { return var.hash(); } | |||||
| }; | |||||
| // Condition Var, match an anf node when condition function return true. | |||||
| class CondVar : public Var { | |||||
| public: | |||||
| explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} | |||||
| ~CondVar() override = default; | |||||
| MS_DECLARE_PARENT(CondVar, Var); | |||||
| bool matches(const BaseRef &value) override { | |||||
| // MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); | |||||
| if (utils::isa<Var>(value)) { | |||||
| return false; | |||||
| } | |||||
| return cond_fn_(value); | |||||
| } | |||||
| ConditionFunc cond_fn_; | |||||
| }; | |||||
| using Seq = VectorRef; | |||||
| using SeqPtr = std::shared_ptr<Seq>; | |||||
| // Sequence Var which could match multiple consecutive input nodes of a CNode. | |||||
| class SeqVar : public Var { | |||||
| public: | |||||
| SeqVar() : subvar_(std::make_shared<Var>()) {} | |||||
| ~SeqVar() override = default; | |||||
| MS_DECLARE_PARENT(SeqVar, Var); | |||||
| explicit SeqVar(const VarPtr subvar) : subvar_(subvar) {} | |||||
| bool matches(const BaseRef &value) override { | |||||
| // match Seq. | |||||
| if (utils::isa<Seq>(value)) { | |||||
| const Seq &seq = utils::cast<Seq>(value); | |||||
| return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { | |||||
| auto eq = subvar_->matches(v); | |||||
| return eq; | |||||
| }); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } | |||||
| std::string ToString() const override; | |||||
| private: | |||||
| VarPtr subvar_; | |||||
| }; | |||||
| bool operator==(const VarPtr &lhs, const VarPtr &rhs); | |||||
| inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } | |||||
| std::ostream &operator<<(std::ostream &os, const VarPtr &var); | |||||
| using Equiv = std::map<VarPtr, BaseRef>; | |||||
| using EquivPtr = std::shared_ptr<Equiv>; | |||||
| using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>; | |||||
| using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>; | |||||
| inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } | |||||
| class PatternEngine { | |||||
| public: | |||||
| PatternEngine(const std::shared_ptr<Visitor> &visitor, | |||||
| const std::function<bool(const BaseRef &, const BaseRef &)> &eq, | |||||
| const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq) | |||||
| : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} | |||||
| ~PatternEngine() = default; | |||||
| EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, | |||||
| EquivPtr equiv) const; | |||||
| // Replace pattern with equivalent | |||||
| BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; | |||||
| private: | |||||
| EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, | |||||
| const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; | |||||
| bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const; | |||||
| bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, | |||||
| VectorRef *const values_expr) const; | |||||
| std::shared_ptr<Visitor> visitor_; | |||||
| std::function<bool(const BaseRef &, const BaseRef &)> eq_; | |||||
| std::function<bool(const BaseRef &, const BaseRef &)> type_eq_; | |||||
| }; | |||||
| } // namespace mindspore | |||||
| namespace std { | |||||
| using mindspore::ERROR; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::NoExceptionType; | |||||
| template <> | |||||
| struct hash<mindspore::VarPtr> { | |||||
| std::size_t operator()(const mindspore::VarPtr var) const { | |||||
| if (var == nullptr) { | |||||
| MS_LOG(ERROR) << "Invalid var ptr"; | |||||
| return 0; | |||||
| } | |||||
| return std::hash<std::string>{}(var->tag()); | |||||
| } | |||||
| }; | |||||
| } // namespace std | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ | |||||
| @@ -1,165 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include "src/gllo/common/visit.h" | |||||
| #include "src/gllo/common/pattern_engine.h" | |||||
| #include "utils/any.h" | |||||
| #include "ir/anf.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| bool CheckIfNeedExpand(const std::vector<BaseRef> &list) { | |||||
| return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); }); | |||||
| } | |||||
| std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) { | |||||
| std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>(); | |||||
| for (auto &item : list) { | |||||
| if (utils::isa<Seq>(item)) { | |||||
| const Seq &seq = utils::cast<Seq>(item); | |||||
| new_list->insert(new_list->end(), seq.begin(), seq.end()); | |||||
| } else { | |||||
| new_list->push_back(item); | |||||
| } | |||||
| } | |||||
| return new_list; | |||||
| } | |||||
| bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { | |||||
| std::vector<BaseRef> out; | |||||
| (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), | |||||
| [this](const BaseRef &item) { return fn_(item); }); | |||||
| if (visit_out != nullptr) { | |||||
| *visit_out = ExpandList(out); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { | |||||
| if (utils::isa<Seq>(any)) { | |||||
| return Visit(utils::cast<Seq>(any), visit_out); | |||||
| } else if (utils::isa<AnfNodePtr>(any)) { | |||||
| auto nodeptr = utils::cast<AnfNodePtr>(any); | |||||
| AnfNodePtr output; | |||||
| AnfNodePtr *p_output = &output; | |||||
| if (visit_out == nullptr) { | |||||
| p_output = nullptr; | |||||
| } | |||||
| Visit(nodeptr, fn_, p_output); | |||||
| if (visit_out != nullptr) { | |||||
| *visit_out = output; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); | |||||
| return false; | |||||
| } | |||||
| void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { | |||||
| if (node->isa<CNode>()) { | |||||
| Visit(node->cast<CNodePtr>(), fn, output); | |||||
| return; | |||||
| } | |||||
| if (node->isa<ValueNode>()) { | |||||
| Visit(node->cast<ValueNodePtr>(), fn, output); | |||||
| return; | |||||
| } | |||||
| if (output != nullptr) { | |||||
| *output = node; | |||||
| } | |||||
| } | |||||
| void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { | |||||
| // if output is nullptr, it's not required to make the new CNode node. | |||||
| if (output == nullptr) { | |||||
| for (auto &inp : cnode->inputs()) { | |||||
| (void)fn(inp); | |||||
| } | |||||
| if (cnode->func_graph() != nullptr) { | |||||
| (void)fn(cnode->func_graph()); | |||||
| } else { | |||||
| (void)fn(cnode->func_graph_as_var()); | |||||
| } | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| std::vector<BaseRef> after_cnode_fn; | |||||
| std::shared_ptr<VectorRef> out; | |||||
| (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); | |||||
| if (CheckIfNeedExpand(after_cnode_fn)) { | |||||
| out = ExpandList(after_cnode_fn); | |||||
| } | |||||
| std::vector<BaseRef> &outs = after_cnode_fn; | |||||
| if (out != nullptr) { | |||||
| outs = out->elements(); | |||||
| } | |||||
| for (auto &any_item : outs) { | |||||
| if (!utils::isa<AnfNodePtr>(any_item)) { | |||||
| MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; | |||||
| } | |||||
| new_inputs.push_back(utils::cast<AnfNodePtr>(any_item)); | |||||
| } | |||||
| BaseRef any_fg; | |||||
| AnfNodePtr new_cnode = nullptr; | |||||
| if (cnode->func_graph() != nullptr) { | |||||
| any_fg = fn(cnode->func_graph()); | |||||
| if (!utils::isa<FuncGraphPtr>(any_fg)) { | |||||
| MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; | |||||
| } | |||||
| new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg)); | |||||
| } else { | |||||
| any_fg = fn(cnode->func_graph_as_var()); | |||||
| if (utils::isa<VarPtr>(any_fg)) { | |||||
| new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg)); | |||||
| } else if (utils::isa<FuncGraphPtr>(any_fg)) { | |||||
| new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg)); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; | |||||
| } | |||||
| } | |||||
| new_cnode->set_abstract(cnode->abstract()); | |||||
| *output = new_cnode; | |||||
| } | |||||
| void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { | |||||
| const BaseRef &value = utils::cast<ValuePtr>(fn(vnode->value())); | |||||
| if (utils::isa<ValuePtr>(value)) { | |||||
| if (output != nullptr) { | |||||
| auto ct = NewValueNode(utils::cast<ValuePtr>(value)); | |||||
| ct->set_abstract(vnode->abstract()); | |||||
| *output = ct; | |||||
| } | |||||
| return; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -1,59 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ | |||||
| #define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ | |||||
| #include <unordered_map> | |||||
| #include <stdexcept> | |||||
| #include <list> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "mindspore/core/base/base.h" | |||||
| #include "base/base_ref.h" | |||||
| namespace mindspore { | |||||
| using VisitFn = std::function<BaseRef(const BaseRef &)>; | |||||
| class Visitor { | |||||
| public: | |||||
| virtual void SetFn(VisitFn fn) = 0; | |||||
| virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0; | |||||
| virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0; | |||||
| virtual ~Visitor() = default; | |||||
| }; | |||||
| class DefaultVisitor : public Visitor { | |||||
| public: | |||||
| DefaultVisitor() : fn_(nullptr) {} | |||||
| ~DefaultVisitor() override = default; | |||||
| void SetFn(VisitFn fn) override { fn_ = fn; }; | |||||
| bool Visit(const VectorRef &e, BaseRef *out) const override; | |||||
| bool Visit(const BaseRef &e, BaseRef *out) const override; | |||||
| void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const; | |||||
| void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const; | |||||
| void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const; | |||||
| VisitFn fn_; | |||||
| }; | |||||
| std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list); | |||||
| bool CheckIfNeedExpand(const std::vector<BaseRef> &list); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ | |||||
| @@ -14,12 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h" | |||||
| #include "src/gllo/fusion/conv_activation_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| #include "mindspore/ccsrc/utils/utils.h" | |||||
| #include "mindspore/lite/src/gllo/common/utils.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| @@ -18,7 +18,7 @@ | |||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ | ||||
| #include <string> | #include <string> | ||||
| #include "mindspore/lite/src/gllo/common/optimizer.h" | |||||
| #include "src/gllo/common/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -13,13 +13,13 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h" | |||||
| #include <mindspore/lite/src/param_value_lite.h> | |||||
| #include "src/gllo/fusion/conv_biasadd_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| #include "mindspore/ccsrc/utils/utils.h" | |||||
| #include "mindspore/lite/src/gllo/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| @@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons | |||||
| CheckIfCNodeIsNull(conv_node); | CheckIfCNodeIsNull(conv_node); | ||||
| GenConvNewBias(func_graph, conv_node, add_node); | GenConvNewBias(func_graph, conv_node, add_node); | ||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0)); | auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0)); | ||||
| MS_ASSERT(primitiveT_value); | |||||
| MS_ASSERT(primitiveT_value != nullptr); | |||||
| auto type = primitiveT_value->GetPrimitiveT()->value.type; | auto type = primitiveT_value->GetPrimitiveT()->value.type; | ||||
| if (type == schema::PrimitiveType_Conv2D) { | if (type == schema::PrimitiveType_Conv2D) { | ||||
| primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; | primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; | ||||
| @@ -17,7 +17,7 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | ||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ | ||||
| #include "mindspore/lite/src/gllo/common/optimizer.h" | |||||
| #include "src/gllo/common/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -14,13 +14,13 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h" | |||||
| #include <mindspore/lite/src/param_value_lite.h> | |||||
| #include "src/gllo/fusion/conv_bn_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| #include "mindspore/ccsrc/utils/utils.h" | |||||
| #include "mindspore/lite/src/gllo/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| @@ -17,7 +17,7 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | ||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" | |||||
| #include "src/gllo/fusion/conv_transform_fusion.h" | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| class ConvBatchNormFusion : public ConvTransformFusion { | class ConvBatchNormFusion : public ConvTransformFusion { | ||||
| @@ -14,13 +14,13 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h" | |||||
| #include "src/gllo/fusion/conv_scale_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/lite/src/param_value_lite.h" | |||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| #include "mindspore/ccsrc/utils/utils.h" | |||||
| #include "mindspore/lite/src/gllo/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| @@ -17,7 +17,7 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | ||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" | |||||
| #include "src/gllo/fusion/conv_transform_fusion.h" | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| class ConvScaleFusion : public ConvTransformFusion { | class ConvScaleFusion : public ConvTransformFusion { | ||||
| @@ -14,13 +14,13 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" | |||||
| #include "src/gllo/fusion/conv_transform_fusion.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/lite/src/param_value_lite.h" | |||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/src/ir/primitive_t_value.h" | |||||
| #include "mindspore/ccsrc/utils/utils.h" | |||||
| #include "mindspore/lite/src/gllo/common/utils.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "utils/utils.h" | |||||
| #include "src/gllo/common/gllo_utils.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| @@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||||
| GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); | GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); | ||||
| delete[] trans_bias; | delete[] trans_bias; | ||||
| delete[] trans_scale; | delete[] trans_scale; | ||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0)); | |||||
| MS_ASSERT(primitiveT_value != nullptr); | |||||
| auto type = primitiveT_value->GetPrimitiveT()->value.type; | |||||
| if (type == schema::PrimitiveType_Conv2D) { | |||||
| primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; | |||||
| } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unsupported opType, " << type; | |||||
| } | |||||
| return pre_node; | return pre_node; | ||||
| } | } | ||||
| @@ -18,7 +18,7 @@ | |||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ | ||||
| #include <string> | #include <string> | ||||
| #include "mindspore/lite/src/gllo/common/optimizer.h" | |||||
| #include "src/gllo/common/optimizer.h" | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| class ConvTransformFusion : public PatternProcessPass { | class ConvTransformFusion : public PatternProcessPass { | ||||
| @@ -63,6 +63,8 @@ if(BUILD_CONVERTER) | |||||
| ${CCSRC_DIR}/pybind_api/export_flags.cc | ${CCSRC_DIR}/pybind_api/export_flags.cc | ||||
| ${CCSRC_DIR}/utils/context/context_extends.cc | ${CCSRC_DIR}/utils/context/context_extends.cc | ||||
| ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc | ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc | ||||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc | ||||
| ) | ) | ||||
| else() | else() | ||||
| @@ -202,12 +204,14 @@ if(BUILD_CONVERTER) | |||||
| ${LITE_DIR}/tools/converter/converter.cc | ${LITE_DIR}/tools/converter/converter.cc | ||||
| ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc | ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc | ||||
| ${LITE_DIR}/test/st/converter_test.cc | ${LITE_DIR}/test/st/converter_test.cc | ||||
| ${LITE_DIR}/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc | |||||
| ${LITE_DIR}/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc | |||||
| ${LITE_DIR}/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc | |||||
| ${LITE_DIR}/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc | |||||
| ${LITE_DIR}/src/gllo/common/node_pass.cc | ${LITE_DIR}/src/gllo/common/node_pass.cc | ||||
| ${LITE_DIR}/src/gllo/common/optimizer.cc | ${LITE_DIR}/src/gllo/common/optimizer.cc | ||||
| ${LITE_DIR}/src/gllo/common/pass_manager.cc | ${LITE_DIR}/src/gllo/common/pass_manager.cc | ||||
| ${LITE_DIR}/src/gllo/common/pattern_engine.cc | |||||
| ${LITE_DIR}/src/gllo/common/visit.cc | |||||
| ${LITE_DIR}/src/gllo/common/utils.cc | |||||
| ${LITE_DIR}/src/gllo/common/gllo_utils.cc | |||||
| ${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc | ${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc | ||||
| ${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc | ${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc | ||||
| ${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc | ${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc | ||||
| @@ -0,0 +1,184 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "tools/converter/model_parser.h" | |||||
| #include "tools/converter/anf_transform.h" | |||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | |||||
| class ConvActivationFusionTest : public mindspore::Common { | |||||
| public: | |||||
| ConvActivationFusionTest() = default; | |||||
| }; | |||||
| using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>; | |||||
| using CNodeTptr = std::unique_ptr<schema::CNodeT>; | |||||
| namespace { | |||||
| CNodeTptr BuildConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| auto prim1 = new schema::Conv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelOut = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| CNodeTptr BuildDepthwiseConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| auto prim1 = new schema::DepthwiseConv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelIn = 1; | |||||
| prim1->channelMultiplier = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, | |||||
| schema::ActivationType activation_type) { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // conv node | |||||
| CNodeTptr convNode; | |||||
| if (conv_type == schema::PrimitiveType_Conv2D) { | |||||
| convNode = BuildConv2D(); | |||||
| } else { | |||||
| convNode = BuildDepthwiseConv2D(); | |||||
| } | |||||
| meta_graph->nodes.emplace_back(std::move(convNode)); | |||||
| // relu node | |||||
| auto next_node = std::make_unique<schema::CNodeT>(); | |||||
| next_node->inputIndex = {2}; | |||||
| next_node->outputIndex = {3}; | |||||
| next_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| next_node->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| auto prim2 = new schema::ActivationT; | |||||
| prim2->type = activation_type; | |||||
| next_node->primitive->value.value = prim2; | |||||
| next_node->name = "activation"; | |||||
| meta_graph->nodes.emplace_back(std::move(next_node)); | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {3}; | |||||
| // input 0: data | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 5, 5, 3}; | |||||
| input0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // input 1: weight | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_KHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {8, 3, 3, 3}; | |||||
| input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| // conv output | |||||
| auto conv_output = std::make_unique<schema::TensorT>(); | |||||
| conv_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| conv_output->format = schema::Format_NHWC; | |||||
| conv_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| conv_output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(conv_output)); | |||||
| // final output | |||||
| auto output = std::make_unique<schema::TensorT>(); | |||||
| output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| output->format = schema::Format_NHWC; | |||||
| output->dataType = TypeId::kNumberTypeFloat32; | |||||
| output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(output)); | |||||
| return meta_graph; | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(ConvActivationFusionTest, TestConvReluNode) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU); | |||||
| } | |||||
| } | |||||
| TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU6); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU6); | |||||
| } | |||||
| } | |||||
| TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::ActivationType_LEAKY_RELU); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 2); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->activationType, schema::ActivationType_NO_ACTIVATION); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "tools/converter/model_parser.h" | |||||
| #include "tools/converter/anf_transform.h" | |||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | |||||
| class ConvBiasAddFusionTest : public mindspore::Common { | |||||
| public: | |||||
| ConvBiasAddFusionTest() = default; | |||||
| }; | |||||
| using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>; | |||||
| using CNodeTptr = std::unique_ptr<schema::CNodeT>; | |||||
| namespace { | |||||
| CNodeTptr BuildConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| auto prim1 = new schema::Conv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelOut = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| CNodeTptr BuildDepthwiseConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| auto prim1 = new schema::DepthwiseConv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelIn = 1; | |||||
| prim1->channelMultiplier = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, | |||||
| schema::PrimitiveType add_type) { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // conv node | |||||
| CNodeTptr convNode; | |||||
| if (conv_type == schema::PrimitiveType_Conv2D) { | |||||
| convNode = BuildConv2D(); | |||||
| } else { | |||||
| convNode = BuildDepthwiseConv2D(); | |||||
| } | |||||
| meta_graph->nodes.emplace_back(std::move(convNode)); | |||||
| // biasadd node | |||||
| auto biasadd_node = std::make_unique<schema::CNodeT>(); | |||||
| biasadd_node->inputIndex = {2, 3}; | |||||
| biasadd_node->outputIndex = {4}; | |||||
| biasadd_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| biasadd_node->primitive->value.type = add_type; | |||||
| auto prim2 = new schema::BiasAddT; | |||||
| biasadd_node->primitive->value.value = prim2; | |||||
| biasadd_node->name = "BiasAdd"; | |||||
| meta_graph->nodes.emplace_back(std::move(biasadd_node)); | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {4}; | |||||
| // input 0: data | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 5, 5, 3}; | |||||
| input0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // input 1: weight | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_KHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {8, 3, 3, 3}; | |||||
| input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| // conv output | |||||
| auto conv_output = std::make_unique<schema::TensorT>(); | |||||
| conv_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| conv_output->format = schema::Format_NHWC; | |||||
| conv_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| conv_output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(conv_output)); | |||||
| // input2: bias | |||||
| auto input2 = std::make_unique<schema::TensorT>(); | |||||
| input2->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input2->format = schema::Format_NHWC; | |||||
| input2->dataType = TypeId::kNumberTypeFloat32; | |||||
| input2->dims = {1, 5, 5, 8}; | |||||
| input2->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input2)); | |||||
| // final output | |||||
| auto output = std::make_unique<schema::TensorT>(); | |||||
| output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| output->format = schema::Format_NHWC; | |||||
| output->dataType = TypeId::kNumberTypeFloat32; | |||||
| output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(output)); | |||||
| return meta_graph; | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::PrimitiveType_BiasAdd); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); | |||||
| } | |||||
| MS_LOG(INFO) << "Passed"; | |||||
| } | |||||
| TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); | |||||
| } | |||||
| } | |||||
| TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_MatMul); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 2); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, false); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,296 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "mindspore/core/utils/log_adapter.h" | |||||
| #include "tools/converter/model_parser.h" | |||||
| #include "tools/converter/anf_transform.h" | |||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | |||||
| class ConvBNFusionTest : public mindspore::Common { | |||||
| public: | |||||
| ConvBNFusionTest() = default; | |||||
| }; | |||||
| using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>; | |||||
| using CNodeTptr = std::unique_ptr<schema::CNodeT>; | |||||
| namespace { | |||||
| CNodeTptr BuildConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| auto prim1 = new schema::Conv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelOut = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| CNodeTptr BuildDepthwiseConv2D() { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| convNode->inputIndex = {0, 1, 2}; | |||||
| convNode->outputIndex = {3}; | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| auto prim1 = new schema::DepthwiseConv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelIn = 1; | |||||
| prim1->channelMultiplier = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| // caffe bn op has 3 inputs | |||||
| MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // conv node | |||||
| CNodeTptr convNode; | |||||
| if (conv_type == schema::PrimitiveType_Conv2D) { | |||||
| convNode = BuildConv2D(); | |||||
| } else { | |||||
| convNode = BuildDepthwiseConv2D(); | |||||
| } | |||||
| meta_graph->nodes.emplace_back(std::move(convNode)); | |||||
| // bn_node | |||||
| auto bn_node = std::make_unique<schema::CNodeT>(); | |||||
| bn_node->inputIndex = {2, 3, 4}; | |||||
| bn_node->outputIndex = {5}; | |||||
| bn_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| bn_node->primitive->value.type = schema::PrimitiveType_CaffeBatchNorm; | |||||
| auto prim2 = new schema::CaffeBatchNormT; | |||||
| bn_node->primitive->value.value = prim2; | |||||
| bn_node->name = "bn"; | |||||
| meta_graph->nodes.emplace_back(std::move(bn_node)); | |||||
| // input 0: data | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 5, 5, 3}; | |||||
| input0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // input 1: weight | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_KHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {8, 3, 3, 3}; | |||||
| input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| // conv output | |||||
| auto conv_output = std::make_unique<schema::TensorT>(); | |||||
| conv_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| conv_output->format = schema::Format_NHWC; | |||||
| conv_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| conv_output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(conv_output)); | |||||
| // caffe bn : mean | |||||
| auto input2 = std::make_unique<schema::TensorT>(); | |||||
| input2->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input2->format = schema::Format_NHWC; | |||||
| input2->dataType = TypeId::kNumberTypeFloat32; | |||||
| input2->dims = {1, 5, 5, 8}; | |||||
| input2->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input2)); | |||||
| // caffe bn : var | |||||
| auto input3 = std::make_unique<schema::TensorT>(); | |||||
| input3->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input3->format = schema::Format_NHWC; | |||||
| input3->dataType = TypeId::kNumberTypeFloat32; | |||||
| input3->dims = {1, 5, 5, 8}; | |||||
| input3->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input3)); | |||||
| // final bn output | |||||
| auto output = std::make_unique<schema::TensorT>(); | |||||
| output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| output->format = schema::Format_NHWC; | |||||
| output->dataType = TypeId::kNumberTypeFloat32; | |||||
| output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(output)); | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {5}; | |||||
| return meta_graph; | |||||
| } | |||||
| // tf bn op has 4 inputs | |||||
| MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // conv node | |||||
| CNodeTptr convNode; | |||||
| if (conv_type == schema::PrimitiveType_Conv2D) { | |||||
| convNode = BuildConv2D(); | |||||
| } else { | |||||
| convNode = BuildDepthwiseConv2D(); | |||||
| } | |||||
| meta_graph->nodes.emplace_back(std::move(convNode)); | |||||
| // bn_node | |||||
| auto bn_node = std::make_unique<schema::CNodeT>(); | |||||
| bn_node->inputIndex = {3, 4, 5, 6, 7}; | |||||
| bn_node->outputIndex = {8}; | |||||
| bn_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| bn_node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||||
| auto prim2 = new schema::FusedBatchNormT; | |||||
| bn_node->primitive->value.value = prim2; | |||||
| bn_node->name = "bn"; | |||||
| meta_graph->nodes.emplace_back(std::move(bn_node)); | |||||
| // input 0: data | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 5, 5, 3}; | |||||
| input0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // input 1: conv_bias | |||||
| auto input11 = std::make_unique<schema::TensorT>(); | |||||
| input11->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input11->format = schema::Format_KHWC; | |||||
| input11->dataType = TypeId::kNumberTypeFloat32; | |||||
| input11->dims = {8, 3, 3, 3}; | |||||
| input11->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input11)); | |||||
| // input 1: weight | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_KHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {8, 3, 3, 3}; | |||||
| input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| // conv output | |||||
| auto conv_output = std::make_unique<schema::TensorT>(); | |||||
| conv_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| conv_output->format = schema::Format_NHWC; | |||||
| conv_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| conv_output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(conv_output)); | |||||
| // tflite bn : scale | |||||
| auto input2 = std::make_unique<schema::TensorT>(); | |||||
| input2->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input2->format = schema::Format_NHWC; | |||||
| input2->dataType = TypeId::kNumberTypeFloat32; | |||||
| input2->dims = {1, 5, 5, 8}; | |||||
| input2->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input2)); | |||||
| // tflite bn : bias | |||||
| auto input3 = std::make_unique<schema::TensorT>(); | |||||
| input3->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input3->format = schema::Format_NHWC; | |||||
| input3->dataType = TypeId::kNumberTypeFloat32; | |||||
| input3->dims = {1, 5, 5, 8}; | |||||
| input3->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input3)); | |||||
| // tflite bn : mean | |||||
| auto input4 = std::make_unique<schema::TensorT>(); | |||||
| input4->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input4->format = schema::Format_NHWC; | |||||
| input4->dataType = TypeId::kNumberTypeFloat32; | |||||
| input4->dims = {1, 5, 5, 8}; | |||||
| input4->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input4)); | |||||
| // tflite bn : var | |||||
| auto input5 = std::make_unique<schema::TensorT>(); | |||||
| input5->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input5->format = schema::Format_NHWC; | |||||
| input5->dataType = TypeId::kNumberTypeFloat32; | |||||
| input5->dims = {1, 5, 5, 8}; | |||||
| input5->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input5)); | |||||
| // final output | |||||
| auto output = std::make_unique<schema::TensorT>(); | |||||
| output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| output->format = schema::Format_NHWC; | |||||
| output->dataType = TypeId::kNumberTypeFloat32; | |||||
| output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(output)); | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {8}; | |||||
| return meta_graph; | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(ConvBNFusionTest, TestConvAddNode) { | |||||
| auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2D); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); | |||||
| } | |||||
| } | |||||
| TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { | |||||
| auto meta_graph = BuildTFGraph(schema::PrimitiveType_DepthwiseConv2D); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,221 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "tools/converter/model_parser.h" | |||||
| #include "tools/converter/anf_transform.h" | |||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | |||||
| class ConvScaleFusionTest : public mindspore::Common { | |||||
| public: | |||||
| ConvScaleFusionTest() = default; | |||||
| }; | |||||
| using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>; | |||||
| using CNodeTptr = std::unique_ptr<schema::CNodeT>; | |||||
| namespace { | |||||
| // conv has 2 inputs | |||||
| CNodeTptr BuildConv2D(int with_bias_flag) { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| if (with_bias_flag) { | |||||
| convNode->inputIndex = {0, 1, 2}; | |||||
| convNode->outputIndex = {3}; | |||||
| } else { | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| } | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| auto prim1 = new schema::Conv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelOut = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| // conv2d has 3 inputs | |||||
| CNodeTptr BuildDepthwiseConv2D(int with_bias_flag) { | |||||
| auto convNode = std::make_unique<schema::CNodeT>(); | |||||
| if (with_bias_flag) { | |||||
| convNode->inputIndex = {0, 1, 2}; | |||||
| convNode->outputIndex = {3}; | |||||
| } else { | |||||
| convNode->inputIndex = {0, 1}; | |||||
| convNode->outputIndex = {2}; | |||||
| } | |||||
| convNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| auto prim1 = new schema::DepthwiseConv2DT; | |||||
| prim1->padMode = schema::PadMode_SAME; | |||||
| prim1->format = schema::Format_NHWC; | |||||
| prim1->strideH = 1; | |||||
| prim1->strideW = 1; | |||||
| prim1->kernelH = 3; | |||||
| prim1->kernelW = 3; | |||||
| prim1->dilateH = 1; | |||||
| prim1->dilateW = 1; | |||||
| prim1->channelIn = 1; | |||||
| prim1->channelMultiplier = 3; | |||||
| convNode->primitive->value.value = prim1; | |||||
| convNode->name = "Conv2D"; | |||||
| return convNode; | |||||
| } | |||||
| MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // conv node | |||||
| CNodeTptr convNode; | |||||
| if (conv_type == schema::PrimitiveType_Conv2D) { | |||||
| convNode = BuildConv2D(conv_with_bias); | |||||
| } else { | |||||
| convNode = BuildDepthwiseConv2D(conv_with_bias); | |||||
| } | |||||
| meta_graph->nodes.emplace_back(std::move(convNode)); | |||||
| // scale_node weight bias | |||||
| auto scale_node = std::make_unique<schema::CNodeT>(); | |||||
| if (conv_with_bias) { | |||||
| scale_node->inputIndex = {3, 4, 5}; | |||||
| scale_node->outputIndex = {6}; | |||||
| } else { | |||||
| scale_node->inputIndex = {2, 3, 4}; | |||||
| scale_node->outputIndex = {5}; | |||||
| } | |||||
| scale_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| scale_node->primitive->value.type = schema::PrimitiveType_Scale; | |||||
| auto prim2 = new schema::ScaleT; | |||||
| scale_node->primitive->value.value = prim2; | |||||
| scale_node->name = "scale"; | |||||
| meta_graph->nodes.emplace_back(std::move(scale_node)); | |||||
| // input 0: data | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 5, 5, 3}; | |||||
| input0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // input 1: weight | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_KHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {8, 3, 3, 3}; | |||||
| input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| if (conv_with_bias) { | |||||
| // input 00: bias | |||||
| auto input00 = std::make_unique<schema::TensorT>(); | |||||
| input00->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input00->format = schema::Format_NHWC; | |||||
| input00->dataType = TypeId::kNumberTypeFloat32; | |||||
| input00->dims = {1, 5, 5, 3}; | |||||
| input00->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(input00)); | |||||
| } | |||||
| // conv output | |||||
| auto conv_output = std::make_unique<schema::TensorT>(); | |||||
| conv_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| conv_output->format = schema::Format_NHWC; | |||||
| conv_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| conv_output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(conv_output)); | |||||
| // scale weight input | |||||
| auto input2 = std::make_unique<schema::TensorT>(); | |||||
| input2->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input2->format = schema::Format_NHWC; | |||||
| input2->dataType = TypeId::kNumberTypeFloat32; | |||||
| input2->dims = {1, 5, 5, 8}; | |||||
| input2->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input2)); | |||||
| // scale bias input | |||||
| auto input3 = std::make_unique<schema::TensorT>(); | |||||
| input3->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input3->format = schema::Format_NHWC; | |||||
| input3->dataType = TypeId::kNumberTypeFloat32; | |||||
| input3->dims = {1, 5, 5, 8}; | |||||
| input3->data.resize(sizeof(float) * 8 * 5 * 5); | |||||
| meta_graph->allTensors.emplace_back(std::move(input3)); | |||||
| // final scale output | |||||
| auto output = std::make_unique<schema::TensorT>(); | |||||
| output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| output->format = schema::Format_NHWC; | |||||
| output->dataType = TypeId::kNumberTypeFloat32; | |||||
| output->dims = {1, 5, 5, 8}; | |||||
| meta_graph->allTensors.emplace_back(std::move(output)); | |||||
| if (conv_with_bias) { | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {6}; | |||||
| } else { | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {5}; | |||||
| } | |||||
| return meta_graph; | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(ConvScaleFusionTest, TestConvScaleNode) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, true); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); | |||||
| } | |||||
| } | |||||
| TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { | |||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, false); | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto anf_transform = new lite::AnfTransform(); | |||||
| auto new_graph = anf_transform->Transform(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 1); | |||||
| for (auto &cnode : new_meta_graph->nodes) { | |||||
| ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); | |||||
| ASSERT_EQ(cnode->inputIndex.size(), 3); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -49,6 +49,8 @@ set(ANF_SRC | |||||
| ${CCSRC_DIR}/pybind_api/export_flags.cc | ${CCSRC_DIR}/pybind_api/export_flags.cc | ||||
| ${CCSRC_DIR}/utils/context/context_extends.cc | ${CCSRC_DIR}/utils/context/context_extends.cc | ||||
| ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc | ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc | ||||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc | ||||
| ) | ) | ||||
| @@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pattern_engine.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/gllo_utils.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc | ||||
| @@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // auto newGraph = anfTransform->Transform(graph); | |||||
| graph = anfTransform->Transform(graph); | |||||
| CreateQuantizer(graph, flag); | CreateQuantizer(graph, flag); | ||||
| if (mQuantizer != nullptr) { | if (mQuantizer != nullptr) { | ||||
| @@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| // } | // } | ||||
| // fusion | // fusion | ||||
| { | |||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| // { | |||||
| // Optimizer fusionOptimizer; | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); | |||||
| // fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| // status = fusionOptimizer.Run(graphDefT); | |||||
| // if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| // MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||||
| // return status; | |||||
| // } | |||||
| // } | |||||
| // weight format trans | // weight format trans | ||||
| if (ctx.formatTrans) { | if (ctx.formatTrans) { | ||||