Merge pull request !3290 from chenfei_mindspore/split-tuple-parameter-to-parameterstags/v0.7.0-beta
| @@ -46,8 +46,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| auto common_pm = std::make_shared<PassManager>("common_pm"); | auto common_pm = std::make_shared<PassManager>("common_pm"); | ||||
| common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>()); | common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>()); | ||||
| common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>()); | common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>()); | ||||
| common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>()); | |||||
| common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>()); | common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>()); | ||||
| common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>()); | |||||
| common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>()); | common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>()); | ||||
| optimizer->AddPassManager(common_pm); | optimizer->AddPassManager(common_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| @@ -139,7 +139,10 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { | |||||
| const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | |||||
| if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (AnfAlgo::IsGraphKernel(node)) { | if (AnfAlgo::IsGraphKernel(node)) { | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| @@ -25,68 +26,26 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) { | |||||
| AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(tuple_anf); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (!AnfAlgo::IsTupleOutput(input_node)) { | |||||
| MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!"; | |||||
| MS_EXCEPTION_IF_NULL(transed_nodes); | |||||
| if (!AnfAlgo::IsTupleOutput(tuple_anf)) { | |||||
| return tuple_anf; | |||||
| } | } | ||||
| if (input_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString(); | |||||
| auto transed_node_it = transed_nodes->find(tuple_anf); | |||||
| if (transed_node_it != transed_nodes->end()) { | |||||
| return transed_node_it->second; | |||||
| } | } | ||||
| std::vector<AnfNodePtr> convert_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| auto kernel_graph = graph->cast<KernelGraphPtr>(); | auto kernel_graph = graph->cast<KernelGraphPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node); | |||||
| for (const auto &node : splited_node_list) { | |||||
| if (AnfAlgo::IsTupleOutput(node)) { | |||||
| convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node)); | |||||
| continue; | |||||
| } | |||||
| convert_inputs.emplace_back(node); | |||||
| } | |||||
| auto make_tuple = graph->NewCNode(convert_inputs); | |||||
| std::vector<abstract::AbstractBasePtr> abstract_list; | |||||
| auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple); | |||||
| for (size_t index = 0; index < make_tuple_input_size; ++index) { | |||||
| auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple_input); | |||||
| abstract_list.emplace_back(make_tuple_input->abstract()); | |||||
| } | |||||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); | |||||
| (*transed_nodes)[tuple_anf] = make_tuple; | |||||
| // replace graph inputs if input is a parameter | |||||
| kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); | |||||
| return make_tuple; | return make_tuple; | ||||
| } | } | ||||
| CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> convert_inputs = {cnode_ptr->input(0)}; | |||||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { | |||||
| auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); | |||||
| if (AnfAlgo::IsTupleOutput(input_node)) { | |||||
| std::vector<TypeId> types; | |||||
| std::vector<std::vector<size_t>> shapes; | |||||
| std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| if (input_node->isa<CNode>()) { | |||||
| for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { | |||||
| make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); | |||||
| types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); | |||||
| shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); | |||||
| } | |||||
| auto make_tuple = graph->NewCNode(make_tuple_inputs_list); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); | |||||
| convert_inputs.emplace_back(make_tuple); | |||||
| continue; | |||||
| } | |||||
| convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node)); | |||||
| } else { | |||||
| convert_inputs.push_back(input_node); | |||||
| } | |||||
| } | |||||
| auto new_node = graph->NewCNode(convert_inputs); | |||||
| new_node->set_abstract(cnode_ptr->abstract()); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { | const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { | ||||
| @@ -102,15 +61,22 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes; | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { | |||||
| return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); | |||||
| })) { | |||||
| return ConvertTupleInputToMakeTuple(func_graph, cnode); | |||||
| bool cnode_input_changed = false; | |||||
| for (size_t i = 0; i < cnode->inputs().size(); ++i) { | |||||
| const auto &input = cnode->inputs()[i]; | |||||
| if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && | |||||
| !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { | |||||
| cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); | |||||
| cnode_input_changed = true; | |||||
| } | |||||
| } | } | ||||
| return nullptr; | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| return cnode_input_changed ? kernel_graph->NewCNode(cnode) : nullptr; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1817,7 +1817,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||||
| // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | ||||
| // auto multi_output_param = graph->NewParameter(); | // auto multi_output_param = graph->NewParameter(); | ||||
| auto origin_inputs = graph->inputs(); | auto origin_inputs = graph->inputs(); | ||||
| auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); | |||||
| auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | |||||
| MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | ||||
| graph->MutableInputs()->operator=(origin_inputs); | graph->MutableInputs()->operator=(origin_inputs); | ||||
| graph->AddChildGraphResult(output_param); | graph->AddChildGraphResult(output_param); | ||||
| @@ -1835,9 +1835,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||||
| if (child_graph->get_output_null()) { | if (child_graph->get_output_null()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto graph_output = child_graph->output(); | |||||
| AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), | |||||
| NOT_NULL(output_param)); | |||||
| AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, | |||||
| NOT_NULL(child_graph->output()), NOT_NULL(output_param)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -441,83 +441,115 @@ ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract | |||||
| return new_parameter; | return new_parameter; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) { | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| std::vector<AnfNodePtr> convert_nodes_list; | |||||
| auto abstract = parameter->abstract(); | |||||
| ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>(); | |||||
| AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); | |||||
| return new_value_node; | |||||
| } | |||||
| ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) { | |||||
| MS_EXCEPTION_IF_NULL(abstract); | MS_EXCEPTION_IF_NULL(abstract); | ||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| ValueNodePtr new_value_node = std::make_shared<ValueNode>(value); | |||||
| new_value_node->set_abstract(abstract); | |||||
| SetKernelInfoForNode(new_value_node); | |||||
| AnfAlgo::SetGraphId(graph_id(), new_value_node.get()); | |||||
| return new_value_node; | |||||
| } | |||||
| AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value) { | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| if (!abstract->isa<abstract::AbstractTuple>()) { | if (!abstract->isa<abstract::AbstractTuple>()) { | ||||
| MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString(); | |||||
| auto new_value_node = NewValueNode(abstract, value); | |||||
| AddValueNodeToGraph(new_value_node); | |||||
| return new_value_node; | |||||
| } | } | ||||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | ||||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | MS_EXCEPTION_IF_NULL(tuple_abstract); | ||||
| for (size_t index = 0; index < tuple_abstract->size(); ++index) { | |||||
| auto new_parameter = this->NewParameter((*tuple_abstract)[index]); | |||||
| SetKernelInfoForNode(new_parameter); | |||||
| convert_nodes_list.emplace_back(new_parameter); | |||||
| } | |||||
| auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>(); | |||||
| auto old_inputs = inputs(); | |||||
| for (const auto &input_node : old_inputs) { | |||||
| if (input_node != parameter) { | |||||
| new_inputs->emplace_back(input_node); | |||||
| continue; | |||||
| } | |||||
| std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs)); | |||||
| } | |||||
| inputs_ = new_inputs; | |||||
| return convert_nodes_list; | |||||
| } | |||||
| std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| if (tuple_abstract->size() != value_tuple->size()) { | |||||
| MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size() | |||||
| << " is not equal to value size:" << value_tuple->size(); | |||||
| } | } | ||||
| if (node->isa<Parameter>()) { | |||||
| return SplitTupleParameterToNodeList(node->cast<ParameterPtr>()); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = { | |||||
| mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||||
| for (size_t index = 0; index < tuple_abstract->size(); ++index) { | |||||
| make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index])); | |||||
| } | } | ||||
| return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>()); | |||||
| auto make_tuple = NewCNode(make_tuple_inputs); | |||||
| make_tuple->set_abstract(tuple_abstract); | |||||
| return make_tuple; | |||||
| } | } | ||||
| std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto node_value = value_node->value(); | |||||
| std::vector<AnfNodePtr> convert_inputs; | |||||
| if (!node_value->isa<ValueTuple>()) { | |||||
| MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); | |||||
| } | |||||
| auto value_tuple = node_value->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| auto abstract = value_node->abstract(); | |||||
| AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) { | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| if (!abstract->isa<abstract::AbstractTuple>()) { | if (!abstract->isa<abstract::AbstractTuple>()) { | ||||
| MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple"; | |||||
| return NewParameter(abstract); | |||||
| } | } | ||||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | MS_EXCEPTION_IF_NULL(tuple_abstract); | ||||
| if (tuple_abstract->size() != value_tuple->size()) { | |||||
| MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range " | |||||
| << tuple_abstract->size(); | |||||
| } | |||||
| for (size_t index = 0; index < value_tuple->value().size(); ++index) { | |||||
| auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]); | |||||
| new_value_node->set_abstract((*tuple_abstract)[index]); | |||||
| AddValueNodeToGraph(new_value_node); | |||||
| SetKernelInfoForNode(new_value_node); | |||||
| AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); | |||||
| convert_inputs.emplace_back(new_value_node); | |||||
| } | |||||
| if (!RemoveValueNodeFromGraph(value_node)) { | |||||
| MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = { | |||||
| mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||||
| for (size_t index = 0; index < tuple_abstract->size(); ++index) { | |||||
| make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index])); | |||||
| } | } | ||||
| return convert_inputs; | |||||
| auto make_tuple = NewCNode(make_tuple_inputs); | |||||
| make_tuple->set_abstract(tuple_abstract); | |||||
| return make_tuple; | |||||
| } | } | ||||
| ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>(); | |||||
| AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); | |||||
| return new_value_node; | |||||
| AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) { | |||||
| auto idx = mindspore::NewValueNode(SizeToInt(output_idx)); | |||||
| MS_EXCEPTION_IF_NULL(idx); | |||||
| auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx)); | |||||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | |||||
| idx->set_abstract(abstract_scalar); | |||||
| AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx}); | |||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||||
| tuple_getitem->set_scope(node->scope()); | |||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | |||||
| TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); | |||||
| return tuple_getitem; | |||||
| } | |||||
| AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<TypeId> types; | |||||
| std::vector<std::vector<size_t>> shapes; | |||||
| std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)}; | |||||
| for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) { | |||||
| make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index)); | |||||
| types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index)); | |||||
| shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index)); | |||||
| } | |||||
| auto make_tuple = NewCNode(make_tuple_inputs_list); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); | |||||
| return make_tuple; | |||||
| } | |||||
| AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::IsTupleOutput(node)) { | |||||
| return node; | |||||
| } | |||||
| if (node->isa<Parameter>()) { | |||||
| return TransParameterTuple(node->abstract()); | |||||
| } else if (node->isa<ValueNode>()) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value()); | |||||
| if (RemoveValueNodeFromGraph(value_node)) { | |||||
| MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); | |||||
| } | |||||
| return make_tuple; | |||||
| } else if (node->isa<CNode>()) { | |||||
| return TransCNodeTuple(node->cast<CNodePtr>()); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Unexpected node:" << node->DebugString(); | |||||
| } | } | ||||
| const std::vector<AnfNodePtr> &KernelGraph::inputs() const { | const std::vector<AnfNodePtr> &KernelGraph::inputs() const { | ||||
| @@ -817,6 +849,23 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) { | |||||
| // update graph inputs | |||||
| MS_EXCEPTION_IF_NULL(old_parameter); | |||||
| MS_EXCEPTION_IF_NULL(new_parameter); | |||||
| if (old_parameter == new_parameter) { | |||||
| return; | |||||
| } | |||||
| for (size_t i = 0; i < inputs_->size(); i++) { | |||||
| if ((*inputs_)[i] == old_parameter) { | |||||
| MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString() | |||||
| << ",new graph input:" << new_parameter->DebugString(); | |||||
| (*inputs_)[i] = new_parameter; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) { | void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) { | ||||
| MS_EXCEPTION_IF_NULL(inputs_); | MS_EXCEPTION_IF_NULL(inputs_); | ||||
| { | { | ||||
| @@ -840,15 +889,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP | |||||
| output_cnode->set_input(i, new_anf_node); | output_cnode->set_input(i, new_anf_node); | ||||
| } | } | ||||
| } | } | ||||
| // update graph inputs | |||||
| for (size_t i = 0; i < inputs_->size(); i++) { | |||||
| if ((*inputs_)[i] == old_anf_node.get()) { | |||||
| MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() | |||||
| << ",new graph input:" << new_anf_node->DebugString(); | |||||
| (*inputs_)[i] = new_anf_node.get(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| ReplaceGraphInput(old_anf_node, new_anf_node); | |||||
| } | } | ||||
| // update front to backend map | // update front to backend map | ||||
| FrontBackendlMapUpdate(old_anf_node, new_anf_node); | FrontBackendlMapUpdate(old_anf_node, new_anf_node); | ||||
| @@ -49,15 +49,17 @@ class KernelGraph : public FuncGraph { | |||||
| const std::vector<AnfNodePtr> &inputs() const; | const std::vector<AnfNodePtr> &inputs() const; | ||||
| std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); } | std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); } | ||||
| void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter); | |||||
| std::vector<AnfNodePtr> outputs() const; | std::vector<AnfNodePtr> outputs() const; | ||||
| CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override; | CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override; | ||||
| void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); | void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); | ||||
| CNodePtr NewCNode(const CNodePtr &cnode); | CNodePtr NewCNode(const CNodePtr &cnode); | ||||
| ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); | ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); | ||||
| ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); | ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); | ||||
| ValueNodePtr NewValueNode(const ValuePtr &value); | |||||
| ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); | |||||
| ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); | ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); | ||||
| std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node); | |||||
| // trans tuple output to maketuple + no_tuple out | |||||
| AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node); | |||||
| void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } | void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } | ||||
| const std::vector<CNodePtr> &execution_order() const { return execution_order_; } | const std::vector<CNodePtr> &execution_order() const { return execution_order_; } | ||||
| void SetExecOrderByDefault(); | void SetExecOrderByDefault(); | ||||
| @@ -167,8 +169,6 @@ class KernelGraph : public FuncGraph { | |||||
| // remove value node form graph | // remove value node form graph | ||||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | ||||
| void SetKernelInfoForNode(const AnfNodePtr &node) const; | void SetKernelInfoForNode(const AnfNodePtr &node) const; | ||||
| std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); | |||||
| std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr ¶meter); | |||||
| AnfNodePtr MakeValueNode(const AnfNodePtr &node); | AnfNodePtr MakeValueNode(const AnfNodePtr &node); | ||||
| void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | ||||
| std::unordered_set<AnfNodePtr> *visited_nodes); | std::unordered_set<AnfNodePtr> *visited_nodes); | ||||
| @@ -181,6 +181,10 @@ class KernelGraph : public FuncGraph { | |||||
| bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | ||||
| std::unordered_set<AnfNodePtr> *visited_nodes); | std::unordered_set<AnfNodePtr> *visited_nodes); | ||||
| void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends); | void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends); | ||||
| AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); | |||||
| AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | |||||
| AnfNodePtr TransCNodeTuple(const CNodePtr &node); | |||||
| AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); | |||||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | ||||
| std::vector<AnfNodePtr> child_graph_result_; | std::vector<AnfNodePtr> child_graph_result_; | ||||
| @@ -99,13 +99,18 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) { | |||||
| EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr); | EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr); | ||||
| auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>(); | auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>(); | ||||
| EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name()); | EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name()); | ||||
| auto input1 = cnode->input(1); | |||||
| ASSERT_TRUE(input1 != nullptr); | |||||
| EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1)); | |||||
| auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>(); | |||||
| ASSERT_TRUE(tensor != nullptr); | |||||
| auto data = tensor->data_c(); | |||||
| EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2})); | |||||
| std::vector<int> out; | |||||
| for (size_t i = 1; i <= 4; i++) { | |||||
| auto input = cnode->input(i); | |||||
| ASSERT_TRUE(input != nullptr); | |||||
| EXPECT_TRUE(IsValueNode<tensor::Tensor>(input)); | |||||
| auto tensor = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>(); | |||||
| ASSERT_TRUE(tensor != nullptr); | |||||
| int *data = (int *)(tensor->data_c()); | |||||
| ASSERT_TRUE(data != nullptr); | |||||
| out.push_back(*data); | |||||
| } | |||||
| EXPECT_EQ(out, std::vector<int>({2, 4, 2, 2})); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||