diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 940661b300..39ebc39612 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -71,7 +71,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt } else if (is_insert_input) { // if need padding & is input need insert a transdata // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0)); + auto padding_shape = + trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); trans_node = trans_data; diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 6983162cda..aebd12b727 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -553,6 +553,30 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { return output_node_list->size() > 1; } +bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + if (output_node_list->empty()) { + return true; + } + for (const auto &output : *output_node_list) { + auto out_node = output.first; + auto name = AnfAlgo::GetCNodeName(out_node); + if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || + name == prim::kPrimTupleGetItem->name()) { + auto result = IsNotRealUsedByOthers(graph, out_node); + if (!result) { + return result; + } + continue; + } + return false; + } + return true; +} + AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { auto idx = NewValueNode(SizeToInt(output_idx)); MS_EXCEPTION_IF_NULL(idx); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 144f6e3e77..87731b5445 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -175,6 +175,7 @@ std::shared_ptr>> GetRealNodeUsedList(con std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index); +bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index 48c0c78047..730cd5b83e 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; namespace { -AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { +AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return nullptr; @@ -40,6 +40,9 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { return nullptr; } + if (!IsNotRealUsedByOthers(func_graph, cnode)) { + return nullptr; + } CheckCNodeInputSize(cnode, kSingleInputIndex + 1); return cnode->input(kSingleInputIndex); } @@ -50,10 +53,11 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { return nullptr; } - std::vector new_make_tuple_inputs; + std::vector new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; bool need_update = false; - for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(input); + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + auto input = AnfAlgo::GetInputNode(cnode, index); + AnfNodePtr replace_input = GetReplaceNode(func_graph, input); // If replace input is not null, it will be the input of the TransData or Cast. if (replace_input == nullptr) { new_make_tuple_inputs.push_back(input); @@ -149,7 +153,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c if (make_tuple_replace_node != nullptr) { return make_tuple_replace_node; } - AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); + AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode); if (replace_node == nullptr) { MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); return replacing_node; diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index f8f5e90d62..1ec0db7f37 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -392,8 +392,8 @@ std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { } else { host_shape = AnfAlgo::GetOutputInferShape(node, index); } - if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { - host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); + if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) { + host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index)); } std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); return shape;