| @@ -71,7 +71,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| } else if (is_insert_input) { | } else if (is_insert_input) { | ||||
| // if need padding & is input need insert a transdata | // if need padding & is input need insert a transdata | ||||
| // reshape[padding shape] -> transdata[padding shape] -> node | // reshape[padding shape] -> transdata[padding shape] -> node | ||||
| auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0)); | |||||
| auto padding_shape = | |||||
| trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); | |||||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | ||||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | ||||
| trans_node = trans_data; | trans_node = trans_data; | ||||
| @@ -553,6 +553,30 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| return output_node_list->size() > 1; | return output_node_list->size() > 1; | ||||
| } | } | ||||
| bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto output_node_list = GetRealNodeUsedList(graph, node); | |||||
| MS_EXCEPTION_IF_NULL(output_node_list); | |||||
| if (output_node_list->empty()) { | |||||
| return true; | |||||
| } | |||||
| for (const auto &output : *output_node_list) { | |||||
| auto out_node = output.first; | |||||
| auto name = AnfAlgo::GetCNodeName(out_node); | |||||
| if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || | |||||
| name == prim::kPrimTupleGetItem->name()) { | |||||
| auto result = IsNotRealUsedByOthers(graph, out_node); | |||||
| if (!result) { | |||||
| return result; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | ||||
| auto idx = NewValueNode(SizeToInt(output_idx)); | auto idx = NewValueNode(SizeToInt(output_idx)); | ||||
| MS_EXCEPTION_IF_NULL(idx); | MS_EXCEPTION_IF_NULL(idx); | ||||
| @@ -175,6 +175,7 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | ||||
| const AnfNodePtr &node, | const AnfNodePtr &node, | ||||
| size_t output_index); | size_t output_index); | ||||
| bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| constexpr auto kSingleInputIndex = 1; | constexpr auto kSingleInputIndex = 1; | ||||
| namespace { | namespace { | ||||
| AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { | |||||
| AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -40,6 +40,9 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { | |||||
| if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { | if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!IsNotRealUsedByOthers(func_graph, cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| CheckCNodeInputSize(cnode, kSingleInputIndex + 1); | CheckCNodeInputSize(cnode, kSingleInputIndex + 1); | ||||
| return cnode->input(kSingleInputIndex); | return cnode->input(kSingleInputIndex); | ||||
| } | } | ||||
| @@ -50,10 +53,11 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { | if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | |||||
| std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||||
| bool need_update = false; | bool need_update = false; | ||||
| for (const auto &input : cnode->inputs()) { | |||||
| AnfNodePtr replace_input = GetReplaceNode(input); | |||||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | |||||
| auto input = AnfAlgo::GetInputNode(cnode, index); | |||||
| AnfNodePtr replace_input = GetReplaceNode(func_graph, input); | |||||
| // If replace input is not null, it will be the input of the TransData or Cast. | // If replace input is not null, it will be the input of the TransData or Cast. | ||||
| if (replace_input == nullptr) { | if (replace_input == nullptr) { | ||||
| new_make_tuple_inputs.push_back(input); | new_make_tuple_inputs.push_back(input); | ||||
| @@ -149,7 +153,7 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c | |||||
| if (make_tuple_replace_node != nullptr) { | if (make_tuple_replace_node != nullptr) { | ||||
| return make_tuple_replace_node; | return make_tuple_replace_node; | ||||
| } | } | ||||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||||
| AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode); | |||||
| if (replace_node == nullptr) { | if (replace_node == nullptr) { | ||||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | ||||
| return replacing_node; | return replacing_node; | ||||
| @@ -392,8 +392,8 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||||
| } else { | } else { | ||||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | host_shape = AnfAlgo::GetOutputInferShape(node, index); | ||||
| } | } | ||||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); | |||||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index)); | |||||
| } | } | ||||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | ||||
| return shape; | return shape; | ||||