|
|
|
@@ -15,12 +15,28 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "tools/optimizer/graph/redundant_op_remove_pass.h" |
|
|
|
#include "mindspore/lite/include/errorcode.h" |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "ops/make_tuple.h" |
|
|
|
|
|
|
|
namespace mindspore::opt { |
|
|
|
namespace { |
|
|
|
constexpr size_t InputDoubleNum = 2; |
|
|
|
constexpr size_t InputTripleNum = 3; |
|
|
|
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) { |
|
|
|
MS_ASSERT(anf_node != nullptr); |
|
|
|
MS_ASSERT(inputs != nullptr); |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
if (cnode->input(i)->isa<CNode>()) { |
|
|
|
inputs->push_back(cnode->input(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { |
|
|
|
if (!utils::isa<CNodePtr>(anf_node)) { |
|
|
|
@@ -58,6 +74,36 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) { |
|
|
|
if (!utils::isa<CNodePtr>(anf_node)) { |
|
|
|
MS_LOG(DEBUG) << "anf node is node a cnode."; |
|
|
|
return lite::RET_NO_CHANGE; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
if (!inputs[i]->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(inputs[i], prim::kPrimMakeTuple)) { |
|
|
|
FetchCNodeFromMakeTuple(inputs[i], &new_inputs); |
|
|
|
continue; |
|
|
|
} |
|
|
|
new_inputs.push_back(inputs[i]); |
|
|
|
} |
|
|
|
for (auto &node : new_inputs) { |
|
|
|
func_graph->get_return()->add_input(node); |
|
|
|
} |
|
|
|
auto value = std::make_shared<UMonad>(); |
|
|
|
bool replace_succ = func_graph->manager()->Replace(anf_node, NewValueNode(value)); |
|
|
|
if (!replace_succ) { |
|
|
|
MS_LOG(ERROR) << "replace redundant op failed."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { |
|
|
|
if (!utils::isa<CNodePtr>(anf_node)) { |
|
|
|
MS_LOG(DEBUG) << "anf node is node a cnode."; |
|
|
|
@@ -111,7 +157,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { |
|
|
|
status = ReplaceOp(node, manager); |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) { |
|
|
|
status = ReplaceOp(node, manager); |
|
|
|
status = ReplaceUpdateStateOp(func_graph, node); |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { |
|
|
|
status = ReplaceTupleGetItem(node, manager); |
|
|
|
|