| @@ -236,6 +236,136 @@ MetaGraphTptr BuildMixGraph() { | |||||
| // final output | // final output | ||||
| return meta_graph; | return meta_graph; | ||||
| } | } | ||||
| MetaGraphTptr BuildSplitGraph() { | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| // slice node | |||||
| auto split_node = std::make_unique<schema::CNodeT>(); | |||||
| split_node->inputIndex = {0}; | |||||
| split_node->outputIndex = {1, 2}; | |||||
| split_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| split_node->primitive->value.type = schema::PrimitiveType_Split; | |||||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||||
| attr->numberSplit = 2; | |||||
| attr->splitDim = 1; | |||||
| split_node->primitive->value.value = attr.release(); | |||||
| split_node->name = "split"; | |||||
| meta_graph->nodes.emplace_back(std::move(split_node)); | |||||
| meta_graph->inputIndex = {0, 3, 4}; | |||||
| meta_graph->outputIndex = {5, 6}; | |||||
| auto mul_node1 = std::make_unique<schema::CNodeT>(); | |||||
| mul_node1->inputIndex = {1, 3}; | |||||
| mul_node1->outputIndex = {5}; | |||||
| mul_node1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| mul_node1->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| std::unique_ptr<schema::MulT> mul_attr = std::make_unique<schema::MulT>(); | |||||
| mul_node1->primitive->value.value = mul_attr.release(); | |||||
| mul_node1->name = "mul1"; | |||||
| meta_graph->nodes.emplace_back(std::move(mul_node1)); | |||||
| auto mul_node2 = std::make_unique<schema::CNodeT>(); | |||||
| mul_node2->inputIndex = {2, 4}; | |||||
| mul_node2->outputIndex = {6}; | |||||
| mul_node2->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| mul_node2->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| std::unique_ptr<schema::MulT> mul2_attr = std::make_unique<schema::MulT>(); | |||||
| mul_node2->primitive->value.value = mul2_attr.release(); | |||||
| mul_node2->name = "mul2"; | |||||
| meta_graph->nodes.emplace_back(std::move(mul_node2)); | |||||
| // input 0: data1 | |||||
| auto input0 = std::make_unique<schema::TensorT>(); | |||||
| input0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input0->format = schema::Format_NHWC; | |||||
| input0->dataType = TypeId::kNumberTypeFloat32; | |||||
| input0->dims = {1, 2, 2, 3}; | |||||
| input0->offset = -1; | |||||
| auto input0_data = new(std::nothrow) float[2 * 2 * 3]; | |||||
| for (auto i = 0; i < 2 * 2 * 3; i++) { | |||||
| input0_data[i] = i; | |||||
| } | |||||
| input0->data.resize(sizeof(float) * 2 * 2 * 3); | |||||
| memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float)); | |||||
| delete[] input0_data; | |||||
| meta_graph->allTensors.emplace_back(std::move(input0)); | |||||
| // split output1 | |||||
| auto split_output1 = std::make_unique<schema::TensorT>(); | |||||
| split_output1->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| split_output1->format = schema::Format_NHWC; | |||||
| split_output1->dataType = TypeId::kNumberTypeFloat32; | |||||
| split_output1->dims = {1, 1, 2, 3}; | |||||
| split_output1->offset = -1; | |||||
| split_output1->data.resize(sizeof(float) * 1 * 2 * 3); | |||||
| auto split_output_data1 = new(std::nothrow) float[1 * 2 * 3]; | |||||
| memcpy(split_output1->data.data(), split_output_data1, 1 * 2 * 3 * sizeof(float)); | |||||
| delete[] split_output_data1; | |||||
| meta_graph->allTensors.emplace_back(std::move(split_output1)); | |||||
| // split output2 | |||||
| auto split_output2 = std::make_unique<schema::TensorT>(); | |||||
| split_output2->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| split_output2->format = schema::Format_NHWC; | |||||
| split_output2->dataType = TypeId::kNumberTypeFloat32; | |||||
| split_output2->dims = {1, 1, 2, 3}; | |||||
| split_output2->offset = -1; | |||||
| split_output2->data.resize(sizeof(float) * 1 * 2 * 3); | |||||
| auto split_output_data2 = new(std::nothrow) float[1 * 2 * 3]; | |||||
| memcpy(split_output2->data.data(), split_output_data2, 1 * 2 * 3 * sizeof(float)); | |||||
| delete[] split_output_data2; | |||||
| meta_graph->allTensors.emplace_back(std::move(split_output2)); | |||||
| // input 1: data2 | |||||
| auto input1 = std::make_unique<schema::TensorT>(); | |||||
| input1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input1->format = schema::Format_NHWC; | |||||
| input1->dataType = TypeId::kNumberTypeFloat32; | |||||
| input1->dims = {1, 1, 2, 3}; | |||||
| input1->offset = -1; | |||||
| input1->data.resize(sizeof(float) * 2 * 3); | |||||
| auto input1_data = new(std::nothrow) float[2 * 3]; | |||||
| for (auto i = 0; i < 2 * 3; i++) { | |||||
| input1_data[i] = i; | |||||
| } | |||||
| memcpy(input1->data.data(), input1_data, 2 * 3 * sizeof(float)); | |||||
| delete[] input1_data; | |||||
| meta_graph->allTensors.emplace_back(std::move(input1)); | |||||
| // input 2: data3 | |||||
| auto input2 = std::make_unique<schema::TensorT>(); | |||||
| input2->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| input2->format = schema::Format_NHWC; | |||||
| input2->dataType = TypeId::kNumberTypeFloat32; | |||||
| input2->dims = {1, 1, 2, 3}; | |||||
| input2->offset = -1; | |||||
| input2->data.resize(sizeof(float) * 2 * 3); | |||||
| auto input2_data = new(std::nothrow) float[2 * 3]; | |||||
| for (auto i = 0; i < 2 * 3; i++) { | |||||
| input2_data[i] = 10; | |||||
| } | |||||
| memcpy(input2->data.data(), input2_data, 2 * 3 * sizeof(float)); | |||||
| delete[] input2_data; | |||||
| meta_graph->allTensors.emplace_back(std::move(input2)); | |||||
| // final mul output1 | |||||
| auto mul_output = std::make_unique<schema::TensorT>(); | |||||
| mul_output->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| mul_output->format = schema::Format_NHWC; | |||||
| mul_output->dataType = TypeId::kNumberTypeFloat32; | |||||
| mul_output->dims = {1, 1, 2, 3}; | |||||
| meta_graph->allTensors.emplace_back(std::move(mul_output)); | |||||
| // final mul output2 | |||||
| auto mul_output2 = std::make_unique<schema::TensorT>(); | |||||
| mul_output2->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| mul_output2->format = schema::Format_NHWC; | |||||
| mul_output2->dataType = TypeId::kNumberTypeFloat32; | |||||
| mul_output2->dims = {1, 1, 2, 3}; | |||||
| meta_graph->allTensors.emplace_back(std::move(mul_output2)); | |||||
| return meta_graph; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { | TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { | ||||
| auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT); | auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT); | ||||
| @@ -483,4 +613,19 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) { | |||||
| auto new_meta_graph = lite::Export(new_graph); | auto new_meta_graph = lite::Export(new_graph); | ||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 0); | ASSERT_EQ(new_meta_graph->nodes.size(), 0); | ||||
| } | } | ||||
| TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) { | |||||
| auto meta_graph = BuildSplitGraph(); | |||||
| auto input_tensor = meta_graph->allTensors.at(0).get(); | |||||
| input_tensor->dataType = kNumberTypeFloat32; | |||||
| auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>("test", false); | |||||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(func_graph); | |||||
| ASSERT_NE(nullptr, new_graph); | |||||
| auto new_meta_graph = lite::Export(new_graph); | |||||
| ASSERT_EQ(new_meta_graph->nodes.size(), 0); | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -319,7 +319,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { | |||||
| if (utils::isa<PrimitiveCPtr>(value)) { | if (utils::isa<PrimitiveCPtr>(value)) { | ||||
| auto primitive = value->cast<PrimitiveCPtr>(); | auto primitive = value->cast<PrimitiveCPtr>(); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| return (schema::PrimitiveType)primitive->Type(); | |||||
| return (schema::PrimitiveType) primitive->Type(); | |||||
| } else if (utils::isa<Primitive>(value)) { | } else if (utils::isa<Primitive>(value)) { | ||||
| auto primitive = value->cast<PrimitivePtr>(); | auto primitive = value->cast<PrimitivePtr>(); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| @@ -392,8 +392,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) { | |||||
| bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { | bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { | ||||
| auto output_node_list = GetRealNodeUsedList(graph, node); | auto output_node_list = GetRealNodeUsedList(graph, node); | ||||
| if (output_node_list->size() != 1) { | if (output_node_list->size() != 1) { | ||||
| MS_LOG(DEBUG) << "fusion node has multi output nodes"; | |||||
| return true; | |||||
| MS_LOG(DEBUG) << "fusion node has multi output nodes"; | |||||
| return true; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -412,5 +412,50 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||||
| std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); | std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); | ||||
| return output_node_list; | return output_node_list; | ||||
| } | } | ||||
| size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { | |||||
| MS_ASSERT(tuple_get_item != nullptr); | |||||
| if (tuple_get_item->size() != kTupleGetItemInputSize) { | |||||
| MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!"; | |||||
| return -1; | |||||
| } | |||||
| auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); | |||||
| MS_ASSERT(output_index_value_node != nullptr); | |||||
| auto value_node = output_index_value_node->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(value_node != nullptr); | |||||
| return IntToSize(GetValue<int>(value_node->value())); | |||||
| } | |||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | |||||
| const AnfNodePtr &node, | |||||
| size_t output_index) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); | |||||
| auto manager = graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| auto iter = manager->node_users().find(node); | |||||
| if (iter == manager->node_users().end()) { | |||||
| MS_LOG(ERROR) << "node has no output in manager"; | |||||
| return output_node_list; | |||||
| } | |||||
| auto output_info_list = iter->second; | |||||
| for (const auto &output_info : output_info_list) { | |||||
| size_t used_output_index; | |||||
| if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) { | |||||
| used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | |||||
| } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) { | |||||
| used_output_index = output_index; | |||||
| } else { | |||||
| if (output_index != 0) { | |||||
| MS_LOG(ERROR) << "node has no output in manager"; | |||||
| return output_node_list; | |||||
| } | |||||
| return output_node_list; | |||||
| } | |||||
| if (used_output_index == output_index) { | |||||
| output_node_list->push_back(output_info); | |||||
| } | |||||
| } | |||||
| return output_node_list; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -63,6 +63,8 @@ bool CheckIsAllInputsParam(const AnfNodePtr &node); | |||||
| size_t GetOutputTensorNum(const AnfNodePtr &node); | size_t GetOutputTensorNum(const AnfNodePtr &node); | ||||
| bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); | bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); | ||||
| size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | ||||
| @@ -41,7 +41,7 @@ std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { | |||||
| auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); | auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); | ||||
| auto tensor_shape = tensorT->dims; | auto tensor_shape = tensorT->dims; | ||||
| auto lite_tensor = | auto lite_tensor = | ||||
| new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); | |||||
| new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); | |||||
| if (lite_tensor == nullptr) { | if (lite_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "lite tensor is nullptr"; | MS_LOG(ERROR) << "lite tensor is nullptr"; | ||||
| return input_tensors; | return input_tensors; | ||||
| @@ -106,7 +106,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens | |||||
| mindspore::lite::PrimitiveC *primitive) { | mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(nullptr != lite_primitive); | MS_ASSERT(nullptr != lite_primitive); | ||||
| auto data_type = inputs.front()->data_type(); | auto data_type = inputs.front()->data_type(); | ||||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()}; | |||||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType) primitive->Type()}; | |||||
| lite::Context context; | lite::Context context; | ||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | ||||
| if (creator != nullptr) { | if (creator != nullptr) { | ||||
| @@ -115,6 +115,44 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_node, const AnfNodePtr &input_node, | |||||
| std::vector<Tensor *> output_tensors, size_t replace_index) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| if (output_tensors.size() != 1) { | |||||
| for (size_t k = 0; k < output_tensors.size(); k++) { | |||||
| auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k); | |||||
| if (used_node_list->size() != 1) { | |||||
| MS_LOG(ERROR) << " output must tuple_getitem"; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| auto tuple_node = used_node_list->at(0).first; | |||||
| if (GetCNodeType(tuple_node) == schema::PrimitiveType_TupleGetItem) { | |||||
| auto new_parameter = CreateNewParamter(func_graph, output_tensors.at(k)); | |||||
| if (new_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| new_parameter->set_name(input_node->fullname_with_scope() + "_const_" + std::to_string(k)); | |||||
| manager->Replace(tuple_node, new_parameter); | |||||
| } else { | |||||
| MS_LOG(ERROR) << " multi out tensor must connect tuple-getitem: " << input_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); | |||||
| if (new_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| new_parameter->set_name(input_node->fullname_with_scope()); | |||||
| any_node->set_input(replace_index, new_parameter); | |||||
| } | |||||
| return lite::RET_OK; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) { | void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) { | ||||
| if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
| @@ -140,64 +178,66 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| } | } | ||||
| auto any_node = node->cast<CNodePtr>(); | auto any_node = node->cast<CNodePtr>(); | ||||
| CheckIfCNodeIsNull(any_node); | CheckIfCNodeIsNull(any_node); | ||||
| bool changed = false; | |||||
| for (size_t i = 1; i < any_node->inputs().size(); i++) { | for (size_t i = 1; i < any_node->inputs().size(); i++) { | ||||
| auto input_node = any_node->input(i); | auto input_node = any_node->input(i); | ||||
| if (input_node->isa<CNode>() && CheckIsAllInputsParam(input_node)) { | |||||
| auto input_cnode = input_node->cast<CNodePtr>(); | |||||
| auto input_tensors = GetCNodeInputTensors(input_cnode); | |||||
| if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { | |||||
| FreeTensors(&input_tensors, nullptr); | |||||
| continue; | |||||
| } | |||||
| MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); | |||||
| auto output_nums = GetOutputTensorNum(input_cnode); | |||||
| std::vector<Tensor *> output_tensors{output_nums, new Tensor()}; | |||||
| auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||||
| if (lite_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "lite_primitive is nullptr"; | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| return nullptr; | |||||
| } | |||||
| // here, input_tensor's format need to be transposed nhwc according to fmkType, | |||||
| // but for the time being, we only transpose the tensor with 0/1/2/3D. | |||||
| // Others should be added in future. | |||||
| for (size_t j = 0; j < input_tensors.size(); ++j) { | |||||
| input_tensors[j]->SetFormat(schema::Format_NHWC); | |||||
| if (input_tensors[j]->shape().size() == 4) { | |||||
| MS_LOG(INFO) << "init input_tensor format to nhwc"; | |||||
| } | |||||
| } | |||||
| lite_primitive->InferShape(input_tensors, output_tensors); | |||||
| auto parameter = kernel::PopulateParameter(lite_primitive.get()); | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); | |||||
| return nullptr; | |||||
| } | |||||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); | |||||
| if (lite_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = lite_kernel->Run(); | |||||
| if (0 != ret) { | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); | |||||
| return nullptr; | |||||
| } | |||||
| auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); | |||||
| if (new_parameter == nullptr) { | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name(); | |||||
| return nullptr; | |||||
| if (!input_node->isa<CNode>() || !CheckIsAllInputsParam(input_node)) { | |||||
| continue; | |||||
| } | |||||
| auto input_cnode = input_node->cast<CNodePtr>(); | |||||
| auto input_tensors = GetCNodeInputTensors(input_cnode); | |||||
| if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { | |||||
| FreeTensors(&input_tensors, nullptr); | |||||
| continue; | |||||
| } | |||||
| changed = true; | |||||
| MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); | |||||
| auto output_nums = GetOutputTensorNum(input_cnode); | |||||
| std::vector<Tensor *> output_tensors{output_nums, new Tensor()}; | |||||
| auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||||
| if (lite_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "lite_primitive is nullptr"; | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| return nullptr; | |||||
| } | |||||
| // here, input_tensor's format need to be transposed nhwc according to fmkType, | |||||
| // but for the time being, we only transpose the tensor with 0/1/2/3D. | |||||
| // Others should be added in future. | |||||
| for (size_t j = 0; j < input_tensors.size(); ++j) { | |||||
| input_tensors[j]->SetFormat(schema::Format_NHWC); | |||||
| if (input_tensors[j]->shape().size() == 4) { | |||||
| MS_LOG(INFO) << "init input_tensor format to nhwc"; | |||||
| } | } | ||||
| new_parameter->set_name(input_node->fullname_with_scope()); | |||||
| any_node->set_input(i, new_parameter); | |||||
| } | |||||
| lite_primitive->InferShape(input_tensors, output_tensors); | |||||
| auto parameter = kernel::PopulateParameter(lite_primitive.get()); | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType) (lite_primitive->Type())); | |||||
| return nullptr; | |||||
| } | |||||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); | |||||
| if (lite_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = lite_kernel->Run(); | |||||
| if (0 != ret) { | |||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); | |||||
| return nullptr; | |||||
| } | |||||
| // replace cnode by new param | |||||
| if (ReplaceCNode(func_graph, any_node, input_node, output_tensors, i) != lite::RET_OK) { | |||||
| FreeTensors(&input_tensors, &output_tensors); | FreeTensors(&input_tensors, &output_tensors); | ||||
| delete (lite_kernel); | delete (lite_kernel); | ||||
| MS_LOG(ERROR) << "constant_folding replace cnode failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| FreeTensors(&input_tensors, &output_tensors); | |||||
| delete (lite_kernel); | |||||
| } | } | ||||
| return any_node; | |||||
| return changed ? any_node : nullptr; | |||||
| } | } | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||