|
|
|
@@ -28,7 +28,7 @@ namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
constexpr auto kSingleInputIndex = 1; |
|
|
|
namespace { |
|
|
|
AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { |
|
|
|
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return nullptr; |
|
|
|
@@ -40,6 +40,9 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { |
|
|
|
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (!IsNotRealUsedByOthers(func_graph, cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
CheckCNodeInputSize(cnode, kSingleInputIndex + 1); |
|
|
|
return cnode->input(kSingleInputIndex); |
|
|
|
} |
|
|
|
@@ -50,10 +53,11 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs; |
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; |
|
|
|
bool need_update = false; |
|
|
|
for (const auto &input : cnode->inputs()) { |
|
|
|
AnfNodePtr replace_input = GetReplaceNode(input); |
|
|
|
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { |
|
|
|
auto input = AnfAlgo::GetInputNode(cnode, index); |
|
|
|
AnfNodePtr replace_input = GetReplaceNode(func_graph, input); |
|
|
|
// If replace input is not null, it will be the input of the TransData or Cast. |
|
|
|
if (replace_input == nullptr) { |
|
|
|
new_make_tuple_inputs.push_back(input); |
|
|
|
@@ -149,7 +153,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c |
|
|
|
if (make_tuple_replace_node != nullptr) { |
|
|
|
return make_tuple_replace_node; |
|
|
|
} |
|
|
|
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); |
|
|
|
AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode); |
|
|
|
if (replace_node == nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); |
|
|
|
return replacing_node; |
|
|
|
|