|
|
|
@@ -21,19 +21,20 @@ |
|
|
|
#include <vector> |
|
|
|
#include "include/registry/pass_registry.h" |
|
|
|
#include "ops/custom.h" |
|
|
|
#include "ops/fusion/add_fusion.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
// check a certain node is designated node's type. |
|
|
|
bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { |
|
|
|
bool CheckPrimitiveTypeTutorial(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type) { |
|
|
|
if (node == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (node->isa<api::CNode>()) { |
|
|
|
auto cnode = node->cast<api::CNodePtr>(); |
|
|
|
return IsPrimitive(cnode->input(0), primitive_type); |
|
|
|
} else if (node->isa<ValueNode>()) { |
|
|
|
} else if (node->isa<api::ValueNode>()) { |
|
|
|
return IsPrimitive(node, primitive_type); |
|
|
|
} |
|
|
|
return false; |
|
|
|
@@ -41,11 +42,11 @@ bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &prim |
|
|
|
} // namespace |
|
|
|
|
|
|
|
// convert addn to custom op |
|
|
|
AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) { |
|
|
|
api::AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const api::CNodePtr &cnode) { |
|
|
|
if (cnode == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto primc = std::make_shared<ops::Custom>(); |
|
|
|
auto primc = api::MakeShared<ops::Custom>(); |
|
|
|
if (primc == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -78,13 +79,13 @@ bool PassTutorial::Execute(const api::FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return()); |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (!utils::isa<CNode>(node)) { |
|
|
|
if (!api::utils::isa<api::CNode>(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!CheckPrimitiveTypeTutorial(node, prim::kPrimAddFusion)) { |
|
|
|
if (!CheckPrimitiveTypeTutorial(node, mindspore::api::MakeShared<mindspore::ops::AddFusion>())) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto cnode = node->cast<api::CNodePtr>(); |
|
|
|
auto custome_cnode = CreateCustomOp(func_graph, cnode); |
|
|
|
if (custome_cnode == nullptr) { |
|
|
|
return false; |
|
|
|
|