| @@ -20,6 +20,7 @@ | |||
| #include <map> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/ops/assert_op.h" | |||
| #include "src/ops/space_to_batch.h" | |||
| #include "src/ops/space_to_batch_nd.h" | |||
| #include "src/ops/conv2d.h" | |||
| @@ -614,6 +615,13 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Sqrt>(prim, inputs, quantType); | |||
| } else if (op_type == "Greater") { | |||
| return NewPrimitiveC<Greater>(prim, inputs, quantType); | |||
| } else if (op_type == "Switch") { | |||
| return NewPrimitiveC<Switch>(prim, inputs, quantType); | |||
| } else if (op_type == "Partial") { | |||
| return NewPrimitiveC<Partial>(prim, inputs, quantType); | |||
| } else if (op_type == "Merge") { | |||
| return NewPrimitiveC<Merge>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||
| @@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) Merge(primitive); | |||
| case schema::PrimitiveType_Partial: | |||
| return new (std::nothrow) Partial(primitive); | |||
| case schema::PrimitiveType_Assert: | |||
| return new (std::nothrow) AssertOP(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| desc.type == schema::PrimitiveType_Nhwc2Nchw); | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "desc type is not Transpose"; | |||
| return nullptr; | |||
| @@ -200,6 +200,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/while_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -7,7 +7,7 @@ rcnn-ilsvrc13-9.onnx | |||
| mobilenetv2-7.onnx | |||
| shufflenet-v2-10.onnx | |||
| squeezenet1.1-7.onnx | |||
| densenet-9.onnx | |||
| #densenet-9.onnx | |||
| ml_table_detection_fp32.onnx | |||
| ml_table_segment.onnx | |||
| googlenet-9.onnx | |||
| @@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3 | |||
| super-resolution-10.onnx;1,224,224,1 | |||
| tinyyolov2-8.onnx;1,416,416,3 | |||
| ml_2012_ocr_cn.onnx | |||
| ml_2012_ocr_cn_noLSTM.onnx | |||
| #ml_2012_ocr_cn_noLSTM.onnx | |||
| candy-9.onnx | |||
| mosaic-9.onnx | |||
| pointilism-9.onnx | |||
| @@ -7,7 +7,7 @@ emotion-ferplus-8.onnx 1 | |||
| mobilenetv2-7.onnx 8 | |||
| shufflenet-v2-10.onnx 5 | |||
| squeezenet1.1-7.onnx 1 | |||
| densenet-9.onnx 6 | |||
| #densenet-9.onnx 6 | |||
| ml_table_detection_fp32.onnx 2 | |||
| ml_table_segment.onnx 2 | |||
| googlenet-9.onnx 3 | |||
| @@ -27,7 +27,7 @@ mnist-8.onnx 10 | |||
| #super-resolution-10.onnx 1 | |||
| #tinyyolov2-8.onnx 0.3 | |||
| ml_2012_ocr_cn.onnx 200 | |||
| ml_2012_ocr_cn_noLSTM.onnx 1 | |||
| #ml_2012_ocr_cn_noLSTM.onnx 1 | |||
| candy-9.onnx 5 | |||
| mosaic-9.onnx 4 | |||
| pointilism-9.onnx 3 | |||
| @@ -28,6 +28,8 @@ | |||
| #include "src/tensor.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/ops/partial.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore::lite { | |||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| @@ -73,7 +75,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||
| hasDepend = true; | |||
| bool maskOut = (dependNode->inputs().size() == 3) ? true : false; | |||
| bool maskOut = (dependNode->inputs().size() == 3); | |||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | |||
| AnfNodePtr dependInputNode = dependNode->input(j); | |||
| if (dependInputNode->isa<CNode>()) { | |||
| @@ -172,22 +174,50 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| return RET_OK; | |||
| } | |||
| void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| for (auto node : graph_input_nodes_) { | |||
| std::vector<schema::CNodeT *> AnfExporter::GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const size_t &subgraph_index) { | |||
| std::vector<schema::CNodeT *> subgraph_nodes{}; | |||
| subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size()); | |||
| std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(), | |||
| meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(), | |||
| [&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); }); | |||
| return subgraph_nodes; | |||
| } | |||
| int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const size_t &subgraph_index) { | |||
| auto &subgraph = meta_graphT->subGraph.at(subgraph_index); | |||
| auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index); | |||
| std::vector<schema::CNodeT *> subgraph_input_nodes{}; | |||
| for (auto &node : subgraph_nodes) { | |||
| if (IsContain(graph_input_nodes_, node)) { | |||
| subgraph_input_nodes.push_back(node); | |||
| } | |||
| } | |||
| std::vector<schema::TensorT *> subgraph_inputs{}; | |||
| for (auto &node : subgraph_input_nodes) { | |||
| for (auto input : node->inputIndex) { | |||
| auto tensor = meta_graphT->allTensors[input].get(); | |||
| if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) { | |||
| tensor->nodeType = schema::NodeType_ValueNode; | |||
| tensor->format = schema::Format_NHWC; | |||
| if (!IsContain(meta_graphT->inputIndex, input)) { | |||
| meta_graphT->inputIndex.emplace_back(input); | |||
| if (!IsContain(subgraph->inputIndices, input)) { | |||
| if (subgraph_index == kMainGraphIndex) { | |||
| meta_graphT->inputIndex.push_back(input); | |||
| } | |||
| subgraph->inputIndices.push_back(input); | |||
| subgraph_inputs.push_back(tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const std::unique_ptr<schema::SubGraphT> &sub_graphT, | |||
| schema::CNodeT *return_node) { | |||
| MS_ASSERT(nullptr != meta_graphT); | |||
| MS_ASSERT(nullptr != return_node); | |||
| @@ -202,28 +232,62 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt | |||
| MS_LOG(ERROR) << "obtain outputs failed"; | |||
| return ret; | |||
| } | |||
| } else if (input_node->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node"; | |||
| continue; | |||
| } else { | |||
| MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| for (unsigned int &i : return_node->inputIndex) { | |||
| meta_graphT->outputIndex.push_back(i); | |||
| if (subgraph_index == kMainGraphIndex) { | |||
| meta_graphT->outputIndex.push_back(i); | |||
| } | |||
| meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||
| int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const size_t &subgraph_index, bool keep_graph, bool copy_primitive, | |||
| const std::shared_ptr<AnfNode> &partial_anode) { | |||
| int ret = RET_OK; | |||
| meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>()); | |||
| auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); | |||
| auto subgraph_name = func_graph->get_attr("graph_name"); | |||
| MS_ASSERT(subgraph_name != nullptr); | |||
| sub_graphT->name = GetValue<std::string>(subgraph_name); | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| for (const auto &cnode : cnodes) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| ret = RET_MEMORY_FAILED; | |||
| break; | |||
| auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||
| if (fg != nullptr) { | |||
| auto partial_cnode = CreatePartialCnode(fg, cnode); | |||
| primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(partial_cnode->input(0)); | |||
| auto primT = primitive_c->primitiveT(); | |||
| auto pos = fg_subgraph_map.find(fg); | |||
| if (pos != fg_subgraph_map.end()) { | |||
| primT->value.AsPartial()->subGraphIndex = fg_subgraph_map.at(fg); | |||
| } else { | |||
| size_t next_subgraph_index = fg_subgraph_map.size() + 1; | |||
| fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index}); | |||
| primT->value.AsPartial()->subGraphIndex = next_subgraph_index; | |||
| ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ExportSubgraph failed"; | |||
| break; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| ret = RET_MEMORY_FAILED; | |||
| break; | |||
| } | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| RemoveIfMakeTuple(cnode); | |||
| RemoveIfDepend(cnode); | |||
| @@ -249,13 +313,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| } | |||
| if (primT->value.type == schema::PrimitiveType_Return) { | |||
| node->name = "return_node"; | |||
| ret = SetGraphoutputIndex(cnode, meta_graphT, node.get()); | |||
| ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetOpOutputN failed"; | |||
| break; | |||
| } | |||
| continue; | |||
| } | |||
| node->nodeType = schema::NodeType_CNode; | |||
| node->name = cnode->fullname_with_scope(); | |||
| if (copy_primitive) { | |||
| @@ -281,21 +346,45 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||
| if (!keep_graph) { | |||
| primitive_c->ClearPrimitiveT(); | |||
| } | |||
| meta_graphT->nodes.emplace_back(std::move(node)); | |||
| meta_graphT->nodes.push_back(std::move(node)); | |||
| meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); | |||
| } | |||
| if (ret != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return ret; | |||
| } | |||
| ret = SetGraphInputIndex(meta_graphT, subgraph_index); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphInputIndex failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return ret; | |||
| } | |||
| ret = SetSubgraphTensorIndices(meta_graphT.get()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||
| static int subgraph_index = 0; | |||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||
| int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); | |||
| if (ret != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return nullptr; | |||
| } | |||
| // set graph input tensors | |||
| SetGraphInputIndex(meta_graphT); | |||
| return meta_graphT.release(); | |||
| } | |||
| int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) { | |||
| std::string input_name = input_anode->fullname_with_scope(); | |||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | |||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | |||
| #ifndef SUPPORT_TRAIN | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| @@ -343,11 +432,11 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | |||
| iter = node_id_map_.find(input_index_key); | |||
| if (iter == node_id_map_.end()) { | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||
| return RET_ERROR; | |||
| } | |||
| #else | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key; | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||
| return RET_ERROR; | |||
| #endif | |||
| } | |||
| @@ -367,6 +456,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano | |||
| } | |||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||
| paramTensor->format = schema::Format_NHWC; | |||
| paramTensor->name = paramNode->name(); | |||
| auto abstractBase = paramNode->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | |||
| @@ -518,6 +608,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||
| } else if (value->isa<FuncGraph>()) { | |||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||
| return RET_ERROR; | |||
| @@ -644,6 +737,20 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| } | |||
| } | |||
| bool AnfExporter::HasPrimitiveCNode(const AnfNodePtr &node) { | |||
| MS_ASSERT(node != nullptr); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| auto prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { | |||
| MS_ASSERT(node != nullptr); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -658,6 +765,47 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType | |||
| return (schema::PrimitiveType)(prim->Type()) == type; | |||
| } | |||
| ValueNodePtr AnfExporter::GetPartialAnfPrim() { | |||
| auto partial_primitiveT = new (std::nothrow) schema::PrimitiveT; | |||
| if (partial_primitiveT == nullptr) { | |||
| MS_LOG(ERROR) << "new partial_primitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| partial_primitiveT->value.type = schema::PrimitiveType_Partial; | |||
| partial_primitiveT->value.value = new (std::nothrow) schema::PartialT; | |||
| if (partial_primitiveT->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new PartialT failed"; | |||
| return nullptr; | |||
| } | |||
| auto partial_prim = std::make_shared<lite::Partial>(partial_primitiveT); | |||
| ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | |||
| return partial_anf_prim; | |||
| } | |||
| CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node) { | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c != nullptr) { | |||
| return cnode; | |||
| } | |||
| auto partial_anf_prim_vnode = GetPartialAnfPrim(); | |||
| auto cnode_input = cnode->inputs(); | |||
| cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode); | |||
| cnode->set_inputs(cnode_input); | |||
| return cnode; | |||
| } else if (utils::isa<ValueNodePtr>(node)) { | |||
| auto partial_anf_prim_vnode = GetPartialAnfPrim(); | |||
| std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node}; | |||
| auto cnode = fg->NewCNode(inputs); | |||
| return cnode; | |||
| } else { | |||
| MS_LOG(ERROR) << "failed to create partial cnode."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||
| AnfExporter anf_exporter; | |||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive); | |||
| @@ -27,6 +27,10 @@ | |||
| #include "tools/converter/converter_context.h" | |||
| namespace mindspore::lite { | |||
| constexpr const int kPartialMinSize = 3; | |||
| constexpr const int kMainGraphIndex = 0; | |||
| class AnfExporter { | |||
| public: | |||
| AnfExporter() = default; | |||
| @@ -45,17 +49,28 @@ class AnfExporter { | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *return_node); | |||
| int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index); | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node); | |||
| static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||
| static bool HasPrimitiveCNode(const AnfNodePtr &node); | |||
| static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const size_t &subgraph_index, bool keep_graph, bool copy_primitive, | |||
| const std::shared_ptr<AnfNode> &partial_anode = nullptr); | |||
| ValueNodePtr GetPartialAnfPrim(); | |||
| CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode); | |||
| std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| const size_t &subgraph_index); | |||
| private: | |||
| std::map<std::string, int> node_id_map_; | |||
| std::vector<schema::CNodeT *> graph_input_nodes_; | |||
| std::map<FuncGraphPtr, int> fg_subgraph_map; | |||
| uint32_t node_idx = 0; | |||
| }; | |||
| // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | |||
| // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | |||
| @@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe | |||
| continue; | |||
| } | |||
| } | |||
| // update graph input indexes | |||
| // update graph input indices | |||
| for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update graph output indexes | |||
| // update graph output indices | |||
| for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| for (auto &subgraph : graphT->subGraph) { | |||
| // update subgraph input indices | |||
| for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) { | |||
| if (*idx > deleteIdx) { | |||
| (*idx)--; | |||
| } | |||
| } | |||
| } | |||
| // update nodes indexes | |||
| for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | |||
| // update nodes input indexes | |||
| @@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) { | |||
| modelName = modelName.substr(0, modelName.find_last_of('.')); | |||
| return modelName; | |||
| } | |||
| int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) { | |||
| for (auto &subgraph : meta_graphT->subGraph) { | |||
| std::vector<uint32_t> subgraph_indices{}; | |||
| for (auto &node_idx : subgraph->nodeIndices) { | |||
| auto &node = meta_graphT->nodes.at(node_idx); | |||
| for (auto &input_idx : node->inputIndex) { | |||
| if (IsContain(subgraph_indices, input_idx)) { | |||
| continue; | |||
| } else { | |||
| subgraph_indices.push_back(input_idx); | |||
| } | |||
| } | |||
| for (auto &output_idx : node->outputIndex) { | |||
| if (IsContain(subgraph_indices, output_idx)) { | |||
| continue; | |||
| } else { | |||
| subgraph_indices.push_back(output_idx); | |||
| } | |||
| } | |||
| } | |||
| subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo | |||
| STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT); | |||
| std::string GetModelName(const std::string &modelFile); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/slice_prepose_pass.cc | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| ../optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ../optimizer/graph/while_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -42,6 +42,7 @@ | |||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/infershape_pass.h" | |||
| #include "tools/optimizer/graph/slice_prepose_pass.h" | |||
| #include "tools/optimizer/graph/while_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default; | |||
| AnfTransform::~AnfTransform() = default; | |||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| MS_ASSERT(nullptr != old_graph); | |||
| if (config == nullptr) { | |||
| MS_LOG(ERROR) << "config shoud be specified"; | |||
| MS_LOG(ERROR) << "config should be specified"; | |||
| return nullptr; | |||
| } | |||
| if (old_graph->has_flag("HasTransformed")) { | |||
| old_graph->set_flag("HasTransformed", false); | |||
| return old_graph; | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| // mindir pre adjustment | |||
| if (config->fmk == converter::FmkType_MS) { | |||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | |||
| mindir_adjust_pass->SetFmkType(config->fmk); | |||
| @@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| } | |||
| } | |||
| // for now - trainning is not supporting fuse operations | |||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) { | |||
| auto while_pass = std::make_shared<opt::WhilePass>(); | |||
| graph_pm->AddPass(while_pass); | |||
| } | |||
| // for now - training is not supporting fuse operations | |||
| if (!config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| @@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| return nullptr; | |||
| } | |||
| } | |||
| return new_graph; | |||
| } | |||
| STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, | |||
| std::vector<ValueNodePtr> *vnodes) { | |||
| auto nodes = TopoSort(main_graph->get_return()); | |||
| for (auto &node : nodes) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(node); | |||
| if (fg) { | |||
| vnodes->push_back(utils::cast<ValueNodePtr>(node)); | |||
| subgraphs->push_back(fg); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { | |||
| // transform main_graph | |||
| auto new_main_graph = TransformSingleFuncGraph(main_graph, config); | |||
| if (new_main_graph == nullptr) { | |||
| MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| // transform sub_graph | |||
| FuncGraphPtrList subgraphs{}; | |||
| std::vector<ValueNodePtr> vnodes{}; | |||
| int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| return nullptr; | |||
| } | |||
| for (size_t i = 0; i < subgraphs.size(); i++) { | |||
| auto new_graph = Transform(subgraphs.at(i), config); | |||
| new_graph->set_flag("HasTransformed", true); | |||
| vnodes.at(i)->set_value(new_graph); | |||
| } | |||
| return new_main_graph; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/common/storage.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| @@ -34,6 +35,9 @@ class AnfTransform { | |||
| FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | |||
| private: | |||
| STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs, | |||
| std::vector<ValueNodePtr> *vnodes); | |||
| FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | |||
| std::unique_ptr<quant::Quantizer> mQuantizer = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| @@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| int status = modelImporter->Import(flag); | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| graph = modelImporter->GetResult(); | |||
| graph->set_attr("graph_name", MakeValue("main_graph")); | |||
| } else { | |||
| MS_ASSERT(nullptr != modelParser); | |||
| const std::string modelFile = flag->modelFile; | |||
| @@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | |||
| return nullptr; | |||
| } | |||
| // transform | |||
| transform->SetGraphDef(meta_graph); | |||
| auto status = transform->Transform(*flag); | |||
| @@ -16,6 +16,7 @@ | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "schema/model_generated.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| @@ -37,9 +38,21 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/switch_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" | |||
| using std::string; | |||
| namespace mindspore::lite { | |||
| std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { | |||
| std::vector<schema::CNodeT *> old_nodes{}; | |||
| old_nodes.resize(graphDefT->nodes.size()); | |||
| std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(), | |||
| [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); | |||
| return old_nodes; | |||
| } | |||
| GraphDefTransform::GraphDefTransform() = default; | |||
| GraphDefTransform::~GraphDefTransform() = default; | |||
| @@ -48,141 +61,232 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| { | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||
| if (!ctx.trainModel) { | |||
| unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); | |||
| if (ctx.fmk != converter::FmkType_TF) { | |||
| { | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||
| if (!ctx.trainModel) { | |||
| unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); | |||
| } | |||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||
| return status; | |||
| // topological sorting | |||
| { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| // topological sorting | |||
| { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| // generate and infer quant parameters | |||
| { | |||
| Optimizer inferQuantParamPass; | |||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamPass.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| // generate and infer quant parameters | |||
| { | |||
| Optimizer inferQuantParamPass; | |||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamPass.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| // postconvert pass | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer fusionOptimizer; | |||
| if (!ctx.trainModel) { | |||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||
| if (batch_norm_scale_pass == nullptr) { | |||
| MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; | |||
| return RET_ERROR; | |||
| } | |||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||
| fusionOptimizer.AddPass(batch_norm_scale_pass); | |||
| } | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| // postconvert pass | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| if (!ctx.trainModel) { | |||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||
| if (batch_norm_scale_pass == nullptr) { | |||
| MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; | |||
| return RET_ERROR; | |||
| } | |||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||
| fusionOptimizer.AddPass(batch_norm_scale_pass); | |||
| // format transform | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| formatTransPass->SetQuantType(ctx.quantType); | |||
| formatTransPass->SetFmk(ctx.fmk); | |||
| formatTransOptimizer.AddPass(formatTransPass); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; | |||
| return status; | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| // format transform | |||
| { | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| formatTransPass->SetQuantType(ctx.quantType); | |||
| formatTransPass->SetFmk(ctx.fmk); | |||
| formatTransOptimizer.AddPass(formatTransPass); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| return status; | |||
| // do quantization | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer tensorQuantOptimizer; | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = tensorQuantOptimizer.Run(graphDefT); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||
| return status; | |||
| // insert quantNode and deQuantNode | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| auto old_nodes2 = GetGraphNodes(); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| // do quantization | |||
| // switch pass | |||
| { | |||
| Optimizer tensorQuantOptimizer; | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| status = tensorQuantOptimizer.Run(graphDefT); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer switchOptimizer; | |||
| switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); | |||
| switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = switchOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run switch graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // insert quantNode and deQuantNode | |||
| // subgraph tensor pass | |||
| { | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| Optimizer subgraphTensorOptimizer; | |||
| subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||
| status = subgraphTensorOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // tensor name | |||
| { | |||
| // init old node indecies | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer nameOptimizer; | |||
| nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||
| status = nameOptimizer.Run(graphDefT); | |||
| @@ -192,16 +296,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| // topological sorting | |||
| { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| } // namespace mindspore::lite | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -39,6 +40,7 @@ class GraphDefTransform { | |||
| inline schema::MetaGraphT *GetOutput() { return graphDefT; } | |||
| protected: | |||
| std::vector<schema::CNodeT *> GetGraphNodes(); | |||
| schema::MetaGraphT *graphDefT = nullptr; | |||
| Optimizer *optimizer = nullptr; | |||
| }; | |||
| @@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc | |||
| ) | |||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { | |||
| for (auto &subgraph : graph->subGraph) { | |||
| for (auto &idx : subgraph->nodeIndices) { | |||
| if (idx > node_idx) { | |||
| idx--; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::vector<schema::CNodeT *> new_nodes{}; | |||
| std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes), | |||
| [](std::unique_ptr<CNodeT> &node) { return node.get(); }); | |||
| for (auto it = old_nodes_.begin(); it != old_nodes_.end();) { | |||
| if (!IsContain(new_nodes, *it)) { | |||
| size_t node_idx = it - old_nodes_.begin(); | |||
| for (auto &subgraph : graph->subGraph) { | |||
| auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx); | |||
| if (node_idx_pos != subgraph->nodeIndices.end()) { | |||
| subgraph->nodeIndices.erase(node_idx_pos); | |||
| UpdateSubgraphNodeIndices(node_idx, graph); | |||
| break; | |||
| } | |||
| } | |||
| it = old_nodes_.erase(it); | |||
| } else { | |||
| it++; | |||
| } | |||
| } | |||
| for (uint32_t i = 0; i < new_nodes.size(); i++) { | |||
| if (!IsContain(old_nodes_, new_nodes[i])) { | |||
| for (auto &subgraph : graph->subGraph) { | |||
| if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) { | |||
| subgraph->nodeIndices.push_back(old_nodes_.size()); | |||
| old_nodes_.push_back(new_nodes[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H | |||
| #define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "tools/converter/optimizer.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class SubgraphNodePass : public GraphPass { | |||
| public: | |||
| explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {} | |||
| ~SubgraphNodePass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| private: | |||
| void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); | |||
| std::vector<schema::CNodeT *> old_nodes_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { | |||
| for (const auto &node : graph->nodes) { | |||
| if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) { | |||
| return true; | |||
| } | |||
| if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) { | |||
| for (const auto &subgraph : graph->subGraph) { | |||
| UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx); | |||
| UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx); | |||
| } | |||
| for (const auto &node : graph->nodes) { | |||
| UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx); | |||
| UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx); | |||
| } | |||
| UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx); | |||
| UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx); | |||
| return RET_OK; | |||
| } | |||
| STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) { | |||
| for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) { | |||
| uint32_t idx = it - graph->allTensors.begin(); | |||
| if (IsUsing(graph, idx)) { | |||
| it++; | |||
| } else { | |||
| it = graph->allTensors.erase(it); | |||
| UpdateTensorIdx(graph, idx); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph->subGraph.size() > 0); | |||
| graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end()); | |||
| graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end()); | |||
| return RET_OK; | |||
| } | |||
| STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| int ret = RemoveUselessTensors(graph); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| ret = SetSubgraphTensorIndices(graph); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| ret = SyncMainGraphInputAndOutput(graph); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H | |||
| #define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "tools/converter/optimizer.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class SubgraphTensorPass : public GraphPass { | |||
| public: | |||
| SubgraphTensorPass() = default; | |||
| ~SubgraphTensorPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| private: | |||
| STATUS RemoveUselessTensors(schema::MetaGraphT *graph); | |||
| bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx); | |||
| STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx); | |||
| STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph); | |||
| template <typename T> | |||
| void UpdateVec(std::vector<T> *vec, T element) { | |||
| for (auto iter = vec->begin(); iter != vec->end(); iter++) { | |||
| if (*iter > element) { | |||
| (*iter)--; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H | |||
| @@ -16,6 +16,7 @@ | |||
| #include <vector> | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include "tools/converter/legacy_optimizer/graph/switch_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| @@ -96,38 +97,6 @@ std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_p | |||
| return out_tensor; | |||
| } | |||
| STATUS SingleSwitchPass::MoveMaxIterationToCond() { | |||
| auto &body_subgraph_input = graph_->subGraph.at(body_subgraph_index_)->inputIndices; | |||
| for (auto it = body_subgraph_input.begin(); it != body_subgraph_input.end();) { | |||
| if (!body_to_cond_partial_node_->inputIndex.empty() && IsContain(body_to_cond_partial_node_->inputIndex, *it)) { | |||
| int32_t max_iteration_idx = it - body_subgraph_input.begin(); | |||
| // get maxiteration tensor | |||
| auto &max_iteration_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex.at(max_iteration_idx)); | |||
| auto all_tensor_idx = std::find(graph_->allTensors.begin(), graph_->allTensors.end(), max_iteration_tensor) - | |||
| graph_->allTensors.begin(); | |||
| // remove maxiteration from body_to_cond partial node | |||
| body_to_cond_partial_node_->inputIndex.erase(body_to_cond_partial_node_->inputIndex.begin() + max_iteration_idx); | |||
| // concat body subgraph tensor to max iteration in all tensor | |||
| auto body_max_iteration_tensor_idx = body_subgraph_input.at(max_iteration_idx); | |||
| for (auto &node : cond_graph_nodes_) { | |||
| std::replace_if( | |||
| node->inputIndex.begin(), node->inputIndex.end(), | |||
| [&body_max_iteration_tensor_idx](uint32_t idx) { return idx == body_max_iteration_tensor_idx; }, | |||
| all_tensor_idx); | |||
| } | |||
| // remove maxiteration from body partial input and body func input | |||
| body_partial_node_->inputIndex.erase(body_partial_node_->inputIndex.begin() + max_iteration_idx); | |||
| it = body_subgraph_input.erase(it); | |||
| } else { | |||
| it++; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS SingleSwitchPass::InsertMerge() { | |||
| int ret = RET_OK; | |||
| auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | |||
| @@ -154,9 +123,9 @@ STATUS SingleSwitchPass::InsertMerge() { | |||
| } | |||
| // double merge inputs to contain the outputs of body node | |||
| for (auto &out_index : origin_switch_output_tensor_indices_) { | |||
| auto &switch_out_tensor = graph_->allTensors.at(out_index); | |||
| auto tensor = NewTensor(switch_out_tensor); | |||
| for (auto &index : cond_partial_node_->inputIndex) { | |||
| auto &in_tensor = graph_->allTensors.at(index); | |||
| auto tensor = NewTensor(in_tensor); | |||
| graph_->allTensors.push_back(std::move(tensor)); | |||
| merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); | |||
| } | |||
| @@ -266,10 +235,6 @@ STATUS SingleSwitchPass::Init() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { | |||
| return RET_OK; | |||
| } | |||
| if (switch_node_->inputIndex.size() < kSwitchMinInputSize) { | |||
| MS_LOG(ERROR) << "switch node: " << switch_node_->name | |||
| << " 's input size is not right, size: " << switch_node_->inputIndex.size(); | |||
| @@ -297,10 +262,6 @@ STATUS SingleSwitchPass::Init() { | |||
| } | |||
| } | |||
| if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || | |||
| body_partial_node_->primitive->value.type != PrimitiveType_Partial) { | |||
| return RET_OK; | |||
| } | |||
| // get cond_graph_nodes_ | |||
| cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex; | |||
| auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices; | |||
| @@ -330,17 +291,36 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| auto &partial_inputs = partial_node->inputIndex; | |||
| auto &subgraph_inputs = graph_->subGraph.at(subgraph_index)->inputIndices; | |||
| auto &subgraph = graph_->subGraph.at(subgraph_index); | |||
| auto &subgraph_inputs = subgraph->inputIndices; | |||
| std::map<int, int> subgraph_input_map; | |||
| std::vector<int> new_subgraph_inputs{}; | |||
| std::vector<std::pair<int, int>> tmp_inputs_order{}; | |||
| for (unsigned int &subgraph_input : subgraph_inputs) { | |||
| auto &tensor = graph_->allTensors.at(subgraph_input); | |||
| // get parameter input index k. subgraph name + “_input_" + "k" | |||
| char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 7]; | |||
| int partial_idx = k - '0'; | |||
| if (tensor->name.size() < subgraph->name.size() + 8) { | |||
| MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right."; | |||
| return RET_ERROR; | |||
| } | |||
| int partial_idx = -1; | |||
| if (tensor->name.find("_input_") != std::string::npos) { | |||
| // get parameter input index k. subgraph name + “_input_" + "k" | |||
| auto pos = subgraph->name.size() + sizeof("_input_"); | |||
| auto pos2 = tensor->name.find('_', pos); | |||
| auto idx_str = tensor->name.substr(pos - 1, pos2); | |||
| partial_idx = std::stoi(idx_str); | |||
| } | |||
| if (tensor->name.find("_output_") != std::string::npos) { | |||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||
| auto pos = subgraph->name.size() + sizeof("_output_"); | |||
| auto pos2 = tensor->name.find('_', pos); | |||
| auto idx_str = tensor->name.substr(pos - 1, pos2); | |||
| partial_idx = std::stoi(idx_str); | |||
| } | |||
| subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]}); | |||
| new_subgraph_inputs.push_back(partial_inputs[partial_idx]); | |||
| tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]); | |||
| } | |||
| for (auto &subgraph_node : subgraph_nodes) { | |||
| @@ -350,6 +330,13 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem | |||
| } | |||
| } | |||
| } | |||
| std::sort(tmp_inputs_order.begin(), tmp_inputs_order.end(), | |||
| [](std::pair<int, int> a, std::pair<int, int> b) { return a.first < b.first; }); | |||
| std::vector<int> new_subgraph_inputs{}; | |||
| std::transform(tmp_inputs_order.begin(), tmp_inputs_order.end(), std::back_inserter(new_subgraph_inputs), | |||
| [](std::pair<int, int> iter) { return iter.second; }); | |||
| subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end()); | |||
| return RET_OK; | |||
| @@ -362,17 +349,28 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| auto &partial_outputs = partial_node->outputIndex; | |||
| auto &subgraph_outputs = graph_->subGraph.at(subgraph_index)->outputIndices; | |||
| auto &subgraph = graph_->subGraph.at(subgraph_index); | |||
| auto &subgraph_outputs = subgraph->outputIndices; | |||
| std::map<int, int> subgraph_output_map; | |||
| std::vector<int> new_subgraph_outputs{}; | |||
| std::vector<std::pair<int, int>> tmp_outputs_order{}; | |||
| for (unsigned int &subgraph_output : subgraph_outputs) { | |||
| auto &tensor = graph_->allTensors.at(subgraph_output); | |||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||
| char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 8]; | |||
| int partial_idx = k - '0'; | |||
| subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]}); | |||
| new_subgraph_outputs.push_back(partial_outputs[partial_idx]); | |||
| for (auto &node : subgraph_nodes) { | |||
| if (IsContain(node->outputIndex, subgraph_output)) { | |||
| int partial_idx = -1; | |||
| if (node->name == "LogicalAnd") { | |||
| partial_idx = 0; | |||
| } else { | |||
| // get parameter input index k. subgraph name + “_output_" + "k" | |||
| auto pos = subgraph->name.size() + sizeof("_output_"); | |||
| auto pos2 = node->name.find('_', pos); | |||
| auto idx_str = node->name.substr(pos - 1, pos2); | |||
| partial_idx = std::stoi(idx_str); | |||
| } | |||
| subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]}); | |||
| tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]); | |||
| } | |||
| } | |||
| } | |||
| for (auto &subgraph_node : subgraph_nodes) { | |||
| @@ -382,6 +380,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||
| } | |||
| } | |||
| } | |||
| std::vector<int> new_subgraph_outputs{}; | |||
| std::transform(tmp_outputs_order.begin(), tmp_outputs_order.end(), std::back_inserter(new_subgraph_outputs), | |||
| [](std::pair<int, int> iter) { return iter.second; }); | |||
| subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); | |||
| return RET_OK; | |||
| @@ -416,102 +418,6 @@ STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() { | |||
| return ret; | |||
| } | |||
| STATUS SingleSwitchPass::ConvertSwitchToSelect() { | |||
| MS_ASSERT(switch_node_->inputIndex.size() >= 3); | |||
| MS_ASSERT(switch_node_->inputIndex.size() % 2 != 0); | |||
| MS_ASSERT(switch_node_->outputIndex.size() * 2 + 1 == switch_node_->inputIndex.size()); | |||
| auto bool_index = switch_node_->inputIndex.front(); | |||
| // insert switch node1 | |||
| auto switch_node1 = std::make_unique<CNodeT>(); | |||
| switch_node1->name = switch_node_->name + "-Switch-1"; | |||
| switch_node1->primitive = std::make_unique<PrimitiveT>(); | |||
| switch_node1->primitive->value.type = PrimitiveType_Switch; | |||
| switch_node1->primitive->value.value = new (std::nothrow) SwitchT(); | |||
| switch_node1->inputIndex = {bool_index}; | |||
| std::vector<int> part_one_input_index( | |||
| switch_node_->inputIndex.begin() + 1, | |||
| switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2); | |||
| switch_node1->inputIndex.insert(switch_node1->inputIndex.end(), part_one_input_index.begin(), | |||
| part_one_input_index.end()); | |||
| std::vector<std::unique_ptr<TensorT>> switch_output_tensors1(part_one_input_index.size() * 2); | |||
| std::vector<int> switch_output_indexes1(part_one_input_index.size() * 2); | |||
| int i = 0; | |||
| for (const auto &input_index : part_one_input_index) { | |||
| auto &switch_in_tensor = graph_->allTensors.at(input_index); | |||
| auto tensor1 = NewTensor(switch_in_tensor); | |||
| auto tensor2 = NewTensor(switch_in_tensor); | |||
| switch_output_tensors1[i] = std::move(tensor1); | |||
| switch_output_tensors1[part_one_input_index.size() + i] = std::move(tensor2); | |||
| switch_output_indexes1[i] = graph_->allTensors.size() - 1 + i; | |||
| switch_output_indexes1[part_one_input_index.size() + i] = | |||
| graph_->allTensors.size() - 1 + i + part_one_input_index.size(); | |||
| i++; | |||
| } | |||
| for (auto &tensor : switch_output_tensors1) { | |||
| graph_->allTensors.emplace_back(std::move(tensor)); | |||
| } | |||
| switch_node1->outputIndex.insert(switch_node1->outputIndex.begin(), switch_output_indexes1.begin(), | |||
| switch_output_indexes1.end()); | |||
| // insert switch node2 | |||
| auto switch_node2 = std::make_unique<CNodeT>(); | |||
| switch_node2->name = switch_node_->name + "-Switch-1"; | |||
| switch_node2->primitive = std::make_unique<PrimitiveT>(); | |||
| switch_node2->primitive->value.type = PrimitiveType_Switch; | |||
| switch_node2->primitive->value.value = new (std::nothrow) SwitchT(); | |||
| switch_node2->inputIndex = {bool_index}; | |||
| std::vector<int> part_two_input_index( | |||
| switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2, switch_node_->inputIndex.end()); | |||
| switch_node2->inputIndex.insert(switch_node2->inputIndex.end(), part_two_input_index.begin(), | |||
| part_two_input_index.end()); | |||
| std::vector<std::unique_ptr<TensorT>> switch_output_tensors2(part_two_input_index.size() * 2); | |||
| std::vector<int> switch_output_indexes2(part_two_input_index.size() * 2); | |||
| i = 0; | |||
| for (const auto &input_index : part_two_input_index) { | |||
| auto &switch_in_tensor = graph_->allTensors.at(input_index); | |||
| auto tensor1 = NewTensor(switch_in_tensor); | |||
| auto tensor2 = NewTensor(switch_in_tensor); | |||
| switch_output_tensors2[i] = std::move(tensor1); | |||
| switch_output_tensors2[part_two_input_index.size() + i] = std::move(tensor2); | |||
| switch_output_indexes2[i] = graph_->allTensors.size() - 1 + i; | |||
| switch_output_indexes2[part_two_input_index.size() + i] = | |||
| graph_->allTensors.size() - 1 + i + part_two_input_index.size(); | |||
| i++; | |||
| } | |||
| for (auto &tensor : switch_output_tensors2) { | |||
| graph_->allTensors.emplace_back(std::move(tensor)); | |||
| } | |||
| switch_node2->outputIndex.insert(switch_node2->outputIndex.begin(), switch_output_indexes2.begin(), | |||
| switch_output_indexes2.end()); | |||
| // insert merge | |||
| auto merge_node = std::make_unique<CNodeT>(); | |||
| merge_node->name = switch_node_->name + "-Merge"; | |||
| merge_node->primitive = std::make_unique<PrimitiveT>(); | |||
| merge_node->primitive->value.type = PrimitiveType_Merge; | |||
| merge_node->primitive->value.value = new (std::nothrow) MergeT(); | |||
| std::vector<int> merge_input_indexes(switch_node_->outputIndex.size() * 2); | |||
| for (i = 0; i < switch_node_->outputIndex.size(); i++) { | |||
| merge_input_indexes[i] = switch_output_indexes1[i]; | |||
| merge_input_indexes[i + switch_node_->outputIndex.size()] = | |||
| switch_output_indexes2[i + switch_node_->outputIndex.size()]; | |||
| merge_node->outputIndex.emplace_back(switch_node_->outputIndex.at(i)); | |||
| } | |||
| merge_node->inputIndex.insert(merge_node->inputIndex.end(), merge_input_indexes.begin(), merge_input_indexes.end()); | |||
| graph_->nodes.emplace_back(std::move(switch_node1)); | |||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||
| graph_->nodes.emplace_back(std::move(switch_node2)); | |||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||
| graph_->nodes.emplace_back(std::move(merge_node)); | |||
| graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); | |||
| RemoveUselessNode(switch_node_, graph_); | |||
| return RET_OK; | |||
| } | |||
| STATUS SingleSwitchPass::Run() { | |||
| int ret = Init(); | |||
| if (ret != RET_OK) { | |||
| @@ -519,24 +425,6 @@ STATUS SingleSwitchPass::Run() { | |||
| return ret; | |||
| } | |||
| if (switch_node_->inputIndex.size() == kSwitchMinInputSize) { | |||
| return RET_OK; | |||
| } | |||
| if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial || | |||
| body_partial_node_->primitive->value.type != PrimitiveType_Partial) { | |||
| ret = ConvertSwitchToSelect(); | |||
| return ret; | |||
| } | |||
| if (IsLoop()) { | |||
| ret = MoveMaxIterationToCond(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MoveMaxIterationToCond failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| ret = DoubleSwitchOutput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret; | |||
| @@ -45,11 +45,9 @@ class SingleSwitchPass { | |||
| STATUS Init(); | |||
| size_t InitThisGraphIndex(); | |||
| STATUS DoubleSwitchOutput(); | |||
| STATUS MoveMaxIterationToCond(); | |||
| STATUS UpdateSwitchUser(); | |||
| STATUS ConcatCondSubgraphInputAndOutput(); | |||
| STATUS ConcatBodySubgraphInputAndOutput(); | |||
| STATUS ConvertSwitchToSelect(); | |||
| bool IsLoop(); | |||
| STATUS InsertMerge(); | |||
| STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, | |||
| @@ -27,56 +27,71 @@ namespace mindspore { | |||
| namespace lite { | |||
| STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::vector<std::unique_ptr<schema::CNodeT>> newNodes; | |||
| std::vector<size_t> sinkedTensorIdxes; | |||
| // put all const tensor index into sinkedTensorIdxes | |||
| std::vector<std::unique_ptr<schema::CNodeT>> new_nodes; | |||
| std::vector<size_t> sinked_tensor_idxes; | |||
| // put all const tensor index into sinked_tensor_idxes | |||
| for (size_t i = 0; i < graph->allTensors.size(); i++) { | |||
| if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) { | |||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i); | |||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i); | |||
| } | |||
| } | |||
| auto &oldNodes = graph->nodes; | |||
| std::queue<std::unique_ptr<schema::CNodeT>> opQueue; | |||
| // put all non depend node into queue | |||
| for (auto &node : graph->nodes) { | |||
| if (IsNodeNonDepend(node, sinkedTensorIdxes)) { | |||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end()); | |||
| opQueue.push(std::move(node)); | |||
| auto &old_nodes = graph->nodes; | |||
| std::queue<std::unique_ptr<schema::CNodeT>> op_queue; | |||
| // put all none depend node into queue | |||
| for (size_t i = 0; i < graph->subGraph.size(); i++) { | |||
| std::vector<unsigned int> new_subgraph_node_indices = {}; | |||
| auto subgraph_node_indices = graph->subGraph[i]->nodeIndices; | |||
| for (size_t j = 0; j < subgraph_node_indices.size(); j++) { | |||
| auto &node = old_nodes[subgraph_node_indices[j]]; | |||
| if (IsNodeNonDepend(node, sinked_tensor_idxes)) { | |||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end()); | |||
| op_queue.push(std::move(node)); | |||
| } | |||
| } | |||
| } | |||
| // bfs | |||
| while (!opQueue.empty()) { | |||
| auto &node = opQueue.front(); | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get())); | |||
| for (auto postNodeIdx : postNodeIdxes) { | |||
| auto &postNode = oldNodes.at(postNodeIdx); | |||
| // check if postNode is non-depended | |||
| if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) { | |||
| sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end()); | |||
| opQueue.push(std::move(postNode)); | |||
| while (!op_queue.empty()) { | |||
| auto &node = op_queue.front(); | |||
| auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get())); | |||
| for (auto post_node_idx : post_node_idxes) { | |||
| if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) { | |||
| auto &post_node = old_nodes.at(post_node_idx); | |||
| // check if post_node is non-depended | |||
| if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) { | |||
| sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(), | |||
| post_node->outputIndex.end()); | |||
| op_queue.push(std::move(post_node)); | |||
| } | |||
| } | |||
| } | |||
| new_nodes.emplace_back(std::move(node)); | |||
| new_subgraph_node_indices.push_back(new_nodes.size() - 1); | |||
| op_queue.pop(); | |||
| } | |||
| newNodes.emplace_back(std::move(node)); | |||
| opQueue.pop(); | |||
| graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices); | |||
| } | |||
| if (newNodes.size() != oldNodes.size()) { | |||
| MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size() | |||
| << ", newNodesSize: " << newNodes.size(); | |||
| if (new_nodes.size() != old_nodes.size()) { | |||
| MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size() | |||
| << ", new_nodes size: " << new_nodes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| graph->nodes.swap(newNodes); | |||
| graph->nodes.swap(new_nodes); | |||
| return RET_OK; | |||
| } | |||
| bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, | |||
| const std::vector<size_t> &sinkedTensorIdxes) { | |||
| const std::vector<size_t> &sinked_tensor_idxes) { | |||
| MS_ASSERT(node != nullptr); | |||
| for (auto inputIdx : node->inputIndex) { | |||
| if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) { | |||
| return false; | |||
| } | |||
| if (node->primitive->value.type == schema::PrimitiveType_Merge) { | |||
| auto node_input_index = node->inputIndex; | |||
| MS_ASSERT(node_input_index.size() % 2 == 0); | |||
| return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2, | |||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) || | |||
| std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(), | |||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }); | |||
| } else { | |||
| return std::all_of(node->inputIndex.begin(), node->inputIndex.end(), | |||
| [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); }); | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | |||
| return func_graph_ptr_; | |||
| } | |||
| @@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st | |||
| MS_LOG(ERROR) << "convert graph outputs failed."; | |||
| return nullptr; | |||
| } | |||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | |||
| return func_graph_ptr_; | |||
| } | |||
| @@ -61,7 +61,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Mul; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "Div") { | |||
| } else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") { | |||
| auto attr = std::make_unique<schema::DivT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| @@ -154,6 +154,7 @@ TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfRealDivParser("RealDiv", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); | |||
| @@ -37,10 +37,11 @@ static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { | |||
| AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { | |||
| AnfNodePtr ret = nullptr; | |||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||
| ret = anf_node_map.at(name); | |||
| auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); | |||
| if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { | |||
| ret = anf_node_map.at(flat_anf_name); | |||
| } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | |||
| ret = anf_node_map.at(name + ":0"); | |||
| ret = anf_node_map.at(flat_anf_name + ":0"); | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -212,6 +213,17 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } else if (type == kObjectTypeString) { | |||
| auto tensor_data = new (std::nothrow) string; | |||
| if (tensor_proto.string_val_size() == 1) { | |||
| string value = tensor_proto.string_val(0); | |||
| *tensor_data = value; | |||
| } else { | |||
| MS_LOG(ERROR) << "string size bigger than one, not support."; | |||
| return RET_ERROR; | |||
| } | |||
| tensor_size = (*tensor_data).size(); | |||
| param_value->SetTensorData(tensor_data, tensor_size); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | |||
| return RET_ERROR; | |||
| @@ -318,6 +330,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | |||
| auto &node_def = tf_root_graph_->node(i); | |||
| @@ -364,7 +377,6 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| std::map<CNodePtr, FuncGraphPtr> while_cond_map; | |||
| std::map<CNodePtr, FuncGraphPtr> while_body_map; | |||
| for (int i = 0; i < subgraph_size; i++) { | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| auto &tf_sub_fuction = graph_def_liarary.function(i); | |||
| auto &tf_sub_signature = tf_sub_fuction.signature(); | |||
| auto input_arg_size = tf_sub_signature.input_arg_size(); | |||
| @@ -381,13 +393,17 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| } | |||
| FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>(); | |||
| sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); | |||
| std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map; | |||
| // convert sub graph inputs | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| for (int j = 0; j < input_arg_size; j++) { | |||
| auto &input_arg = tf_sub_signature.input_arg(j); | |||
| auto paramter = sub_func_graph->add_parameter(); | |||
| paramter->set_name(input_arg.name()); | |||
| anf_sub_node_map[input_arg.name()] = paramter; | |||
| auto root_while_inputs = while_cnode->inputs(); | |||
| paramter->set_abstract(root_while_inputs[j + 1]->abstract()); | |||
| sub_graph_inputs.emplace_back(paramter); | |||
| } | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map; | |||
| @@ -452,8 +468,19 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| } | |||
| // hardcode subgraph inputs name | |||
| for (size_t j = 0; j < sub_graph_inputs.size(); j++) { | |||
| sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); | |||
| sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter"); | |||
| } | |||
| // hardcode subgraph outputs name | |||
| for (size_t j = 1; j < sub_output_nodes.size(); j++) { | |||
| if (utils::isa<CNodePtr>(sub_output_nodes[j])) { | |||
| sub_output_nodes[j]->cast<CNodePtr>()->set_fullname_with_scope(sub_graph_name + "_output_" + | |||
| std::to_string(j - 1) + "_cnode"); | |||
| } else if (utils::isa<ParameterPtr>(sub_output_nodes[j])) { | |||
| sub_output_nodes[j]->cast<ParameterPtr>()->set_name(sub_graph_name + "_output_" + std::to_string(j - 1) + | |||
| "_parameter"); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; | |||
| } | |||
| auto status = WhileNodePostProcess(while_cond_map, while_body_map); | |||
| @@ -469,9 +496,8 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr | |||
| MS_LOG(ERROR) << "while cond body size error"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<FuncGraphPtr> roots = {anf_root_graph_}; | |||
| auto root_func_manager = std::make_shared<FuncGraphManager>(roots); | |||
| anf_root_graph_->set_manager(root_func_manager); | |||
| static auto root_func_manager = Manage(anf_root_graph_); | |||
| for (auto &kv : while_cond_map) { | |||
| auto while_node = kv.first; | |||
| auto &cond_sub_graph = kv.second; | |||
| @@ -633,6 +659,11 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); | |||
| auto input_name = pair.second->input(i); | |||
| if (input_name[0] == '^') { | |||
| input_name.erase(0, 1); | |||
| } | |||
| all_node_inputs.insert(input_name); | |||
| } | |||
| } | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| @@ -644,7 +675,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { | |||
| auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); | |||
| auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); | |||
| if (anf_node == nullptr) { | |||
| MS_LOG(ERROR) << "can't find anf node"; | |||
| MS_LOG(ERROR) << "can't find anf node: " << origin_name; | |||
| return RET_ERROR; | |||
| } | |||
| output_nodes.push_back(anf_node); | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_ragged_range_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF RaggedRangeParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::RangeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) { | |||
| MS_LOG(ERROR) << "The starts attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->start = static_cast<int32_t>(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) { | |||
| MS_LOG(ERROR) << "The limits attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->limit = static_cast<int32_t>(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) { | |||
| MS_LOG(ERROR) << "The deltas attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->delta = static_cast<int32_t>(attr_value.i()); | |||
| primitive->value.type = schema::PrimitiveType_Range; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGFED_RANGE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGGED_RANGE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFRaggedRangeParser : public TFNodeParser { | |||
| public: | |||
| TFRaggedRangeParser() = default; | |||
| ~TFRaggedRangeParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_range_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFRangeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF RangeParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::RangeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) { | |||
| MS_LOG(ERROR) << "The start attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->start = static_cast<int32_t>(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) { | |||
| MS_LOG(ERROR) << "The limit attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->limit = static_cast<int32_t>(attr_value.i()); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) { | |||
| MS_LOG(ERROR) << "The delta attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->delta = static_cast<int32_t>(attr_value.i()); | |||
| primitive->value.type = schema::PrimitiveType_Range; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFRangeParser : public TFNodeParser { | |||
| public: | |||
| TFRangeParser() = default; | |||
| ~TFRangeParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_ | |||
| @@ -9,7 +9,7 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| @@ -18,8 +18,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <string_view> | |||
| #include <unordered_map> | |||
| #include <regex> | |||
| #include <unordered_map> | |||
| #include "src/common/log_adapter.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -76,6 +76,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| func_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||
| return func_graph_; | |||
| } | |||
| @@ -21,7 +21,22 @@ | |||
| #include "mindspore/lite/src/ops/primitive_c.h" | |||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||
| using mindspore::lite::RET_INFER_INVALID; | |||
| namespace mindspore::opt { | |||
| ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||
| auto para_value_lite = std::make_shared<ParamValueLite>(); | |||
| if (para_value_lite == nullptr) { | |||
| MS_LOG(ERROR) << "new ParamValueLite failed"; | |||
| return nullptr; | |||
| } | |||
| para_value_lite->set_tensor_shape(tensor->shape()); | |||
| para_value_lite->set_tensor_type(tensor->data_type()); | |||
| para_value_lite->set_format(tensor->format()); | |||
| return para_value_lite; | |||
| } | |||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| std::vector<int> shape(tensor->shape()); | |||
| @@ -33,15 +48,30 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li | |||
| MS_LOG(ERROR) << "new AbstractTensor failed"; | |||
| return nullptr; | |||
| } | |||
| auto new_value = std::make_shared<ParamValueLite>(); | |||
| if (new_value == nullptr) { | |||
| auto para_value_lite = NewParamValueLitePtr(tensor); | |||
| if (para_value_lite == nullptr) { | |||
| MS_LOG(ERROR) << "new ParamValueLite failed"; | |||
| return nullptr; | |||
| } | |||
| new_value->set_tensor_shape(tensor->shape()); | |||
| new_value->set_tensor_type(tensor->data_type()); | |||
| new_value->set_format(tensor->format()); | |||
| new_abstract->set_value(new_value); | |||
| if (type_id == kObjectTypeTensorType) { | |||
| auto tensor_list = dynamic_cast<lite::TensorList *>(tensor); | |||
| if (tensor_list == nullptr) { | |||
| MS_LOG(ERROR) << "cast tensor_list failed"; | |||
| return nullptr; | |||
| } | |||
| auto tensor_info = new int[tensor_list->element_shape().size() + 2]; | |||
| tensor_info[0] = tensor_list->tensors_data_type(); | |||
| tensor_info[1] = tensor_list->element_shape().size(); | |||
| for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) { | |||
| tensor_info[i + 2] = tensor_list->element_shape()[i]; | |||
| } | |||
| para_value_lite->set_tensor_addr(tensor_info); | |||
| para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2); | |||
| } | |||
| new_abstract->set_value(para_value_lite); | |||
| return new_abstract; | |||
| } | |||
| @@ -121,13 +151,13 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| } | |||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||
| MS_LOG(WARNING) << "input is value node"; | |||
| MS_LOG(WARNING) << cnode->fullname_with_scope() << "'s input[" << i << "] is value node"; | |||
| continue; | |||
| } | |||
| AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of CNode is nullptr"; | |||
| MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| @@ -194,7 +224,7 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< | |||
| MS_ASSERT(output_tensors != nullptr); | |||
| auto abstract = cnode->abstract(); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "abstract is nullptr"; | |||
| MS_LOG(ERROR) << "node " << cnode->fullname_with_scope() << " abstract is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<TypeId> types; | |||
| @@ -264,7 +294,62 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu | |||
| return RET_OK; | |||
| } | |||
| int InferShapePass::StrIsContain(const std::vector<std::string> &total, const std::string &aim) { | |||
| for (size_t i = 0; i < total.size(); i++) { | |||
| if (aim.find(total[i]) != std::string::npos) { | |||
| return i; | |||
| } | |||
| } | |||
| return -1; | |||
| } | |||
| STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph) { | |||
| // hard code construct input parameter name | |||
| std::vector<std::string> inputs_names{}; | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| inputs_names.emplace_back("_input_" + std::to_string(i - 1) + "_parameter"); | |||
| } | |||
| // copy cnode input to func_graph input | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (utils::isa<ParameterPtr>(node)) { | |||
| auto pos = StrIsContain(inputs_names, node->fullname_with_scope()); | |||
| if (pos != -1) { | |||
| auto pnode = utils::cast<ParameterPtr>(node); | |||
| auto input_pnode = utils::cast<ParameterPtr>(cnode->input(pos + 1)); | |||
| MS_ASSERT(pnode != nullptr); | |||
| pnode->set_abstract(input_pnode->abstract()); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { | |||
| auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>(); | |||
| MS_ASSERT(body_partial_cnode != nullptr); | |||
| auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>(); | |||
| MS_ASSERT(body_vnode != nullptr); | |||
| auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode); | |||
| MS_ASSERT(body_fg != nullptr); | |||
| AbstractBasePtrList abstract_list; | |||
| auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output()); | |||
| for (auto &cnode : body_fg_output_cnode->inputs()) { | |||
| if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) { | |||
| continue; | |||
| } | |||
| abstract_list.push_back(cnode->abstract()); | |||
| } | |||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| return RET_OK; | |||
| } | |||
| bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| if (func_graph->has_flag("HasInferShaped")) { | |||
| return true; | |||
| } | |||
| if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | |||
| MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | |||
| return false; | |||
| @@ -287,8 +372,14 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0)); | |||
| if (origin_primc == nullptr) { | |||
| MS_LOG(ERROR) << "origin_primc is nullptr"; | |||
| return false; | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||
| if (sub_func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "node " << node->fullname_with_scope() << "'s origin_primc is nullptr"; | |||
| return false; | |||
| } else { | |||
| MS_LOG(WARNING) << "subgraph infer shape invalid."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| } | |||
| auto origin_primt = origin_primc->primitiveT(); | |||
| if (origin_primt == nullptr) { | |||
| @@ -296,6 +387,15 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| return false; | |||
| } | |||
| auto type = GetCNodeType(cnode); | |||
| if (type == schema::PrimitiveType_Switch) { | |||
| int ret = SwitchCNodeInferShape(cnode); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PartialCNodeInferShape failed."; | |||
| return false; | |||
| } | |||
| } | |||
| if ((type == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | |||
| @@ -41,6 +41,9 @@ class InferShapePass : public Pass { | |||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | |||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | |||
| STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | |||
| STATUS SwitchCNodeInferShape(const CNodePtr &cnode); | |||
| int StrIsContain(const std::vector<std::string> &total, const std::string &aim); | |||
| int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); | |||
| private: | |||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/while_pass.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "mindspore/lite/src/ops/primitive_c.h" | |||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/tensor.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/ops/switch.h" | |||
| #include "src/ops/partial.h" | |||
| namespace mindspore::opt { | |||
| ValueNodePtr WhilePass::GetSwitchAnfPrim() { | |||
| auto switch_primitiveT = new (std::nothrow) schema::PrimitiveT; | |||
| if (switch_primitiveT == nullptr) { | |||
| MS_LOG(ERROR) << "new switch_primitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| switch_primitiveT->value.type = schema::PrimitiveType_Switch; | |||
| switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT; | |||
| if (switch_primitiveT->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new MakeTupleT failed"; | |||
| return nullptr; | |||
| } | |||
| auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT); | |||
| ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); | |||
| return partial_anf_prim; | |||
| } | |||
| void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, | |||
| std::string para_name) { | |||
| for (auto &node : node_list) { | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| for (size_t k = 0; k < cnode->inputs().size(); k++) { | |||
| if (!utils::isa<ParameterPtr>(cnode->input(k))) { | |||
| continue; | |||
| } | |||
| auto para_input = utils::cast<ParameterPtr>(cnode->input(k)); | |||
| if (para_input->name() == para_name) { | |||
| cnode->set_input(k, new_input_cnode); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool WhilePass::Run(const FuncGraphPtr &graph) { | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| static int count = 0; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| if (opt::GetCNodeType(node) != schema::PrimitiveType_While) { | |||
| continue; | |||
| } | |||
| auto while_cnode = node->cast<CNodePtr>(); | |||
| MS_ASSERT(while_cnode != nullptr); | |||
| if (while_cnode->inputs().size() < kWhileMinInputSize) { | |||
| MS_LOG(ERROR) << "while input is not right."; | |||
| return false; | |||
| } | |||
| // the order is fixed. | |||
| auto cond_vnode = while_cnode->input(kWhileCondIndex); | |||
| auto body_vnode = while_cnode->input(kWhileBodyIndex); | |||
| // body_vnode->cast<ValueNodePtr>()->set_value() | |||
| auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode); | |||
| auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode); | |||
| if (cond_fg == nullptr || body_fg == nullptr) { | |||
| MS_LOG(ERROR) << "Get value as func_graph failed."; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); | |||
| return false; | |||
| } | |||
| // create cond partial cnode | |||
| std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode}; | |||
| // create body partial cnode | |||
| std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode}; | |||
| // add while op input to cond_cnode and body_cnode | |||
| cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | |||
| while_cnode->inputs().end()); | |||
| body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, | |||
| while_cnode->inputs().end()); | |||
| static int idx = 0; | |||
| auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); | |||
| cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); | |||
| cond_partial_node->set_abstract(cond_fg->output()->abstract()); | |||
| auto body_partial_node = graph->NewCNode(body_partial_op_inputs); | |||
| body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); | |||
| idx++; | |||
| // concat body_fg output to cond_fg input | |||
| auto body_output = body_fg->output(); | |||
| auto body_output_cnode = utils::cast<CNodePtr>(body_output); | |||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(body_output_cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed."; | |||
| return false; | |||
| } | |||
| // concat body to cond | |||
| std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode}; | |||
| if ((schema::PrimitiveType)(prim->Type()) == schema::PrimitiveType_MakeTuple) { | |||
| for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) { | |||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(i)); | |||
| } | |||
| } else { | |||
| body_to_cond_inputs.emplace_back(body_output_cnode->input(1)); | |||
| } | |||
| // concat body to cond | |||
| auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs); | |||
| body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond"); | |||
| auto body_fg_manager = body_fg->manager(); | |||
| body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode); | |||
| body_fg->set_output(body_to_cond_cnode); | |||
| body_partial_node->set_abstract(cond_fg->output()->abstract()); | |||
| // create switch cnode | |||
| ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim(); | |||
| if (switch_anf_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "GetSwitchAnfPrim failed."; | |||
| return false; | |||
| } | |||
| // insert switch node | |||
| std::vector<AnfNodePtr> switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node}; | |||
| auto switch_cnode = graph->NewCNode(switch_op_inputs); | |||
| switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++)); | |||
| AbstractBasePtrList abstract_list; | |||
| auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output()); | |||
| for (auto &cnode : body_fg_output_cnode->inputs()) { | |||
| if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) { | |||
| continue; | |||
| } | |||
| abstract_list.push_back(cnode->abstract()); | |||
| } | |||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| // create cond partial cnode | |||
| auto manager = graph->manager(); | |||
| auto node_users = manager->node_users()[while_cnode]; | |||
| for (auto &node_user : node_users) { | |||
| manager->SetEdge(node_user.first, node_user.second, switch_cnode); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "src/param_value_lite.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class WhilePass : public Pass { | |||
| public: | |||
| WhilePass() : Pass("while_pass") {} | |||
| ~WhilePass() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name); | |||
| ValueNodePtr GetSwitchAnfPrim(); | |||
| const size_t kWhileMinInputSize = 3; | |||
| const size_t kWhileCondIndex = 1; | |||
| const size_t kWhileBodyIndex = 2; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ | |||