From: @mengyuanli Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -20,6 +20,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #include "src/ops/assert_op.h" | |||||
| #include "src/ops/space_to_batch.h" | #include "src/ops/space_to_batch.h" | ||||
| #include "src/ops/space_to_batch_nd.h" | #include "src/ops/space_to_batch_nd.h" | ||||
| #include "src/ops/conv2d.h" | #include "src/ops/conv2d.h" | ||||
| @@ -614,6 +615,13 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Sqrt>(prim, inputs, quantType); | return NewPrimitiveC<Sqrt>(prim, inputs, quantType); | ||||
| } else if (op_type == "Greater") { | } else if (op_type == "Greater") { | ||||
| return NewPrimitiveC<Greater>(prim, inputs, quantType); | return NewPrimitiveC<Greater>(prim, inputs, quantType); | ||||
| } else if (op_type == "Switch") { | |||||
| return NewPrimitiveC<Switch>(prim, inputs, quantType); | |||||
| } else if (op_type == "Partial") { | |||||
| return NewPrimitiveC<Partial>(prim, inputs, quantType); | |||||
| } else if (op_type == "Merge") { | |||||
| return NewPrimitiveC<Merge>(prim, inputs, quantType); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | ||||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | ||||
| @@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) Merge(primitive); | return new (std::nothrow) Merge(primitive); | ||||
| case schema::PrimitiveType_Partial: | case schema::PrimitiveType_Partial: | ||||
| return new (std::nothrow) Partial(primitive); | return new (std::nothrow) Partial(primitive); | ||||
| case schema::PrimitiveType_Assert: | |||||
| return new (std::nothrow) AssertOP(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| @@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc || | |||||
| desc.type == schema::PrimitiveType_Nhwc2Nchw); | |||||
| if (opParameter == nullptr) { | if (opParameter == nullptr) { | ||||
| MS_LOG(ERROR) << "desc type is not Transpose"; | MS_LOG(ERROR) << "desc type is not Transpose"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -200,6 +200,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/while_pass.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### train | ### train | ||||
| @@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx | |||||
| mobilenetv2-7.onnx | mobilenetv2-7.onnx | ||||
| shufflenet-v2-10.onnx | shufflenet-v2-10.onnx | ||||
| squeezenet1.1-7.onnx | squeezenet1.1-7.onnx | ||||
| densenet-9.onnx | |||||
| #densenet-9.onnx | |||||
| ml_table_detection_fp32.onnx | ml_table_detection_fp32.onnx | ||||
| ml_table_segment.onnx | ml_table_segment.onnx | ||||
| googlenet-9.onnx | googlenet-9.onnx | ||||
| @@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3 | |||||
| super-resolution-10.onnx;1,224,224,1 | super-resolution-10.onnx;1,224,224,1 | ||||
| tinyyolov2-8.onnx;1,416,416,3 | tinyyolov2-8.onnx;1,416,416,3 | ||||
| ml_2012_ocr_cn.onnx | ml_2012_ocr_cn.onnx | ||||
| ml_2012_ocr_cn_noLSTM.onnx | |||||
| #ml_2012_ocr_cn_noLSTM.onnx | |||||
| candy-9.onnx | candy-9.onnx | ||||
| mosaic-9.onnx | mosaic-9.onnx | ||||
| pointilism-9.onnx | pointilism-9.onnx | ||||
| @@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx 2 | |||||
| mobilenetv2-7.onnx 8 | mobilenetv2-7.onnx 8 | ||||
| shufflenet-v2-10.onnx 5 | shufflenet-v2-10.onnx 5 | ||||
| squeezenet1.1-7.onnx 1 | squeezenet1.1-7.onnx 1 | ||||
| densenet-9.onnx 6 | |||||
| #densenet-9.onnx 6 | |||||
| ml_table_detection_fp32.onnx 2 | ml_table_detection_fp32.onnx 2 | ||||
| ml_table_segment.onnx 2 | ml_table_segment.onnx 2 | ||||
| googlenet-9.onnx 3 | googlenet-9.onnx 3 | ||||
| @@ -27,7 +27,7 @@ mnist-8.onnx 10 | |||||
| #super-resolution-10.onnx 1 | #super-resolution-10.onnx 1 | ||||
| #tinyyolov2-8.onnx 0.3 | #tinyyolov2-8.onnx 0.3 | ||||
| ml_2012_ocr_cn.onnx 200 | ml_2012_ocr_cn.onnx 200 | ||||
| ml_2012_ocr_cn_noLSTM.onnx 1 | |||||
| #ml_2012_ocr_cn_noLSTM.onnx 1 | |||||
| candy-9.onnx 5 | candy-9.onnx 5 | ||||
| mosaic-9.onnx 4 | mosaic-9.onnx 4 | ||||
| pointilism-9.onnx 3 | pointilism-9.onnx 3 | ||||
| @@ -28,6 +28,8 @@ | |||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/ops/partial.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | ||||
| @@ -73,7 +75,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | ||||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | ||||
| hasDepend = true; | hasDepend = true; | ||||
| bool maskOut = (dependNode->inputs().size() == 3) ? true : false; | |||||
| bool maskOut = (dependNode->inputs().size() == 3); | |||||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | ||||
| AnfNodePtr dependInputNode = dependNode->input(j); | AnfNodePtr dependInputNode = dependNode->input(j); | ||||
| if (dependInputNode->isa<CNode>()) { | if (dependInputNode->isa<CNode>()) { | ||||
| @@ -172,22 +174,50 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||||
| for (auto node : graph_input_nodes_) { | |||||
| std::vector<schema::CNodeT *> AnfExporter::GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const size_t &subgraph_index) { | |||||
| std::vector<schema::CNodeT *> subgraph_nodes{}; | |||||
| subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size()); | |||||
| std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(), | |||||
| meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(), | |||||
| [&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); }); | |||||
| return subgraph_nodes; | |||||
| } | |||||
| int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const size_t &subgraph_index) { | |||||
| auto &subgraph = meta_graphT->subGraph.at(subgraph_index); | |||||
| auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index); | |||||
| std::vector<schema::CNodeT *> subgraph_input_nodes{}; | |||||
| for (auto &node : subgraph_nodes) { | |||||
| if (IsContain(graph_input_nodes_, node)) { | |||||
| subgraph_input_nodes.push_back(node); | |||||
| } | |||||
| } | |||||
| std::vector<schema::TensorT *> subgraph_inputs{}; | |||||
| for (auto &node : subgraph_input_nodes) { | |||||
| for (auto input : node->inputIndex) { | for (auto input : node->inputIndex) { | ||||
| auto tensor = meta_graphT->allTensors[input].get(); | auto tensor = meta_graphT->allTensors[input].get(); | ||||
| if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) { | if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) { | ||||
| tensor->nodeType = schema::NodeType_ValueNode; | tensor->nodeType = schema::NodeType_ValueNode; | ||||
| tensor->format = schema::Format_NHWC; | tensor->format = schema::Format_NHWC; | ||||
| if (!IsContain(meta_graphT->inputIndex, input)) { | |||||
| meta_graphT->inputIndex.emplace_back(input); | |||||
| if (!IsContain(subgraph->inputIndices, input)) { | |||||
| if (subgraph_index == kMainGraphIndex) { | |||||
| meta_graphT->inputIndex.push_back(input); | |||||
| } | |||||
| subgraph->inputIndices.push_back(input); | |||||
| subgraph_inputs.push_back(tensor); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const std::unique_ptr<schema::SubGraphT> &sub_graphT, | |||||
| schema::CNodeT *return_node) { | schema::CNodeT *return_node) { | ||||
| MS_ASSERT(nullptr != meta_graphT); | MS_ASSERT(nullptr != meta_graphT); | ||||
| MS_ASSERT(nullptr != return_node); | MS_ASSERT(nullptr != return_node); | ||||
| @@ -202,28 +232,62 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt | |||||
| MS_LOG(ERROR) << "obtain outputs failed"; | MS_LOG(ERROR) << "obtain outputs failed"; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| } else if (input_node->isa<Parameter>()) { | |||||
| MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node"; | |||||
| continue; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; | MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| for (unsigned int &i : return_node->inputIndex) { | for (unsigned int &i : return_node->inputIndex) { | ||||
| meta_graphT->outputIndex.push_back(i); | |||||
| if (subgraph_index == kMainGraphIndex) { | |||||
| meta_graphT->outputIndex.push_back(i); | |||||
| } | |||||
| meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||||
| int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const size_t &subgraph_index, bool keep_graph, bool copy_primitive, | |||||
| const std::shared_ptr<AnfNode> &partial_anode) { | |||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>()); | |||||
| auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); | |||||
| auto subgraph_name = func_graph->get_attr("graph_name"); | |||||
| MS_ASSERT(subgraph_name != nullptr); | |||||
| sub_graphT->name = GetValue<std::string>(subgraph_name); | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| for (const auto &cnode : cnodes) { | for (const auto &cnode : cnodes) { | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| ret = RET_MEMORY_FAILED; | |||||
| break; | |||||
| auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||||
| if (fg != nullptr) { | |||||
| auto partial_cnode = CreatePartialCnode(fg, cnode); | |||||
| primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(partial_cnode->input(0)); | |||||
| auto primT = primitive_c->primitiveT(); | |||||
| auto pos = fg_subgraph_map.find(fg); | |||||
| if (pos != fg_subgraph_map.end()) { | |||||
| primT->value.AsPartial()->subGraphIndex = fg_subgraph_map.at(fg); | |||||
| } else { | |||||
| size_t next_subgraph_index = fg_subgraph_map.size() + 1; | |||||
| fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index}); | |||||
| primT->value.AsPartial()->subGraphIndex = next_subgraph_index; | |||||
| ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ExportSubgraph failed"; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| ret = RET_MEMORY_FAILED; | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| RemoveIfMakeTuple(cnode); | RemoveIfMakeTuple(cnode); | ||||
| RemoveIfDepend(cnode); | RemoveIfDepend(cnode); | ||||
| @@ -249,13 +313,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| } | } | ||||
| if (primT->value.type == schema::PrimitiveType_Return) { | if (primT->value.type == schema::PrimitiveType_Return) { | ||||
| node->name = "return_node"; | node->name = "return_node"; | ||||
| ret = SetGraphoutputIndex(cnode, meta_graphT, node.get()); | |||||
| ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "SetOpOutputN failed"; | MS_LOG(ERROR) << "SetOpOutputN failed"; | ||||
| break; | break; | ||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| node->nodeType = schema::NodeType_CNode; | node->nodeType = schema::NodeType_CNode; | ||||
| node->name = cnode->fullname_with_scope(); | node->name = cnode->fullname_with_scope(); | ||||
| if (copy_primitive) { | if (copy_primitive) { | ||||
| @@ -281,21 +346,45 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| if (!keep_graph) { | if (!keep_graph) { | ||||
| primitive_c->ClearPrimitiveT(); | primitive_c->ClearPrimitiveT(); | ||||
| } | } | ||||
| meta_graphT->nodes.emplace_back(std::move(node)); | |||||
| meta_graphT->nodes.push_back(std::move(node)); | |||||
| meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); | |||||
| } | } | ||||
| if (ret != RET_OK) { | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||||
| return ret; | |||||
| } | |||||
| ret = SetGraphInputIndex(meta_graphT, subgraph_index); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetGraphInputIndex failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||||
| return ret; | |||||
| } | |||||
| ret = SetSubgraphTensorIndices(meta_graphT.get()); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||||
| static int subgraph_index = 0; | |||||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||||
| int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // set graph input tensors | |||||
| SetGraphInputIndex(meta_graphT); | |||||
| return meta_graphT.release(); | return meta_graphT.release(); | ||||
| } | } | ||||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) { | int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) { | ||||
| std::string input_name = input_anode->fullname_with_scope(); | std::string input_name = input_anode->fullname_with_scope(); | ||||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | auto input_cnode = utils::cast<CNodePtr>(input_anode); | ||||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | ||||
| #ifndef SUPPORT_TRAIN | #ifndef SUPPORT_TRAIN | ||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | if (node_id_map_.find(input_name) != node_id_map_.end()) { | ||||
| @@ -343,11 +432,11 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||||
| input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | ||||
| iter = node_id_map_.find(input_index_key); | iter = node_id_map_.find(input_index_key); | ||||
| if (iter == node_id_map_.end()) { | if (iter == node_id_map_.end()) { | ||||
| MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; | |||||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| #else | #else | ||||
| MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; | |||||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -367,6 +456,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano | |||||
| } | } | ||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | auto paramTensor = std::make_unique<schema::TensorT>(); | ||||
| paramTensor->format = schema::Format_NHWC; | paramTensor->format = schema::Format_NHWC; | ||||
| paramTensor->name = paramNode->name(); | |||||
| auto abstractBase = paramNode->abstract(); | auto abstractBase = paramNode->abstract(); | ||||
| if (abstractBase == nullptr) { | if (abstractBase == nullptr) { | ||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | ||||
| @@ -518,6 +608,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | ||||
| } else if (value->isa<FuncGraph>()) { | |||||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; | |||||
| return RET_OK; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Not support value type , need add support."; | MS_LOG(ERROR) << "Not support value type , need add support."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -644,6 +737,20 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||||
| } | } | ||||
| } | } | ||||
| bool AnfExporter::HasPrimitiveCNode(const AnfNodePtr &node) { | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (prim == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { | bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -658,6 +765,47 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType | |||||
| return (schema::PrimitiveType)(prim->Type()) == type; | return (schema::PrimitiveType)(prim->Type()) == type; | ||||
| } | } | ||||
| ValueNodePtr AnfExporter::GetPartialAnfPrim() { | |||||
| auto partial_primitiveT = new (std::nothrow) schema::PrimitiveT; | |||||
| if (partial_primitiveT == nullptr) { | |||||
| MS_LOG(ERROR) << "new partial_primitiveT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| partial_primitiveT->value.type = schema::PrimitiveType_Partial; | |||||
| partial_primitiveT->value.value = new (std::nothrow) schema::PartialT; | |||||
| if (partial_primitiveT->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new PartialT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto partial_prim = std::make_shared<lite::Partial>(partial_primitiveT); | |||||
| ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | |||||
| return partial_anf_prim; | |||||
| } | |||||
| CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node) { | |||||
| if (utils::isa<CNodePtr>(node)) { | |||||
| auto cnode = utils::cast<CNodePtr>(node); | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c != nullptr) { | |||||
| return cnode; | |||||
| } | |||||
| auto partial_anf_prim_vnode = GetPartialAnfPrim(); | |||||
| auto cnode_input = cnode->inputs(); | |||||
| cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode); | |||||
| cnode->set_inputs(cnode_input); | |||||
| return cnode; | |||||
| } else if (utils::isa<ValueNodePtr>(node)) { | |||||
| auto partial_anf_prim_vnode = GetPartialAnfPrim(); | |||||
| std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node}; | |||||
| auto cnode = fg->NewCNode(inputs); | |||||
| return cnode; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "failed to create partial cnode."; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | ||||
| AnfExporter anf_exporter; | AnfExporter anf_exporter; | ||||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive); | return anf_exporter.Export(func_graph, keep_graph, copy_primitive); | ||||
| @@ -27,6 +27,10 @@ | |||||
| #include "tools/converter/converter_context.h" | #include "tools/converter/converter_context.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| constexpr const int kPartialMinSize = 3; | |||||
| constexpr const int kMainGraphIndex = 0; | |||||
| class AnfExporter { | class AnfExporter { | ||||
| public: | public: | ||||
| AnfExporter() = default; | AnfExporter() = default; | ||||
| @@ -45,17 +49,28 @@ class AnfExporter { | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | ||||
| int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | ||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | ||||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *return_node); | |||||
| int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index); | |||||
| int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node); | |||||
| static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | ||||
| static bool HasPrimitiveCNode(const AnfNodePtr &node); | |||||
| static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | ||||
| const std::shared_ptr<PrimitiveC> &primitive, | const std::shared_ptr<PrimitiveC> &primitive, | ||||
| const std::unique_ptr<schema::CNodeT> &dst_node); | const std::unique_ptr<schema::CNodeT> &dst_node); | ||||
| int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const size_t &subgraph_index, bool keep_graph, bool copy_primitive, | |||||
| const std::shared_ptr<AnfNode> &partial_anode = nullptr); | |||||
| ValueNodePtr GetPartialAnfPrim(); | |||||
| CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode); | |||||
| std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| const size_t &subgraph_index); | |||||
| private: | private: | ||||
| std::map<std::string, int> node_id_map_; | std::map<std::string, int> node_id_map_; | ||||
| std::vector<schema::CNodeT *> graph_input_nodes_; | std::vector<schema::CNodeT *> graph_input_nodes_; | ||||
| std::map<FuncGraphPtr, int> fg_subgraph_map; | |||||
| uint32_t node_idx = 0; | |||||
| }; | }; | ||||
| // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | ||||
| // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | ||||
| @@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| // update graph input indexes | |||||
| // update graph input indices | |||||
| for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { | for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { | ||||
| if (*gInIdx > deleteIdx) { | if (*gInIdx > deleteIdx) { | ||||
| (*gInIdx)--; | (*gInIdx)--; | ||||
| } | } | ||||
| } | } | ||||
| // update graph output indexes | |||||
| // update graph output indices | |||||
| for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { | for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { | ||||
| if (*gOutIdx > deleteIdx) { | if (*gOutIdx > deleteIdx) { | ||||
| (*gOutIdx)--; | (*gOutIdx)--; | ||||
| } | } | ||||
| } | } | ||||
| for (auto &subgraph : graphT->subGraph) { | |||||
| // update subgraph input indices | |||||
| for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) { | |||||
| if (*gInIdx > deleteIdx) { | |||||
| (*gInIdx)--; | |||||
| } | |||||
| } | |||||
| // update subgraph output indices | |||||
| for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) { | |||||
| if (*gOutIdx > deleteIdx) { | |||||
| (*gOutIdx)--; | |||||
| } | |||||
| } | |||||
| // update subgraph output indices | |||||
| for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) { | |||||
| if (*idx > deleteIdx) { | |||||
| (*idx)--; | |||||
| } | |||||
| } | |||||
| } | |||||
| // update nodes indexes | // update nodes indexes | ||||
| for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | ||||
| // update nodes input indexes | // update nodes input indexes | ||||
| @@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) { | |||||
| modelName = modelName.substr(0, modelName.find_last_of('.')); | modelName = modelName.substr(0, modelName.find_last_of('.')); | ||||
| return modelName; | return modelName; | ||||
| } | } | ||||
| int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) { | |||||
| for (auto &subgraph : meta_graphT->subGraph) { | |||||
| std::vector<uint32_t> subgraph_indices{}; | |||||
| for (auto &node_idx : subgraph->nodeIndices) { | |||||
| auto &node = meta_graphT->nodes.at(node_idx); | |||||
| for (auto &input_idx : node->inputIndex) { | |||||
| if (IsContain(subgraph_indices, input_idx)) { | |||||
| continue; | |||||
| } else { | |||||
| subgraph_indices.push_back(input_idx); | |||||
| } | |||||
| } | |||||
| for (auto &output_idx : node->outputIndex) { | |||||
| if (IsContain(subgraph_indices, output_idx)) { | |||||
| continue; | |||||
| } else { | |||||
| subgraph_indices.push_back(output_idx); | |||||
| } | |||||
| } | |||||
| } | |||||
| subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo | |||||
| STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | ||||
| STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT); | |||||
| std::string GetModelName(const std::string &modelFile); | std::string GetModelName(const std::string &modelFile); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/slice_prepose_pass.cc | ../optimizer/graph/slice_prepose_pass.cc | ||||
| ../optimizer/graph/mindir_adjust_pass.cc | ../optimizer/graph/mindir_adjust_pass.cc | ||||
| ../optimizer/graph/onnx_inputs_adjust_pass.cc | ../optimizer/graph/onnx_inputs_adjust_pass.cc | ||||
| ../optimizer/graph/while_pass.cc | |||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | add_subdirectory(../anf_importer anf_importer) | ||||
| @@ -42,6 +42,7 @@ | |||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/infershape_pass.h" | #include "tools/optimizer/graph/infershape_pass.h" | ||||
| #include "tools/optimizer/graph/slice_prepose_pass.h" | #include "tools/optimizer/graph/slice_prepose_pass.h" | ||||
| #include "tools/optimizer/graph/while_pass.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| @@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default; | |||||
| AnfTransform::~AnfTransform() = default; | AnfTransform::~AnfTransform() = default; | ||||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||||
| FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||||
| MS_ASSERT(nullptr != old_graph); | MS_ASSERT(nullptr != old_graph); | ||||
| if (config == nullptr) { | if (config == nullptr) { | ||||
| MS_LOG(ERROR) << "config shoud be specified"; | |||||
| MS_LOG(ERROR) << "config should be specified"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (old_graph->has_flag("HasTransformed")) { | |||||
| old_graph->set_flag("HasTransformed", false); | |||||
| return old_graph; | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | ||||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | ||||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | ||||
| // mindir pre adjustment | |||||
| if (config->fmk == converter::FmkType_MS) { | if (config->fmk == converter::FmkType_MS) { | ||||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | ||||
| mindir_adjust_pass->SetFmkType(config->fmk); | mindir_adjust_pass->SetFmkType(config->fmk); | ||||
| @@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| } | } | ||||
| } | } | ||||
| // for now - trainning is not supporting fuse operations | |||||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) { | |||||
| auto while_pass = std::make_shared<opt::WhilePass>(); | |||||
| graph_pm->AddPass(while_pass); | |||||
| } | |||||
| // for now - training is not supporting fuse operations | |||||
| if (!config->trainModel) { | if (!config->trainModel) { | ||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | ||||
| @@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| return new_graph; | return new_graph; | ||||
| } | } | ||||
| STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, | |||||
| std::vector<ValueNodePtr> *vnodes) { | |||||
| auto nodes = TopoSort(main_graph->get_return()); | |||||
| for (auto &node : nodes) { | |||||
| auto fg = GetValueNode<FuncGraphPtr>(node); | |||||
| if (fg) { | |||||
| vnodes->push_back(utils::cast<ValueNodePtr>(node)); | |||||
| subgraphs->push_back(fg); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { | |||||
| // transform main_graph | |||||
| auto new_main_graph = TransformSingleFuncGraph(main_graph, config); | |||||
| if (new_main_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| // transform sub_graph | |||||
| FuncGraphPtrList subgraphs{}; | |||||
| std::vector<ValueNodePtr> vnodes{}; | |||||
| int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||||
| return nullptr; | |||||
| } | |||||
| for (size_t i = 0; i < subgraphs.size(); i++) { | |||||
| auto new_graph = Transform(subgraphs.at(i), config); | |||||
| new_graph->set_flag("HasTransformed", true); | |||||
| vnodes.at(i)->set_value(new_graph); | |||||
| } | |||||
| return new_main_graph; | |||||
| } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "tools/common/storage.h" | #include "tools/common/storage.h" | ||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| @@ -34,6 +35,9 @@ class AnfTransform { | |||||
| FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | ||||
| private: | private: | ||||
| STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, | |||||
| std::vector<ValueNodePtr> *vnodes); | |||||
| FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | |||||
| std::unique_ptr<quant::Quantizer> mQuantizer = nullptr; | std::unique_ptr<quant::Quantizer> mQuantizer = nullptr; | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| int status = modelImporter->Import(flag); | int status = modelImporter->Import(flag); | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| graph = modelImporter->GetResult(); | graph = modelImporter->GetResult(); | ||||
| graph->set_attr("graph_name", MakeValue("main_graph")); | |||||
| } else { | } else { | ||||
| MS_ASSERT(nullptr != modelParser); | MS_ASSERT(nullptr != modelParser); | ||||
| const std::string modelFile = flag->modelFile; | const std::string modelFile = flag->modelFile; | ||||
| @@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | MS_LOG(ERROR) << "Export to meta graph return nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // transform | // transform | ||||
| transform->SetGraphDef(meta_graph); | transform->SetGraphDef(meta_graph); | ||||
| auto status = transform->Transform(*flag); | auto status = transform->Transform(*flag); | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "tools/converter/graphdef_transform.h" | #include "tools/converter/graphdef_transform.h" | ||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| @@ -37,9 +38,21 @@ | |||||
| #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" | #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/switch_pass.h" | |||||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||||
| #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" | |||||
| using std::string; | using std::string; | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { | |||||
| std::vector<schema::CNodeT *> old_nodes{}; | |||||
| old_nodes.resize(graphDefT->nodes.size()); | |||||
| std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(), | |||||
| [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); | |||||
| return old_nodes; | |||||
| } | |||||
| GraphDefTransform::GraphDefTransform() = default; | GraphDefTransform::GraphDefTransform() = default; | ||||
| GraphDefTransform::~GraphDefTransform() = default; | GraphDefTransform::~GraphDefTransform() = default; | ||||
| @@ -48,141 +61,232 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ | |||||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | int GraphDefTransform::Transform(const converter::Flags &ctx) { | ||||
| STATUS status; | STATUS status; | ||||
| { | |||||
| Optimizer unusedOpRemoveOptimizer; | |||||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||||
| if (!ctx.trainModel) { | |||||
| unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); | |||||
| if (ctx.fmk != converter::FmkType_TF) { | |||||
| { | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer unusedOpRemoveOptimizer; | |||||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||||
| if (!ctx.trainModel) { | |||||
| unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); | |||||
| } | |||||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||||
| unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| // topological sorting | |||||
| { | |||||
| Optimizer topologicalOptimizer; | |||||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| status = topologicalOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| // topological sorting | |||||
| { | |||||
| Optimizer topologicalOptimizer; | |||||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| status = topologicalOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| // generate and infer quant parameters | |||||
| { | |||||
| Optimizer inferQuantParamPass; | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| status = inferQuantParamPass.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| // generate and infer quant parameters | |||||
| { | |||||
| Optimizer inferQuantParamPass; | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| status = inferQuantParamPass.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| // postconvert pass | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer fusionOptimizer; | |||||
| if (!ctx.trainModel) { | |||||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||||
| if (batch_norm_scale_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||||
| fusionOptimizer.AddPass(batch_norm_scale_pass); | |||||
| } | |||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| // postconvert pass | |||||
| { | |||||
| Optimizer fusionOptimizer; | |||||
| if (!ctx.trainModel) { | |||||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||||
| if (batch_norm_scale_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||||
| fusionOptimizer.AddPass(batch_norm_scale_pass); | |||||
| // format transform | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer formatTransOptimizer; | |||||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||||
| if (formatTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| formatTransPass->SetQuantType(ctx.quantType); | |||||
| formatTransPass->SetFmk(ctx.fmk); | |||||
| formatTransOptimizer.AddPass(formatTransPass); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| status = formatTransOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; | |||||
| return status; | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer formatTransOptimizer; | |||||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||||
| if (formatTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = formatTransOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| // format transform | |||||
| { | |||||
| Optimizer formatTransOptimizer; | |||||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||||
| if (formatTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer formatTransOptimizer; | |||||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||||
| if (formatTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| } | |||||
| status = formatTransOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| formatTransPass->SetQuantType(ctx.quantType); | |||||
| formatTransPass->SetFmk(ctx.fmk); | |||||
| formatTransOptimizer.AddPass(formatTransPass); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| status = formatTransOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| // do quantization | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer tensorQuantOptimizer; | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = tensorQuantOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | |||||
| { | |||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| // insert quantNode and deQuantNode | |||||
| { | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer quantNodeOptimizer; | |||||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||||
| if (dTypeTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| status = quantNodeOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| auto old_nodes2 = GetGraphNodes(); | |||||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); | |||||
| status = quantNodeOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| // do quantization | |||||
| // switch pass | |||||
| { | { | ||||
| Optimizer tensorQuantOptimizer; | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||||
| status = tensorQuantOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer switchOptimizer; | |||||
| switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); | |||||
| switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = switchOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run switch graphPasses Failed"; | |||||
| return status; | return status; | ||||
| } | } | ||||
| } | } | ||||
| // insert quantNode and deQuantNode | |||||
| // subgraph tensor pass | |||||
| { | { | ||||
| Optimizer quantNodeOptimizer; | |||||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||||
| if (dTypeTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||||
| status = quantNodeOptimizer.Run(graphDefT); | |||||
| Optimizer subgraphTensorOptimizer; | |||||
| subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||||
| status = subgraphTensorOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||||
| MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; | |||||
| return status; | return status; | ||||
| } | } | ||||
| } | } | ||||
| // tensor name | // tensor name | ||||
| { | { | ||||
| // init old node indecies | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer nameOptimizer; | Optimizer nameOptimizer; | ||||
| nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | ||||
| nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | ||||
| status = nameOptimizer.Run(graphDefT); | status = nameOptimizer.Run(graphDefT); | ||||
| @@ -192,16 +296,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // topological sorting | |||||
| { | |||||
| Optimizer topologicalOptimizer; | |||||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| status = topologicalOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | |||||
| } // namespace mindspore::lite | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "tools/converter/optimizer.h" | #include "tools/converter/optimizer.h" | ||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| @@ -39,6 +40,7 @@ class GraphDefTransform { | |||||
| inline schema::MetaGraphT *GetOutput() { return graphDefT; } | inline schema::MetaGraphT *GetOutput() { return graphDefT; } | ||||
| protected: | protected: | ||||
| std::vector<schema::CNodeT *> GetGraphNodes(); | |||||
| schema::MetaGraphT *graphDefT = nullptr; | schema::MetaGraphT *graphDefT = nullptr; | ||||
| Optimizer *optimizer = nullptr; | Optimizer *optimizer = nullptr; | ||||
| }; | }; | ||||
| @@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc | |||||
| ) | ) | ||||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | ||||
| add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { | |||||
| for (auto &subgraph : graph->subGraph) { | |||||
| for (auto &idx : subgraph->nodeIndices) { | |||||
| if (idx > node_idx) { | |||||
| idx--; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| std::vector<schema::CNodeT *> new_nodes{}; | |||||
| std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes), | |||||
| [](std::unique_ptr<CNodeT> &node) { return node.get(); }); | |||||
| for (auto it = old_nodes_.begin(); it != old_nodes_.end();) { | |||||
| if (!IsContain(new_nodes, *it)) { | |||||
| size_t node_idx = it - old_nodes_.begin(); | |||||
| for (auto &subgraph : graph->subGraph) { | |||||
| auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx); | |||||
| if (node_idx_pos != subgraph->nodeIndices.end()) { | |||||
| subgraph->nodeIndices.erase(node_idx_pos); | |||||
| UpdateSubgraphNodeIndices(node_idx, graph); | |||||
| break; | |||||
| } | |||||
| } | |||||
| it = old_nodes_.erase(it); | |||||
| } else { | |||||
| it++; | |||||
| } | |||||
| } | |||||
| for (uint32_t i = 0; i < new_nodes.size(); i++) { | |||||
| if (!IsContain(old_nodes_, new_nodes[i])) { | |||||
| for (auto &subgraph : graph->subGraph) { | |||||
| if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) { | |||||
| subgraph->nodeIndices.push_back(old_nodes_.size()); | |||||
| old_nodes_.push_back(new_nodes[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H | |||||
| #define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "tools/converter/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class SubgraphNodePass : public GraphPass { | |||||
| public: | |||||
| explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {} | |||||
| ~SubgraphNodePass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| private: | |||||
| void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); | |||||
| std::vector<schema::CNodeT *> old_nodes_; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H | |||||
| @@ -0,0 +1,100 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { | |||||
| for (const auto &node : graph->nodes) { | |||||
| if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) { | |||||
| return true; | |||||
| } | |||||
| if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { | |||||
| for (const auto &subgraph : graph->subGraph) { | |||||
| UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx); | |||||
| UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx); | |||||
| } | |||||
| for (const auto &node : graph->nodes) { | |||||
| UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx); | |||||
| UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx); | |||||
| } | |||||
| UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx); | |||||
| UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) { | |||||
| for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) { | |||||
| uint32_t idx = it - graph->allTensors.begin(); | |||||
| if (IsUsing(graph, idx)) { | |||||
| it++; | |||||
| } else { | |||||
| it = graph->allTensors.erase(it); | |||||
| UpdateTensorIdx(graph, idx); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph->subGraph.size() > 0); | |||||
| graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end()); | |||||
| graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end()); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| int ret = RemoveUselessTensors(graph); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = SetSubgraphTensorIndices(graph); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = SyncMainGraphInputAndOutput(graph); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H | |||||
| #define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "tools/converter/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class SubgraphTensorPass : public GraphPass { | |||||
| public: | |||||
| SubgraphTensorPass() = default; | |||||
| ~SubgraphTensorPass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| private: | |||||
| STATUS RemoveUselessTensors(schema::MetaGraphT *graph); | |||||
| bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx); | |||||
| STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx); | |||||
| STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph); | |||||
| template <typename T> | |||||
| void UpdateVec(std::vector<T> *vec, T element) { | |||||
| for (auto iter = vec->begin(); iter != vec->end(); iter++) { | |||||
| if (*iter > element) { | |||||
| (*iter)--; | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H | |||||
| @@ -16,6 +16,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <algorithm> | |||||
| #include "tools/converter/legacy_optimizer/graph/switch_pass.h" | #include "tools/converter/legacy_optimizer/graph/switch_pass.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -96,38 +97,6 @@ std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_p | |||||
| return out_tensor; | return out_tensor; | ||||
| } | } | ||||
| STATUS SingleSwitchPass::MoveMaxIterationToCond() { | |||||
| auto &body_subgraph_input = graph_->subGraph.at(body_subgraph_index_)->inputIndices; | |||||
| for (auto it = body_subgraph_input.begin(); it != body_subgraph_input.end();) { | |||||
| if (!body_to_cond_partial_node_->inputIndex.empty() && IsContain(body_to_cond_partial_node_->inputIndex, *it)) { | |||||
| int32_t max_iteration_idx = it - body_subgraph_input.begin(); | |||||
| // get maxiteration tensor | |||||
| auto &max_iteration_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex.at(max_iteration_idx)); | |||||
| auto all_tensor_idx = std::find(graph_->allTensors.begin(), graph_->allTensors.end(), max_iteration_tensor) - | |||||
| graph_->allTensors.begin(); | |||||
| // remove maxiteration from body_to_cond partial node | |||||
| body_to_cond_partial_node_->inputIndex.erase(body_to_cond_partial_node_->inputIndex.begin() + max_iteration_idx); | |||||
| // concat body subgraph tensor to max iteration in all tensor | |||||
| auto body_max_iteration_tensor_idx = body_subgraph_input.at(max_iteration_idx); | |||||
| for (auto &node : cond_graph_nodes_) { | |||||
| std::replace_if( | |||||
| node->inputIndex.begin(), node->inputIndex.end(), | |||||
| [&body_max_iteration_tensor_idx](uint32_t idx) { return idx == body_max_iteration_tensor_idx; }, | |||||
| all_tensor_idx); | |||||
| } | |||||
| // remove maxiteration from body partial input and body func input | |||||
| body_partial_node_->inputIndex.erase(body_partial_node_->inputIndex.begin() + max_iteration_idx); | |||||
| it = body_subgraph_input.erase(it); | |||||
| } else { | |||||
| it++; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SingleSwitchPass::InsertMerge() { | STATUS SingleSwitchPass::InsertMerge() { | ||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | ||||
| @@ -154,9 +123,9 @@ STATUS SingleSwitchPass::InsertMerge() { | |||||
| } | } | ||||
| // double merge inputs to contain the outputs of body node | // double merge inputs to contain the outputs of body node | ||||
| for (auto &out_index : origin_switch_output_tensor_indices_) { | |||||
| auto &switch_out_tensor = graph_->allTensors.at(out_index); | |||||
| auto tensor = NewTensor(switch_out_tensor); | |||||
| for (auto &index : cond_partial_node_->inputIndex) { | |||||
| auto &in_tensor = graph_->allTensors.at(index); | |||||
| auto tensor = NewTensor(in_tensor); | |||||
| graph_->allTensors.push_back(std::move(tensor)); | graph_->allTensors.push_back(std::move(tensor)); | ||||
| merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); | merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); | ||||
| } | } | ||||
| @@ -266,10 +235,6 @@ STATUS SingleSwitchPass::Init() { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (switch_node_->inputIndex.size() < kSwitchMinInputSize) { | if (switch_node_->inputIndex.size() < kSwitchMinInputSize) { | ||||
| MS_LOG(ERROR) << "switch node: " << switch_node_->name | MS_LOG(ERROR) << "switch node: " << switch_node_->name | ||||
| << " 's input size is not right, size: " << switch_node_->inputIndex.size(); | << " 's input size is not right, size: " << switch_node_->inputIndex.size(); | ||||
| @@ -297,10 +262,6 @@ STATUS SingleSwitchPass::Init() { | |||||
| } | } | ||||
| } | } | ||||
| if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || | |||||
| body_partial_node_->primitive->value.type != PrimitiveType_Partial) { | |||||
| return RET_OK; | |||||
| } | |||||
| // get cond_graph_nodes_ | // get cond_graph_nodes_ | ||||
| cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex; | cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex; | ||||
| auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices; | auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices; | ||||
| @@ -330,17 +291,36 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| auto &partial_inputs = partial_node->inputIndex; | auto &partial_inputs = partial_node->inputIndex; | ||||
| auto &subgraph_inputs = graph_->subGraph.at(subgraph_index)->inputIndices; | |||||
| auto &subgraph = graph_->subGraph.at(subgraph_index); | |||||
| auto &subgraph_inputs = subgraph->inputIndices; | |||||
| std::map<int, int> subgraph_input_map; | std::map<int, int> subgraph_input_map; | ||||
| std::vector<int> new_subgraph_inputs{}; | |||||
| std::vector<std::pair<int, int>> tmp_inputs_order{}; | |||||
| for (unsigned int &subgraph_input : subgraph_inputs) { | for (unsigned int &subgraph_input : subgraph_inputs) { | ||||
| auto &tensor = graph_->allTensors.at(subgraph_input); | auto &tensor = graph_->allTensors.at(subgraph_input); | ||||
| // get parameter input index k. subgraph name + “_input_" + "k" | |||||
| char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 7]; | |||||
| int partial_idx = k - '0'; | |||||
| if (tensor->name.size() < subgraph->name.size() + 8) { | |||||
| MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int partial_idx = -1; | |||||
| if (tensor->name.find("_input_") != std::string::npos) { | |||||
| // get parameter input index k. subgraph name + “_input_" + "k" | |||||
| auto pos = subgraph->name.size() + sizeof("_input_"); | |||||
| auto pos2 = tensor->name.find('_', pos); | |||||
| auto idx_str = tensor->name.substr(pos - 1, pos2); | |||||
| partial_idx = std::stoi(idx_str); | |||||
| } | |||||
| if (tensor->name.find("_output_") != std::string::npos) { | |||||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||||
| auto pos = subgraph->name.size() + sizeof("_output_"); | |||||
| auto pos2 = tensor->name.find('_', pos); | |||||
| auto idx_str = tensor->name.substr(pos - 1, pos2); | |||||
| partial_idx = std::stoi(idx_str); | |||||
| } | |||||
| subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]}); | subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]}); | ||||
| new_subgraph_inputs.push_back(partial_inputs[partial_idx]); | |||||
| tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]); | |||||
| } | } | ||||
| for (auto &subgraph_node : subgraph_nodes) { | for (auto &subgraph_node : subgraph_nodes) { | ||||
| @@ -350,6 +330,13 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::sort(tmp_inputs_order.begin(), tmp_inputs_order.end(), | |||||
| [](std::pair<int, int> a, std::pair<int, int> b) { return a.first < b.first; }); | |||||
| std::vector<int> new_subgraph_inputs{}; | |||||
| std::transform(tmp_inputs_order.begin(), tmp_inputs_order.end(), std::back_inserter(new_subgraph_inputs), | |||||
| [](std::pair<int, int> iter) { return iter.second; }); | |||||
| subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end()); | subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end()); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -362,17 +349,28 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| auto &partial_outputs = partial_node->outputIndex; | auto &partial_outputs = partial_node->outputIndex; | ||||
| auto &subgraph_outputs = graph_->subGraph.at(subgraph_index)->outputIndices; | |||||
| auto &subgraph = graph_->subGraph.at(subgraph_index); | |||||
| auto &subgraph_outputs = subgraph->outputIndices; | |||||
| std::map<int, int> subgraph_output_map; | std::map<int, int> subgraph_output_map; | ||||
| std::vector<int> new_subgraph_outputs{}; | |||||
| std::vector<std::pair<int, int>> tmp_outputs_order{}; | |||||
| for (unsigned int &subgraph_output : subgraph_outputs) { | for (unsigned int &subgraph_output : subgraph_outputs) { | ||||
| auto &tensor = graph_->allTensors.at(subgraph_output); | |||||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||||
| char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 8]; | |||||
| int partial_idx = k - '0'; | |||||
| subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]}); | |||||
| new_subgraph_outputs.push_back(partial_outputs[partial_idx]); | |||||
| for (auto &node : subgraph_nodes) { | |||||
| if (IsContain(node->outputIndex, subgraph_output)) { | |||||
| int partial_idx = -1; | |||||
| if (node->name == "LogicalAnd") { | |||||
| partial_idx = 0; | |||||
| } else { | |||||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||||
| auto pos = subgraph->name.size() + sizeof("_output_"); | |||||
| auto pos2 = node->name.find('_', pos); | |||||
| auto idx_str = node->name.substr(pos - 1, pos2); | |||||
| partial_idx = std::stoi(idx_str); | |||||
| } | |||||
| subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]}); | |||||
| tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| for (auto &subgraph_node : subgraph_nodes) { | for (auto &subgraph_node : subgraph_nodes) { | ||||
| @@ -382,6 +380,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::vector<int> new_subgraph_outputs{}; | |||||
| std::transform(tmp_outputs_order.begin(), tmp_outputs_order.end(), std::back_inserter(new_subgraph_outputs), | |||||
| [](std::pair<int, int> iter) { return iter.second; }); | |||||
| subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); | subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -416,102 +418,6 @@ STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| STATUS SingleSwitchPass::ConvertSwitchToSelect() { | |||||
| MS_ASSERT(switch_node_->inputIndex.size() >= 3); | |||||
| MS_ASSERT(switch_node_->inputIndex.size() % 2 != 0); | |||||
| MS_ASSERT(switch_node_->outputIndex.size() * 2 + 1 == switch_node_->inputIndex.size()); | |||||
| auto bool_index = switch_node_->inputIndex.front(); | |||||
| // insert switch node1 | |||||
| auto switch_node1 = std::make_unique<CNodeT>(); | |||||
| switch_node1->name = switch_node_->name + "-Switch-1"; | |||||
| switch_node1->primitive = std::make_unique<PrimitiveT>(); | |||||
| switch_node1->primitive->value.type = PrimitiveType_Switch; | |||||
| switch_node1->primitive->value.value = new (std::nothrow) SwitchT(); | |||||
| switch_node1->inputIndex = {bool_index}; | |||||
| std::vector<int> part_one_input_index( | |||||
| switch_node_->inputIndex.begin() + 1, | |||||
| switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2); | |||||
| switch_node1->inputIndex.insert(switch_node1->inputIndex.end(), part_one_input_index.begin(), | |||||
| part_one_input_index.end()); | |||||
| std::vector<std::unique_ptr<TensorT>> switch_output_tensors1(part_one_input_index.size() * 2); | |||||
| std::vector<int> switch_output_indexes1(part_one_input_index.size() * 2); | |||||
| int i = 0; | |||||
| for (const auto &input_index : part_one_input_index) { | |||||
| auto &switch_in_tensor = graph_->allTensors.at(input_index); | |||||
| auto tensor1 = NewTensor(switch_in_tensor); | |||||
| auto tensor2 = NewTensor(switch_in_tensor); | |||||
| switch_output_tensors1[i] = std::move(tensor1); | |||||
| switch_output_tensors1[part_one_input_index.size() + i] = std::move(tensor2); | |||||
| switch_output_indexes1[i] = graph_->allTensors.size() - 1 + i; | |||||
| switch_output_indexes1[part_one_input_index.size() + i] = | |||||
| graph_->allTensors.size() - 1 + i + part_one_input_index.size(); | |||||
| i++; | |||||
| } | |||||
| for (auto &tensor : switch_output_tensors1) { | |||||
| graph_->allTensors.emplace_back(std::move(tensor)); | |||||
| } | |||||
| switch_node1->outputIndex.insert(switch_node1->outputIndex.begin(), switch_output_indexes1.begin(), | |||||
| switch_output_indexes1.end()); | |||||
| // insert switch node2 | |||||
| auto switch_node2 = std::make_unique<CNodeT>(); | |||||
| switch_node2->name = switch_node_->name + "-Switch-1"; | |||||
| switch_node2->primitive = std::make_unique<PrimitiveT>(); | |||||
| switch_node2->primitive->value.type = PrimitiveType_Switch; | |||||
| switch_node2->primitive->value.value = new (std::nothrow) SwitchT(); | |||||
| switch_node2->inputIndex = {bool_index}; | |||||
| std::vector<int> part_two_input_index( | |||||
| switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2, switch_node_->inputIndex.end()); | |||||
| switch_node2->inputIndex.insert(switch_node2->inputIndex.end(), part_two_input_index.begin(), | |||||
| part_two_input_index.end()); | |||||
| std::vector<std::unique_ptr<TensorT>> switch_output_tensors2(part_two_input_index.size() * 2); | |||||
| std::vector<int> switch_output_indexes2(part_two_input_index.size() * 2); | |||||
| i = 0; | |||||
| for (const auto &input_index : part_two_input_index) { | |||||
| auto &switch_in_tensor = graph_->allTensors.at(input_index); | |||||
| auto tensor1 = NewTensor(switch_in_tensor); | |||||
| auto tensor2 = NewTensor(switch_in_tensor); | |||||
| switch_output_tensors2[i] = std::move(tensor1); | |||||
| switch_output_tensors2[part_two_input_index.size() + i] = std::move(tensor2); | |||||
| switch_output_indexes2[i] = graph_->allTensors.size() - 1 + i; | |||||
| switch_output_indexes2[part_two_input_index.size() + i] = | |||||
| graph_->allTensors.size() - 1 + i + part_two_input_index.size(); | |||||
| i++; | |||||
| } | |||||
| for (auto &tensor : switch_output_tensors2) { | |||||
| graph_->allTensors.emplace_back(std::move(tensor)); | |||||
| } | |||||
| switch_node2->outputIndex.insert(switch_node2->outputIndex.begin(), switch_output_indexes2.begin(), | |||||
| switch_output_indexes2.end()); | |||||
| // insert merge | |||||
| auto merge_node = std::make_unique<CNodeT>(); | |||||
| merge_node->name = switch_node_->name + "-Merge"; | |||||
| merge_node->primitive = std::make_unique<PrimitiveT>(); | |||||
| merge_node->primitive->value.type = PrimitiveType_Merge; | |||||
| merge_node->primitive->value.value = new (std::nothrow) MergeT(); | |||||
| std::vector<int> merge_input_indexes(switch_node_->outputIndex.size() * 2); | |||||
| for (i = 0; i < switch_node_->outputIndex.size(); i++) { | |||||
| merge_input_indexes[i] = switch_output_indexes1[i]; | |||||
| merge_input_indexes[i + switch_node_->outputIndex.size()] = | |||||
| switch_output_indexes2[i + switch_node_->outputIndex.size()]; | |||||
| merge_node->outputIndex.emplace_back(switch_node_->outputIndex.at(i)); | |||||
| } | |||||
| merge_node->inputIndex.insert(merge_node->inputIndex.end(), merge_input_indexes.begin(), merge_input_indexes.end()); | |||||
| graph_->nodes.emplace_back(std::move(switch_node1)); | |||||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||||
| graph_->nodes.emplace_back(std::move(switch_node2)); | |||||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||||
| graph_->nodes.emplace_back(std::move(merge_node)); | |||||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||||
| RemoveUselessNode(switch_node_, graph_); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SingleSwitchPass::Run() { | STATUS SingleSwitchPass::Run() { | ||||
| int ret = Init(); | int ret = Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -519,24 +425,6 @@ STATUS SingleSwitchPass::Run() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || | |||||
| body_partial_node_->primitive->value.type != PrimitiveType_Partial) { | |||||
| ret = ConvertSwitchToSelect(); | |||||
| return ret; | |||||
| } | |||||
| if (IsLoop()) { | |||||
| ret = MoveMaxIterationToCond(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "MoveMaxIterationToCond failed, ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| ret = DoubleSwitchOutput(); | ret = DoubleSwitchOutput(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret; | MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret; | ||||
| @@ -45,11 +45,9 @@ class SingleSwitchPass { | |||||
| STATUS Init(); | STATUS Init(); | ||||
| size_t InitThisGraphIndex(); | size_t InitThisGraphIndex(); | ||||
| STATUS DoubleSwitchOutput(); | STATUS DoubleSwitchOutput(); | ||||
| STATUS MoveMaxIterationToCond(); | |||||
| STATUS UpdateSwitchUser(); | STATUS UpdateSwitchUser(); | ||||
| STATUS ConcatCondSubgraphInputAndOutput(); | STATUS ConcatCondSubgraphInputAndOutput(); | ||||
| STATUS ConcatBodySubgraphInputAndOutput(); | STATUS ConcatBodySubgraphInputAndOutput(); | ||||
| STATUS ConvertSwitchToSelect(); | |||||
| bool IsLoop(); | bool IsLoop(); | ||||
| STATUS InsertMerge(); | STATUS InsertMerge(); | ||||
| STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, | STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, | ||||
| @@ -27,56 +27,71 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { | STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| std::vector<std::unique_ptr<schema::CNodeT>> newNodes; | |||||
| std::vector<size_t> sinkedTensorIdxes; | |||||
| // put all const tensor index into sinkedTensorIdxes | |||||
| std::vector<std::unique_ptr<schema::CNodeT>> new_nodes; | |||||
| std::vector<size_t> sinked_tensor_idxes; | |||||
| // put all const tensor index into sinked_tensor_idxes | |||||
| for (size_t i = 0; i < graph->allTensors.size(); i++) { | for (size_t i = 0; i < graph->allTensors.size(); i++) { | ||||
| if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) { | if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) { | ||||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i); | |||||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i); | |||||
| } | } | ||||
| } | } | ||||
| auto &oldNodes = graph->nodes; | |||||
| std::queue<std::unique_ptr<schema::CNodeT>> opQueue; | |||||
| // put all non depend node into queue | |||||
| for (auto &node : graph->nodes) { | |||||
| if (IsNodeNonDepend(node, sinkedTensorIdxes)) { | |||||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end()); | |||||
| opQueue.push(std::move(node)); | |||||
| auto &old_nodes = graph->nodes; | |||||
| std::queue<std::unique_ptr<schema::CNodeT>> op_queue; | |||||
| // put all none depend node into queue | |||||
| for (size_t i = 0; i < graph->subGraph.size(); i++) { | |||||
| std::vector<unsigned int> new_subgraph_node_indices = {}; | |||||
| auto subgraph_node_indices = graph->subGraph[i]->nodeIndices; | |||||
| for (size_t j = 0; j < subgraph_node_indices.size(); j++) { | |||||
| auto &node = old_nodes[subgraph_node_indices[j]]; | |||||
| if (IsNodeNonDepend(node, sinked_tensor_idxes)) { | |||||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end()); | |||||
| op_queue.push(std::move(node)); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| // bfs | |||||
| while (!opQueue.empty()) { | |||||
| auto &node = opQueue.front(); | |||||
| auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get())); | |||||
| for (auto postNodeIdx : postNodeIdxes) { | |||||
| auto &postNode = oldNodes.at(postNodeIdx); | |||||
| // check if postNode is non-depended | |||||
| if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) { | |||||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end()); | |||||
| opQueue.push(std::move(postNode)); | |||||
| while (!op_queue.empty()) { | |||||
| auto &node = op_queue.front(); | |||||
| auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get())); | |||||
| for (auto post_node_idx : post_node_idxes) { | |||||
| if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) { | |||||
| auto &post_node = old_nodes.at(post_node_idx); | |||||
| // check if post_node is non-depended | |||||
| if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) { | |||||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(), | |||||
| post_node->outputIndex.end()); | |||||
| op_queue.push(std::move(post_node)); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| new_nodes.emplace_back(std::move(node)); | |||||
| new_subgraph_node_indices.push_back(new_nodes.size() - 1); | |||||
| op_queue.pop(); | |||||
| } | } | ||||
| newNodes.emplace_back(std::move(node)); | |||||
| opQueue.pop(); | |||||
| graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices); | |||||
| } | } | ||||
| if (newNodes.size() != oldNodes.size()) { | |||||
| MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size() | |||||
| << ", newNodesSize: " << newNodes.size(); | |||||
| if (new_nodes.size() != old_nodes.size()) { | |||||
| MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size() | |||||
| << ", new_nodes size: " << new_nodes.size(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| graph->nodes.swap(newNodes); | |||||
| graph->nodes.swap(new_nodes); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, | bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, | ||||
| const std::vector<size_t> &sinkedTensorIdxes) { | |||||
| const std::vector<size_t> &sinked_tensor_idxes) { | |||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| for (auto inputIdx : node->inputIndex) { | |||||
| if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) { | |||||
| return false; | |||||
| } | |||||
| if (node->primitive->value.type == schema::PrimitiveType_Merge) { | |||||
| auto node_input_index = node->inputIndex; | |||||
| MS_ASSERT(node_input_index.size() % 2 == 0); | |||||
| return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2, | |||||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) || | |||||
| std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(), | |||||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }); | |||||
| } else { | |||||
| return std::all_of(node->inputIndex.begin(), node->inputIndex.end(), | |||||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); }); | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| return func_graph_ptr_; | return func_graph_ptr_; | ||||
| } | } | ||||
| @@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st | |||||
| MS_LOG(ERROR) << "convert graph outputs failed."; | MS_LOG(ERROR) << "convert graph outputs failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| return func_graph_ptr_; | return func_graph_ptr_; | ||||
| } | } | ||||
| @@ -61,7 +61,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_Mul; | primitive->value.type = schema::PrimitiveType_Mul; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } else if (tf_op.op() == "Div") { | |||||
| } else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") { | |||||
| auto attr = std::make_unique<schema::DivT>(); | auto attr = std::make_unique<schema::DivT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new attr failed"; | MS_LOG(ERROR) << "new attr failed"; | ||||
| @@ -154,6 +154,7 @@ TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); | TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); | TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); | TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfRealDivParser("RealDiv", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser()); | TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser()); | TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); | TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); | ||||
| @@ -37,10 +37,11 @@ static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { | |||||
| AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { | AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { | ||||
| AnfNodePtr ret = nullptr; | AnfNodePtr ret = nullptr; | ||||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||||
| ret = anf_node_map.at(name); | |||||
| auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); | |||||
| if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { | |||||
| ret = anf_node_map.at(flat_anf_name); | |||||
| } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | ||||
| ret = anf_node_map.at(name + ":0"); | |||||
| ret = anf_node_map.at(flat_anf_name + ":0"); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -212,6 +213,17 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| return status; | return status; | ||||
| } | } | ||||
| } else if (type == kObjectTypeString) { | |||||
| auto tensor_data = new (std::nothrow) string; | |||||
| if (tensor_proto.string_val_size() == 1) { | |||||
| string value = tensor_proto.string_val(0); | |||||
| *tensor_data = value; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "string size bigger than one, not support."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| tensor_size = (*tensor_data).size(); | |||||
| param_value->SetTensorData(tensor_data, tensor_size); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | MS_LOG(ERROR) << "Unsupport dataType: " << type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -318,6 +330,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | for (int i = 0; i < tf_root_graph_->node_size(); i++) { | ||||
| auto &node_def = tf_root_graph_->node(i); | auto &node_def = tf_root_graph_->node(i); | ||||
| @@ -364,7 +377,6 @@ STATUS TFModelParser::ConvertSubgraph() { | |||||
| std::map<CNodePtr, FuncGraphPtr> while_cond_map; | std::map<CNodePtr, FuncGraphPtr> while_cond_map; | ||||
| std::map<CNodePtr, FuncGraphPtr> while_body_map; | std::map<CNodePtr, FuncGraphPtr> while_body_map; | ||||
| for (int i = 0; i < subgraph_size; i++) { | for (int i = 0; i < subgraph_size; i++) { | ||||
| std::vector<ParameterPtr> sub_graph_inputs; | |||||
| auto &tf_sub_fuction = graph_def_liarary.function(i); | auto &tf_sub_fuction = graph_def_liarary.function(i); | ||||
| auto &tf_sub_signature = tf_sub_fuction.signature(); | auto &tf_sub_signature = tf_sub_fuction.signature(); | ||||
| auto input_arg_size = tf_sub_signature.input_arg_size(); | auto input_arg_size = tf_sub_signature.input_arg_size(); | ||||
| @@ -381,13 +393,17 @@ STATUS TFModelParser::ConvertSubgraph() { | |||||
| } | } | ||||
| FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | ||||
| sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); | |||||
| std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | ||||
| // convert sub graph inputs | // convert sub graph inputs | ||||
| std::vector<ParameterPtr> sub_graph_inputs; | |||||
| for (int j = 0; j < input_arg_size; j++) { | for (int j = 0; j < input_arg_size; j++) { | ||||
| auto &input_arg = tf_sub_signature.input_arg(j); | auto &input_arg = tf_sub_signature.input_arg(j); | ||||
| auto paramter = sub_func_graph->add_parameter(); | auto paramter = sub_func_graph->add_parameter(); | ||||
| paramter->set_name(input_arg.name()); | paramter->set_name(input_arg.name()); | ||||
| anf_sub_node_map[input_arg.name()] = paramter; | anf_sub_node_map[input_arg.name()] = paramter; | ||||
| auto root_while_inputs = while_cnode->inputs(); | |||||
| paramter->set_abstract(root_while_inputs[j + 1]->abstract()); | |||||
| sub_graph_inputs.emplace_back(paramter); | sub_graph_inputs.emplace_back(paramter); | ||||
| } | } | ||||
| std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map; | std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map; | ||||
| @@ -452,8 +468,19 @@ STATUS TFModelParser::ConvertSubgraph() { | |||||
| } | } | ||||
| // hardcode subgraph inputs name | // hardcode subgraph inputs name | ||||
| for (size_t j = 0; j < sub_graph_inputs.size(); j++) { | for (size_t j = 0; j < sub_graph_inputs.size(); j++) { | ||||
| sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); | |||||
| sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter"); | |||||
| } | |||||
| // hardcode subgraph outputs name | |||||
| for (size_t j = 1; j < sub_output_nodes.size(); j++) { | |||||
| if (utils::isa<CNodePtr>(sub_output_nodes[j])) { | |||||
| sub_output_nodes[j]->cast<CNodePtr>()->set_fullname_with_scope(sub_graph_name + "_output_" + | |||||
| std::to_string(j - 1) + "_cnode"); | |||||
| } else if (utils::isa<ParameterPtr>(sub_output_nodes[j])) { | |||||
| sub_output_nodes[j]->cast<ParameterPtr>()->set_name(sub_graph_name + "_output_" + std::to_string(j - 1) + | |||||
| "_parameter"); | |||||
| } | |||||
| } | } | ||||
| MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; | MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; | ||||
| } | } | ||||
| auto status = WhileNodePostProcess(while_cond_map, while_body_map); | auto status = WhileNodePostProcess(while_cond_map, while_body_map); | ||||
| @@ -469,9 +496,8 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr | |||||
| MS_LOG(ERROR) << "while cond body size error"; | MS_LOG(ERROR) << "while cond body size error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<FuncGraphPtr> roots = {anf_root_graph_}; | |||||
| auto root_func_manager = std::make_shared<FuncGraphManager>(roots); | |||||
| anf_root_graph_->set_manager(root_func_manager); | |||||
| static auto root_func_manager = Manage(anf_root_graph_); | |||||
| for (auto &kv : while_cond_map) { | for (auto &kv : while_cond_map) { | ||||
| auto while_node = kv.first; | auto while_node = kv.first; | ||||
| auto &cond_sub_graph = kv.second; | auto &cond_sub_graph = kv.second; | ||||
| @@ -633,6 +659,11 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { | |||||
| for (auto &pair : tf_root_graph_nodes_) { | for (auto &pair : tf_root_graph_nodes_) { | ||||
| for (int i = 0; i < pair.second->input_size(); ++i) { | for (int i = 0; i < pair.second->input_size(); ++i) { | ||||
| all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); | all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); | ||||
| auto input_name = pair.second->input(i); | |||||
| if (input_name[0] == '^') { | |||||
| input_name.erase(0, 1); | |||||
| } | |||||
| all_node_inputs.insert(input_name); | |||||
| } | } | ||||
| } | } | ||||
| for (auto &pair : tf_root_graph_nodes_) { | for (auto &pair : tf_root_graph_nodes_) { | ||||
| @@ -644,7 +675,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { | |||||
| auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); | auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); | ||||
| auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); | auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); | ||||
| if (anf_node == nullptr) { | if (anf_node == nullptr) { | ||||
| MS_LOG(ERROR) << "can't find anf node"; | |||||
| MS_LOG(ERROR) << "can't find anf node: " << origin_name; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| output_nodes.push_back(anf_node); | output_nodes.push_back(anf_node); | ||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_ragged_range_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF RaggedRangeParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::RangeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The starts attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->start = static_cast<int32_t>(attr_value.i()); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The limits attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->limit = static_cast<int32_t>(attr_value.i()); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The deltas attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->delta = static_cast<int32_t>(attr_value.i()); | |||||
| primitive->value.type = schema::PrimitiveType_Range; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGFED_RANGE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGGED_RANGE_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFRaggedRangeParser : public TFNodeParser { | |||||
| public: | |||||
| TFRaggedRangeParser() = default; | |||||
| ~TFRaggedRangeParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_range_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF RangeParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::RangeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The start attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->start = static_cast<int32_t>(attr_value.i()); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The limit attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->limit = static_cast<int32_t>(attr_value.i()); | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The delta attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->delta = static_cast<int32_t>(attr_value.i()); | |||||
| primitive->value.type = schema::PrimitiveType_Range; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFRangeParser : public TFNodeParser { | |||||
| public: | |||||
| TFRangeParser() = default; | |||||
| ~TFRangeParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ | |||||
| @@ -9,7 +9,7 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, software | * Unless required by applicable law or agreed to in writing, software | ||||
| * distributed under the License is distributed on an "AS IS" BASIS, | * distributed under the License is distributed on an "AS IS" BASIS, | ||||
| * WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| @@ -18,8 +18,8 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string_view> | #include <string_view> | ||||
| #include <unordered_map> | |||||
| #include <regex> | #include <regex> | ||||
| #include <unordered_map> | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| @@ -76,6 +76,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| func_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| return func_graph_; | return func_graph_; | ||||
| } | } | ||||
| @@ -21,7 +21,22 @@ | |||||
| #include "mindspore/lite/src/ops/primitive_c.h" | #include "mindspore/lite/src/ops/primitive_c.h" | ||||
| #include "tools/anf_importer/import_from_meta_graphT.h" | #include "tools/anf_importer/import_from_meta_graphT.h" | ||||
| using mindspore::lite::RET_INFER_INVALID; | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||||
| auto para_value_lite = std::make_shared<ParamValueLite>(); | |||||
| if (para_value_lite == nullptr) { | |||||
| MS_LOG(ERROR) << "new ParamValueLite failed"; | |||||
| return nullptr; | |||||
| } | |||||
| para_value_lite->set_tensor_shape(tensor->shape()); | |||||
| para_value_lite->set_tensor_type(tensor->data_type()); | |||||
| para_value_lite->set_format(tensor->format()); | |||||
| return para_value_lite; | |||||
| } | |||||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | ||||
| MS_ASSERT(nullptr != tensor); | MS_ASSERT(nullptr != tensor); | ||||
| std::vector<int> shape(tensor->shape()); | std::vector<int> shape(tensor->shape()); | ||||
| @@ -33,15 +48,30 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li | |||||
| MS_LOG(ERROR) << "new AbstractTensor failed"; | MS_LOG(ERROR) << "new AbstractTensor failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto new_value = std::make_shared<ParamValueLite>(); | |||||
| if (new_value == nullptr) { | |||||
| auto para_value_lite = NewParamValueLitePtr(tensor); | |||||
| if (para_value_lite == nullptr) { | |||||
| MS_LOG(ERROR) << "new ParamValueLite failed"; | MS_LOG(ERROR) << "new ParamValueLite failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| new_value->set_tensor_shape(tensor->shape()); | |||||
| new_value->set_tensor_type(tensor->data_type()); | |||||
| new_value->set_format(tensor->format()); | |||||
| new_abstract->set_value(new_value); | |||||
| if (type_id == kObjectTypeTensorType) { | |||||
| auto tensor_list = dynamic_cast<lite::TensorList *>(tensor); | |||||
| if (tensor_list == nullptr) { | |||||
| MS_LOG(ERROR) << "cast tensor_list failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_info = new int[tensor_list->element_shape().size() + 2]; | |||||
| tensor_info[0] = tensor_list->tensors_data_type(); | |||||
| tensor_info[1] = tensor_list->element_shape().size(); | |||||
| for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) { | |||||
| tensor_info[i + 2] = tensor_list->element_shape()[i]; | |||||
| } | |||||
| para_value_lite->set_tensor_addr(tensor_info); | |||||
| para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2); | |||||
| } | |||||
| new_abstract->set_value(para_value_lite); | |||||
| return new_abstract; | return new_abstract; | ||||
| } | } | ||||
| @@ -121,13 +151,13 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||||
| } | } | ||||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | if (utils::isa<ValueNodePtr>(cnode->input(i))) { | ||||
| MS_LOG(WARNING) << "input is value node"; | |||||
| MS_LOG(WARNING) << cnode->fullname_with_scope() << "'s input[" << i << "] is value node"; | |||||
| continue; | continue; | ||||
| } | } | ||||
| AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); | AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); | ||||
| if (abstract == nullptr) { | if (abstract == nullptr) { | ||||
| MS_LOG(ERROR) << "Abstract of CNode is nullptr"; | |||||
| MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | ||||
| @@ -194,7 +224,7 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< | |||||
| MS_ASSERT(output_tensors != nullptr); | MS_ASSERT(output_tensors != nullptr); | ||||
| auto abstract = cnode->abstract(); | auto abstract = cnode->abstract(); | ||||
| if (abstract == nullptr) { | if (abstract == nullptr) { | ||||
| MS_LOG(ERROR) << "abstract is nullptr"; | |||||
| MS_LOG(ERROR) << "node " << cnode->fullname_with_scope() << " abstract is nullptr"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<TypeId> types; | std::vector<TypeId> types; | ||||
| @@ -264,7 +294,62 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int InferShapePass::StrIsContain(const std::vector<std::string> &total, const std::string &aim) { | |||||
| for (size_t i = 0; i < total.size(); i++) { | |||||
| if (aim.find(total[i]) != std::string::npos) { | |||||
| return i; | |||||
| } | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph) { | |||||
| // hard code construct input parameter name | |||||
| std::vector<std::string> inputs_names{}; | |||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||||
| inputs_names.emplace_back("_input_" + std::to_string(i - 1) + "_parameter"); | |||||
| } | |||||
| // copy cnode input to func_graph input | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (utils::isa<ParameterPtr>(node)) { | |||||
| auto pos = StrIsContain(inputs_names, node->fullname_with_scope()); | |||||
| if (pos != -1) { | |||||
| auto pnode = utils::cast<ParameterPtr>(node); | |||||
| auto input_pnode = utils::cast<ParameterPtr>(cnode->input(pos + 1)); | |||||
| MS_ASSERT(pnode != nullptr); | |||||
| pnode->set_abstract(input_pnode->abstract()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { | |||||
| auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>(); | |||||
| MS_ASSERT(body_partial_cnode != nullptr); | |||||
| auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(body_vnode != nullptr); | |||||
| auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode); | |||||
| MS_ASSERT(body_fg != nullptr); | |||||
| AbstractBasePtrList abstract_list; | |||||
| auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output()); | |||||
| for (auto &cnode : body_fg_output_cnode->inputs()) { | |||||
| if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) { | |||||
| continue; | |||||
| } | |||||
| abstract_list.push_back(cnode->abstract()); | |||||
| } | |||||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| return RET_OK; | |||||
| } | |||||
| bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | ||||
| if (func_graph->has_flag("HasInferShaped")) { | |||||
| return true; | |||||
| } | |||||
| if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | ||||
| MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | ||||
| return false; | return false; | ||||
| @@ -287,8 +372,14 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0)); | auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0)); | ||||
| if (origin_primc == nullptr) { | if (origin_primc == nullptr) { | ||||
| MS_LOG(ERROR) << "origin_primc is nullptr"; | |||||
| return false; | |||||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||||
| if (sub_func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "node " << node->fullname_with_scope() << "'s origin_primc is nullptr"; | |||||
| return false; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "subgraph infer shape invalid."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| } | } | ||||
| auto origin_primt = origin_primc->primitiveT(); | auto origin_primt = origin_primc->primitiveT(); | ||||
| if (origin_primt == nullptr) { | if (origin_primt == nullptr) { | ||||
| @@ -296,6 +387,15 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto type = GetCNodeType(cnode); | auto type = GetCNodeType(cnode); | ||||
| if (type == schema::PrimitiveType_Switch) { | |||||
| int ret = SwitchCNodeInferShape(cnode); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "PartialCNodeInferShape failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if ((type == schema::PrimitiveType_TupleGetItem) || | if ((type == schema::PrimitiveType_TupleGetItem) || | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | ||||
| @@ -41,6 +41,9 @@ class InferShapePass : public Pass { | |||||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | ||||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | STATUS SetParameterAbstract(const ParameterPtr ¶meter); | ||||
| STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | ||||
| STATUS SwitchCNodeInferShape(const CNodePtr &cnode); | |||||
| int StrIsContain(const std::vector<std::string> &total, const std::string &aim); | |||||
| int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); | |||||
| private: | private: | ||||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | FmkType fmk_type = lite::converter::FmkType_ONNX; | ||||
| @@ -0,0 +1,181 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/graph/while_pass.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include "mindspore/lite/include/errorcode.h" | |||||
| #include "mindspore/lite/src/ops/primitive_c.h" | |||||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/tensor.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/ops/switch.h" | |||||
| #include "src/ops/partial.h" | |||||
| namespace mindspore::opt { | |||||
| ValueNodePtr WhilePass::GetSwitchAnfPrim() { | |||||
| auto switch_primitiveT = new (std::nothrow) schema::PrimitiveT; | |||||
| if (switch_primitiveT == nullptr) { | |||||
| MS_LOG(ERROR) << "new switch_primitiveT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| switch_primitiveT->value.type = schema::PrimitiveType_Switch; | |||||
| switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT; | |||||
| if (switch_primitiveT->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new MakeTupleT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT); | |||||
| ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | |||||
| return partial_anf_prim; | |||||
| } | |||||
| void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, | |||||
| std::string para_name) { | |||||
| for (auto &node : node_list) { | |||||
| if (utils::isa<CNodePtr>(node)) { | |||||
| auto cnode = utils::cast<CNodePtr>(node); | |||||
| for (size_t k = 0; k < cnode->inputs().size(); k++) { | |||||
| if (!utils::isa<ParameterPtr>(cnode->input(k))) { | |||||
| continue; | |||||
| } | |||||
| auto para_input = utils::cast<ParameterPtr>(cnode->input(k)); | |||||
| if (para_input->name() == para_name) { | |||||
| cnode->set_input(k, new_input_cnode); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool WhilePass::Run(const FuncGraphPtr &graph) { | |||||
| auto node_list = TopoSort(graph->get_return()); | |||||
| static int count = 0; | |||||
| for (auto &node : node_list) { | |||||
| if (!utils::isa<CNodePtr>(node)) { | |||||
| continue; | |||||
| } | |||||
| if (opt::GetCNodeType(node) != schema::PrimitiveType_While) { | |||||
| continue; | |||||
| } | |||||
| auto while_cnode = node->cast<CNodePtr>(); | |||||
| MS_ASSERT(while_cnode != nullptr); | |||||
| if (while_cnode->inputs().size() < kWhileMinInputSize) { | |||||
| MS_LOG(ERROR) << "while input is not right."; | |||||
| return false; | |||||
| } | |||||
| // the order is fixed. | |||||
| auto cond_vnode = while_cnode->input(kWhileCondIndex); | |||||
| auto body_vnode = while_cnode->input(kWhileBodyIndex); | |||||
| // body_vnode->cast<ValueNodePtr>()->set_value() | |||||
| auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode); | |||||
| auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode); | |||||
| if (cond_fg == nullptr || body_fg == nullptr) { | |||||
| MS_LOG(ERROR) << "Get value as func_graph failed."; | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); | |||||
| return false; | |||||
| } | |||||
| // create cond partial cnode | |||||
| std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode}; | |||||
| // create body partial cnode | |||||
| std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode}; | |||||
| // add while op input to cond_cnode and body_cnode | |||||
| cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | |||||
| while_cnode->inputs().end()); | |||||
| body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | |||||
| while_cnode->inputs().end()); | |||||
| static int idx = 0; | |||||
| auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); | |||||
| cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); | |||||
| cond_partial_node->set_abstract(cond_fg->output()->abstract()); | |||||
| auto body_partial_node = graph->NewCNode(body_partial_op_inputs); | |||||
| body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); | |||||
| idx++; | |||||
| // concat body_fg output to cond_fg input | |||||
| auto body_output = body_fg->output(); | |||||
| auto body_output_cnode = utils::cast<CNodePtr>(body_output); | |||||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(body_output_cnode->input(0)); | |||||
| if (prim == nullptr) { | |||||
| MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed."; | |||||
| return false; | |||||
| } | |||||
| // concat body to cond | |||||
| std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode}; | |||||
| if ((schema::PrimitiveType)(prim->Type()) == schema::PrimitiveType_MakeTuple) { | |||||
| for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { | |||||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | |||||
| } | |||||
| } else { | |||||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); | |||||
| } | |||||
| // concat body to cond | |||||
| auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs); | |||||
| body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond"); | |||||
| auto body_fg_manager = body_fg->manager(); | |||||
| body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode); | |||||
| body_fg->set_output(body_to_cond_cnode); | |||||
| body_partial_node->set_abstract(cond_fg->output()->abstract()); | |||||
| // create switch cnode | |||||
| ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim(); | |||||
| if (switch_anf_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "GetSwitchAnfPrim failed."; | |||||
| return false; | |||||
| } | |||||
| // insert switch node | |||||
| std::vector<AnfNodePtr> switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node}; | |||||
| auto switch_cnode = graph->NewCNode(switch_op_inputs); | |||||
| switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++)); | |||||
| AbstractBasePtrList abstract_list; | |||||
| auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output()); | |||||
| for (auto &cnode : body_fg_output_cnode->inputs()) { | |||||
| if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) { | |||||
| continue; | |||||
| } | |||||
| abstract_list.push_back(cnode->abstract()); | |||||
| } | |||||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| // create cond partial cnode | |||||
| auto manager = graph->manager(); | |||||
| auto node_users = manager->node_users()[while_cnode]; | |||||
| for (auto &node_user : node_users) { | |||||
| manager->SetEdge(node_user.first, node_user.second, switch_cnode); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "tools/converter/converter_flags.h" | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "src/param_value_lite.h" | |||||
| using mindspore::lite::converter::FmkType; | |||||
| namespace mindspore::opt { | |||||
| class WhilePass : public Pass { | |||||
| public: | |||||
| WhilePass() : Pass("while_pass") {} | |||||
| ~WhilePass() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name); | |||||
| ValueNodePtr GetSwitchAnfPrim(); | |||||
| const size_t kWhileMinInputSize = 3; | |||||
| const size_t kWhileCondIndex = 1; | |||||
| const size_t kWhileBodyIndex = 2; | |||||
| }; | |||||
| } // namespace mindspore::opt | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||||