diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc index c94940b7dd..6b01c24faa 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -46,8 +46,8 @@ void BackendCommonOptimization(const std::shared_ptr &kern auto common_pm = std::make_shared("common_pm"); common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); optimizer->AddPassManager(common_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc index f204841f3c..cb75d3689e 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -139,7 +139,10 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return nullptr; + } + if (!node->isa()) { return nullptr; } if (AnfAlgo::IsGraphKernel(node)) { diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index 68543328b1..efe8d5320f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -17,6 +17,7 @@ #include #include +#include #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" @@ -25,68 +26,26 @@ namespace mindspore { namespace opt { namespace { -CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) { +AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, + std::unordered_map *transed_nodes) { + MS_EXCEPTION_IF_NULL(tuple_anf); MS_EXCEPTION_IF_NULL(graph); - if (!AnfAlgo::IsTupleOutput(input_node)) { - MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!"; + MS_EXCEPTION_IF_NULL(transed_nodes); + + if (!AnfAlgo::IsTupleOutput(tuple_anf)) { + return tuple_anf; } - if (input_node->isa()) { - MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString(); + auto transed_node_it = transed_nodes->find(tuple_anf); + if (transed_node_it != transed_nodes->end()) { + return transed_node_it->second; } - std::vector convert_inputs = {NewValueNode(prim::kPrimMakeTuple)}; auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node); - for (const auto &node : splited_node_list) { - if (AnfAlgo::IsTupleOutput(node)) { - convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node)); - continue; - } - convert_inputs.emplace_back(node); - } - - auto make_tuple = graph->NewCNode(convert_inputs); - std::vector abstract_list; - auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple); - for (size_t index = 0; index < make_tuple_input_size; ++index) { - auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index); - MS_EXCEPTION_IF_NULL(make_tuple_input); - abstract_list.emplace_back(make_tuple_input->abstract()); - } - make_tuple->set_abstract(std::make_shared(abstract_list)); + auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); + (*transed_nodes)[tuple_anf] = make_tuple; + // replace graph inputs if input is a parameter + kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); return make_tuple; } - -CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - MS_EXCEPTION_IF_NULL(graph); - std::vector convert_inputs = {cnode_ptr->input(0)}; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { - auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); - if (AnfAlgo::IsTupleOutput(input_node)) { - std::vector types; - std::vector> shapes; - std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; - if (input_node->isa()) { - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { - make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); - types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); - shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); - } - auto make_tuple = graph->NewCNode(make_tuple_inputs_list); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); - convert_inputs.emplace_back(make_tuple); - continue; - } - convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node)); - } else { - convert_inputs.push_back(input_node); - } - } - auto new_node = graph->NewCNode(convert_inputs); - new_node->set_abstract(cnode_ptr->abstract()); - return new_node; -} } // namespace const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { @@ -102,15 +61,22 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); + std::unordered_map transed_nodes; if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { return nullptr; } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); - })) { - return ConvertTupleInputToMakeTuple(func_graph, cnode); + bool cnode_input_changed = false; + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + const auto &input = cnode->inputs()[i]; + if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && + !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { + cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); + cnode_input_changed = true; + } } - return nullptr; + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return cnode_input_changed ? kernel_graph->NewCNode(cnode) : nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index f7be0cb6b8..e38de71170 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1817,7 +1817,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output // auto multi_output_param = graph->NewParameter(); auto origin_inputs = graph->inputs(); - auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); + auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); MS_EXCEPTION_IF_NULL(graph->MutableInputs()); graph->MutableInputs()->operator=(origin_inputs); graph->AddChildGraphResult(output_param); @@ -1835,9 +1835,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu if (child_graph->get_output_null()) { continue; } - auto graph_output = child_graph->output(); - AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), - NOT_NULL(output_param)); + AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, + NOT_NULL(child_graph->output()), NOT_NULL(output_param)); } } } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 4313be7379..aa5a62e7e4 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -441,83 +441,115 @@ ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract return new_parameter; } -std::vector KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) { - MS_EXCEPTION_IF_NULL(parameter); - std::vector convert_nodes_list; - auto abstract = parameter->abstract(); +ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto new_value_node = MakeValueNode(value_node)->cast(); + AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); + return new_value_node; +} + +ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) { MS_EXCEPTION_IF_NULL(abstract); + MS_EXCEPTION_IF_NULL(value); + ValueNodePtr new_value_node = std::make_shared(value); + new_value_node->set_abstract(abstract); + SetKernelInfoForNode(new_value_node); + AnfAlgo::SetGraphId(graph_id(), new_value_node.get()); + return new_value_node; +} + +AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(abstract); + MS_EXCEPTION_IF_NULL(value); if (!abstract->isa()) { - MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString(); + auto new_value_node = NewValueNode(abstract, value); + AddValueNodeToGraph(new_value_node); + return new_value_node; } auto tuple_abstract = abstract->cast(); + auto value_tuple = value->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); - for (size_t index = 0; index < tuple_abstract->size(); ++index) { - auto new_parameter = this->NewParameter((*tuple_abstract)[index]); - SetKernelInfoForNode(new_parameter); - convert_nodes_list.emplace_back(new_parameter); - } - auto new_inputs = std::make_shared>(); - auto old_inputs = inputs(); - for (const auto &input_node : old_inputs) { - if (input_node != parameter) { - new_inputs->emplace_back(input_node); - continue; - } - std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs)); - } - inputs_ = new_inputs; - return convert_nodes_list; -} - -std::vector KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString(); + MS_EXCEPTION_IF_NULL(value_tuple); + if (tuple_abstract->size() != value_tuple->size()) { + MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size() + << " is not equal to value size:" << value_tuple->size(); } - if (node->isa()) { - return SplitTupleParameterToNodeList(node->cast()); + std::vector make_tuple_inputs = { + mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; + for (size_t index = 0; index < tuple_abstract->size(); ++index) { + make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index])); } - return SplitTupleValueNodeToNodeList(node->cast()); + auto make_tuple = NewCNode(make_tuple_inputs); + make_tuple->set_abstract(tuple_abstract); + return make_tuple; } -std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - std::vector convert_inputs; - if (!node_value->isa()) { - MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); - } - auto value_tuple = node_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - auto abstract = value_node->abstract(); +AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) { + MS_EXCEPTION_IF_NULL(abstract); if (!abstract->isa()) { - MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple"; + return NewParameter(abstract); } auto tuple_abstract = abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); - if (tuple_abstract->size() != value_tuple->size()) { - MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range " - << tuple_abstract->size(); - } - for (size_t index = 0; index < value_tuple->value().size(); ++index) { - auto new_value_node = std::make_shared(value_tuple->value()[index]); - new_value_node->set_abstract((*tuple_abstract)[index]); - AddValueNodeToGraph(new_value_node); - SetKernelInfoForNode(new_value_node); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - convert_inputs.emplace_back(new_value_node); - } - if (!RemoveValueNodeFromGraph(value_node)) { - MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); + std::vector make_tuple_inputs = { + mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; + for (size_t index = 0; index < tuple_abstract->size(); ++index) { + make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index])); } - return convert_inputs; + auto make_tuple = NewCNode(make_tuple_inputs); + make_tuple->set_abstract(tuple_abstract); + return make_tuple; } -ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto new_value_node = MakeValueNode(value_node)->cast(); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - return new_value_node; +AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) { + auto idx = mindspore::NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector types; + std::vector> shapes; + std::vector make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)}; + for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) { + make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index)); + types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index)); + } + auto make_tuple = NewCNode(make_tuple_inputs_list); + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); + return make_tuple; +} + +AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsTupleOutput(node)) { + return node; + } + if (node->isa()) { + return TransParameterTuple(node->abstract()); + } else if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value()); + if (RemoveValueNodeFromGraph(value_node)) { + MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); + } + return make_tuple; + } else if (node->isa()) { + return TransCNodeTuple(node->cast()); + } + MS_LOG(EXCEPTION) << "Unexpected node:" << node->DebugString(); } const std::vector &KernelGraph::inputs() const { @@ -817,6 +849,23 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { return false; } +void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) { + // update graph inputs + MS_EXCEPTION_IF_NULL(old_parameter); + MS_EXCEPTION_IF_NULL(new_parameter); + if (old_parameter == new_parameter) { + return; + } + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_parameter) { + MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString() + << ",new graph input:" << new_parameter->DebugString(); + (*inputs_)[i] = new_parameter; + break; + } + } +} + void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { MS_EXCEPTION_IF_NULL(inputs_); { @@ -840,15 +889,7 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNullset_input(i, new_anf_node); } } - // update graph inputs - for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node.get()) { - MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() - << ",new graph input:" << new_anf_node->DebugString(); - (*inputs_)[i] = new_anf_node.get(); - break; - } - } + ReplaceGraphInput(old_anf_node, new_anf_node); } // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 7c170a37af..7e3988e11b 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -49,15 +49,17 @@ class KernelGraph : public FuncGraph { const std::vector &inputs() const; std::vector *MutableInputs() const { return inputs_.get(); } + void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter); std::vector outputs() const; CNodePtr NewCNode(const std::vector &inputs) override; void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode); ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); - ValueNodePtr NewValueNode(const ValuePtr &value); + ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); - std::vector SplitTupleOutputNodeToNodeList(const AnfNodePtr &node); + // trans tuple output to maketuple + no_tuple out + AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node); void set_execution_order(const std::vector &order) { execution_order_ = order; } const std::vector &execution_order() const { return execution_order_; } void SetExecOrderByDefault(); @@ -167,8 +169,6 @@ class KernelGraph : public FuncGraph { // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); void SetKernelInfoForNode(const AnfNodePtr &node) const; - std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); - std::vector SplitTupleParameterToNodeList(const ParameterPtr ¶meter); AnfNodePtr MakeValueNode(const AnfNodePtr &node); void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes); @@ -181,6 +181,10 @@ class KernelGraph : public FuncGraph { bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, std::unordered_set *visited_nodes); void UpdateControlDependRelations(const std::vector &depends); + AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); + AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); + AnfNodePtr TransCNodeTuple(const CNodePtr &node); + AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); std::shared_ptr> inputs_; std::vector child_graph_result_; diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc index 5b303d15a5..290fab4ed9 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc @@ -99,13 +99,18 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) { EXPECT_NE(ret->input(1)->cast(), nullptr); auto cnode = ret->input(1)->cast()->input(1)->cast(); EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name()); - auto input1 = cnode->input(1); - ASSERT_TRUE(input1 != nullptr); - EXPECT_TRUE(IsValueNode(input1)); - auto tensor = input1->cast()->value()->cast(); - ASSERT_TRUE(tensor != nullptr); - auto data = tensor->data_c(); - EXPECT_EQ(std::vector((int *)data, (int *)data + 4), std::vector({2, 4, 2, 2})); + std::vector out; + for (size_t i = 1; i <= 4; i++) { + auto input = cnode->input(i); + ASSERT_TRUE(input != nullptr); + EXPECT_TRUE(IsValueNode(input)); + auto tensor = input->cast()->value()->cast(); + ASSERT_TRUE(tensor != nullptr); + int *data = (int *)(tensor->data_c()); + ASSERT_TRUE(data != nullptr); + out.push_back(*data); + } + EXPECT_EQ(out, std::vector({2, 4, 2, 2})); } } // namespace opt } // namespace mindspore