Merge pull request !30284 from Nizzan/export_nizzanr1.3
| @@ -142,7 +142,13 @@ set(TRAIN_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/train/graph_dropout.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/meta_graph_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/optimizer.cc | |||
| ) | |||
| if(ENABLE_V0) | |||
| set(TRAIN_SRC | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "src/train/graph_dropout.h" | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) { | |||
| std::vector<schema::CNodeT *> old_nodes{}; | |||
| old_nodes.resize(graph_defT.nodes.size()); | |||
| std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(), | |||
| [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); | |||
| return old_nodes; | |||
| } | |||
| STATUS GraphDropout::Run(schema::MetaGraphT *graph) { | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "graph is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| Optimizer dropout_optimizer; | |||
| auto old_nodes = GetGraphNodes(*graph); | |||
| dropout_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass()); | |||
| dropout_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| dropout_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| auto status = dropout_optimizer.Run(graph); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "graph fusion failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_ | |||
| #include "tools/converter/optimizer.h" | |||
| #include "inner/model_generated.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GraphDropout { | |||
| public: | |||
| GraphDropout() = default; | |||
| ~GraphDropout() = default; | |||
| STATUS Run(schema::MetaGraphT *graph); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include <set> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/train/train_utils.h" | |||
| #include "src/train/graph_dropout.h" | |||
| #include "src/common/quant_utils.h" | |||
| #include "tools/common/storage.h" | |||
| @@ -420,6 +421,15 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) { | |||
| return ((t.data.size() == 0) && (total_dims != 0)); | |||
| } | |||
| int TrainExport::TrainModelDrop() { | |||
| GraphDropout graph_dropout; | |||
| auto status = graph_dropout.Run(meta_graph_); | |||
| if (status != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TrainExport::~TrainExport() { delete meta_graph_; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -47,6 +47,7 @@ class TrainExport { | |||
| void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } | |||
| int LoadModel(void *buf, size_t buf_size); | |||
| int AddTransformNode(); | |||
| int TrainModelDrop(); | |||
| protected: | |||
| virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); | |||
| @@ -716,6 +716,13 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua | |||
| MS_LOG(ERROR) << "cannot export Network"; | |||
| return status; | |||
| } | |||
| if (model_type == MT_INFERENCE) { | |||
| status = texport.TrainModelDrop(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "TrainModelDrop failed."; | |||
| return status; | |||
| } | |||
| } | |||
| status = texport.SaveToFile(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "failed to save to " << file_name; | |||
| @@ -270,6 +270,7 @@ if(MSLITE_ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/parallel/spliter.cc | |||
| ${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc | |||
| ${LITE_DIR}/tools/common/graph_util.cc | |||
| ${LITE_DIR}/tools/common/meta_graph_utils.cc | |||
| ${LITE_DIR}/tools/common/tensor_util.cc | |||
| ${LITE_DIR}/tools/common/node_util.cc | |||
| ${LITE_DIR}/tools/common/storage.cc | |||
| @@ -281,6 +282,15 @@ if(MSLITE_ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/converter/import/primitive_adjust.cc | |||
| ${LITE_DIR}/tools/converter/import/mindir_adjust.cc | |||
| ) | |||
| else() | |||
| set(TEST_LITE_SRC | |||
| ${TEST_LITE_SRC} | |||
| ${LITE_DIR}/tools/common/meta_graph_utils.cc | |||
| ${LITE_DIR}/tools/converter/optimizer.cc | |||
| ${LITE_DIR}/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| if(SUPPORT_TRAIN) | |||
| @@ -292,6 +302,7 @@ if(SUPPORT_TRAIN) | |||
| ${LITE_DIR}/src/train/train_export.cc | |||
| ${LITE_DIR}/src/train/train_utils.cc | |||
| ${LITE_DIR}/src/train/transfer_session.cc | |||
| ${LITE_DIR}/src/train/graph_dropout.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| ${LITE_DIR}/tools/common/storage.cc | |||
| ) | |||
| @@ -37,6 +37,7 @@ | |||
| #include "tools/converter/quantizer/bitpacking.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "src/ops/ops_utils.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/converter/converter_context.h" | |||
| @@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | |||
| file(GLOB CONVERTER_COMMON_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/meta_graph_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/storage.cc | |||
| @@ -24,6 +24,7 @@ | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/quantizer/bitpacking.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| @@ -69,315 +70,6 @@ OpDefCopyer GetSimpleOpCopyer() { | |||
| }; | |||
| } | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) { | |||
| return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx); | |||
| } | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) { | |||
| std::vector<uint32_t> inputIndexes; | |||
| if (inputIndexIdx == -1) { | |||
| inputIndexes = node.inputIndex; | |||
| } else { | |||
| MS_ASSERT(node.inputIndex.size() > inputIndexIdx); | |||
| inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx)); | |||
| } | |||
| std::set<size_t> inputNodeIdx; | |||
| for (uint32_t inputIdx : inputIndexes) { | |||
| auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx); | |||
| inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end()); | |||
| } | |||
| std::vector<size_t> ret; | |||
| ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end()); | |||
| return ret; | |||
| } | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, | |||
| const int outputIndexIdx) { | |||
| return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx); | |||
| } | |||
| void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) { | |||
| std::replace_if( | |||
| std::begin(graphT->outputIndex), std::end(graphT->outputIndex), | |||
| [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index); | |||
| for (auto &subGraph : graphT->subGraph) { | |||
| std::replace_if( | |||
| std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices), | |||
| [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index); | |||
| } | |||
| } | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) { | |||
| std::vector<uint32_t> outputIndexes; | |||
| if (outputIndexIdx == -1) { | |||
| outputIndexes = node.outputIndex; | |||
| } else { | |||
| MS_ASSERT(node.outputIndex.size() > outputIndexIdx); | |||
| outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx)); | |||
| } | |||
| std::set<size_t> outputNodeIdx; | |||
| for (uint32_t outputIdx : outputIndexes) { | |||
| auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx); | |||
| outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end()); | |||
| } | |||
| std::vector<size_t> ret; | |||
| ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end()); | |||
| return ret; | |||
| } | |||
| std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { | |||
| std::vector<size_t> preNodeIdx; | |||
| for (size_t i = 0; i < graphT.nodes.size(); i++) { | |||
| auto &oldNode = graphT.nodes.at(i); | |||
| if (oldNode == nullptr) { | |||
| continue; | |||
| } | |||
| auto outputIndexes = oldNode->outputIndex; | |||
| if (IsContain<uint32_t>(outputIndexes, tensorIdx)) { | |||
| preNodeIdx.emplace_back(i); | |||
| } | |||
| } | |||
| return preNodeIdx; | |||
| } | |||
| std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { | |||
| std::vector<size_t> postNodeIdx; | |||
| for (size_t i = 0; i < graphT.nodes.size(); i++) { | |||
| auto &oldNode = graphT.nodes.at(i); | |||
| if (oldNode == nullptr) { | |||
| continue; | |||
| } | |||
| auto inputIndexes = oldNode->inputIndex; | |||
| if (IsContain<uint32_t>(inputIndexes, tensorIdx)) { | |||
| postNodeIdx.emplace_back(i); | |||
| } | |||
| } | |||
| return postNodeIdx; | |||
| } | |||
| STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| size_t nodeIdx = 0; | |||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | |||
| auto &inNode = graphT->nodes.at(i); | |||
| MS_ASSERT(inNode != nullptr); | |||
| if (inNode->name == node->name) { | |||
| nodeIdx = i; | |||
| break; | |||
| } | |||
| } | |||
| auto inputTensorIdxes = node->inputIndex; | |||
| auto outputTensorIdxes = node->outputIndex; | |||
| if (inputTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (outputTensorIdxes.size() != 1) { | |||
| MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str() | |||
| << "should has 1 output, in fact: " << outputTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto inDataTensorIdx = inputTensorIdxes.front(); | |||
| auto outDataTensorIdx = outputTensorIdxes.front(); | |||
| MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); | |||
| ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT); | |||
| // find poseNode | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); | |||
| for (auto postNodeIdx : postNodeIdxes) { | |||
| MS_ASSERT(graphT->nodes.size() > postNodeIdx); | |||
| auto &postNode = graphT->nodes.at(postNodeIdx); | |||
| MS_ASSERT(postNode != nullptr); | |||
| for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { | |||
| if (*iter == outDataTensorIdx) { | |||
| *iter = inDataTensorIdx; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| RemoveTensor(graphT, outputTensorIdxes); | |||
| node->inputIndex.clear(); | |||
| node->outputIndex.clear(); | |||
| return RET_OK; | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) { | |||
| MS_ASSERT(graph != nullptr); | |||
| return IsolateOneWayNode(graph, nodeIdx, removeTensor); | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| if (graphT->nodes.size() <= nodeIdx) { | |||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| CNodeT *node = graphT->nodes.at(nodeIdx).get(); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "node is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto inputTensorIdxes = node->inputIndex; | |||
| auto outputTensorIdxes = node->outputIndex; | |||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); | |||
| if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) { | |||
| MS_LOG(ERROR) << "Only support node who has no more than one input and one output"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| auto inDataTensorIdx = inputTensorIdxes.front(); | |||
| if (!outputTensorIdxes.empty()) { | |||
| auto outDataTensorIdx = outputTensorIdxes.front(); | |||
| MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); | |||
| MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr); | |||
| ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT); | |||
| // find poseNode | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); | |||
| for (auto postNodeIdx : postNodeIdxes) { | |||
| MS_ASSERT(graphT->nodes.size() > postNodeIdx); | |||
| auto &postNode = graphT->nodes.at(postNodeIdx); | |||
| MS_ASSERT(postNode != nullptr); | |||
| for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { | |||
| if (*iter == outDataTensorIdx) { | |||
| *iter = inDataTensorIdx; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (removeTensor) { | |||
| // now all node's outputTensors are useless | |||
| // remove all node's outputTensors | |||
| auto status = RemoveTensor(graphT, outputTensorIdxes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| node->inputIndex.clear(); | |||
| node->outputIndex.clear(); | |||
| return RET_OK; | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| bool isSubNode = false; | |||
| size_t nodeIdx = 0; | |||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | |||
| auto &inNode = graphT->nodes.at(i); | |||
| MS_ASSERT(inNode != nullptr); | |||
| if (inNode->name == node->name) { | |||
| isSubNode = true; | |||
| nodeIdx = i; | |||
| break; | |||
| } | |||
| } | |||
| if (!isSubNode) { | |||
| MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str(); | |||
| return RET_PARAM_INVALID; | |||
| } else { | |||
| return IsolateOneWayNode(graphT, nodeIdx, removeTensor); | |||
| } | |||
| } | |||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { | |||
| uint32_t deleteIdx = *iter; | |||
| if (!forceDelete) { | |||
| if (GetRefCount(graphT, deleteIdx) > 1) { | |||
| iter++; | |||
| continue; | |||
| } | |||
| } | |||
| // update graph input indices | |||
| for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update graph output indices | |||
| for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| for (auto &subgraph : graphT->subGraph) { | |||
| // update subgraph input indices | |||
| for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) { | |||
| if (*idx > deleteIdx) { | |||
| (*idx)--; | |||
| } | |||
| } | |||
| } | |||
| // update nodes indexes | |||
| for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | |||
| // update nodes input indexes | |||
| UpdateNodeIndex((*node_iter).get(), deleteIdx); | |||
| } | |||
| // update deleteTensorIdx | |||
| for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { | |||
| if (*selfIt > deleteIdx) { | |||
| (*selfIt)--; | |||
| } | |||
| } | |||
| graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx); | |||
| iter = toDeleteTensorIdxes.erase(iter); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { | |||
| MS_ASSERT(node != nullptr); | |||
| for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { | |||
| if (*inIdxIt == deleteIdx) { | |||
| inIdxIt = node->inputIndex.erase(inIdxIt); | |||
| } else { | |||
| if (*inIdxIt > deleteIdx) { | |||
| (*inIdxIt)--; | |||
| } | |||
| inIdxIt++; | |||
| } | |||
| } | |||
| // update nodes output indexes | |||
| for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) { | |||
| if (*outIdxIt == deleteIdx) { | |||
| outIdxIt = node->outputIndex.erase(outIdxIt); | |||
| } else { | |||
| if (*outIdxIt > deleteIdx) { | |||
| (*outIdxIt)--; | |||
| } | |||
| outIdxIt++; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor, | |||
| InsertPlace place) { | |||
| if (nodeIdx >= graphT->nodes.size()) { | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_ | |||
| #include <cstdlib> | |||
| #include <unordered_map> | |||
| @@ -48,34 +48,6 @@ OpDefCopyer GetSimpleOpCopyer(); | |||
| int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs); | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1); | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| int inputIndexIdx = -1); | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1); | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| int outputIndexIdx = -1); | |||
| std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); | |||
| std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); | |||
| void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT); | |||
| STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true); | |||
| STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx); | |||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false); | |||
| STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor, | |||
| InsertPlace place = kBefore); | |||
| @@ -320,4 +292,4 @@ bool PackRepetition(size_t bit_num, schema::TensorT *tensor) { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_ | |||
| @@ -0,0 +1,346 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include <vector> | |||
| #include <set> | |||
| #include "inner/model_generated.h" | |||
| #include "src/common/utils.h" | |||
| #include "nnacl/op_base.h" | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(graphT->allTensors.size() > tensorIdx); | |||
| size_t refCount = 0; | |||
| for (auto &node : graphT->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| if (IsContain(node->inputIndex, tensorIdx)) { | |||
| refCount++; | |||
| } | |||
| } | |||
| return refCount; | |||
| } | |||
| } // namespace | |||
| std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { | |||
| std::vector<size_t> postNodeIdx; | |||
| for (size_t i = 0; i < graphT.nodes.size(); i++) { | |||
| auto &oldNode = graphT.nodes.at(i); | |||
| if (oldNode == nullptr) { | |||
| continue; | |||
| } | |||
| auto inputIndexes = oldNode->inputIndex; | |||
| if (IsContain<uint32_t>(inputIndexes, tensorIdx)) { | |||
| postNodeIdx.emplace_back(i); | |||
| } | |||
| } | |||
| return postNodeIdx; | |||
| } | |||
| std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { | |||
| std::vector<size_t> preNodeIdx; | |||
| for (size_t i = 0; i < graphT.nodes.size(); i++) { | |||
| auto &oldNode = graphT.nodes.at(i); | |||
| if (oldNode == nullptr) { | |||
| continue; | |||
| } | |||
| auto outputIndexes = oldNode->outputIndex; | |||
| if (IsContain<uint32_t>(outputIndexes, tensorIdx)) { | |||
| preNodeIdx.emplace_back(i); | |||
| } | |||
| } | |||
| return preNodeIdx; | |||
| } | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| const int inputIndexIdx) { | |||
| std::vector<uint32_t> inputIndexes; | |||
| if (inputIndexIdx == -1) { | |||
| inputIndexes = node.inputIndex; | |||
| } else { | |||
| MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx)); | |||
| inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx)); | |||
| } | |||
| std::set<size_t> inputNodeIdx; | |||
| for (uint32_t inputIdx : inputIndexes) { | |||
| auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx); | |||
| inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end()); | |||
| } | |||
| std::vector<size_t> ret; | |||
| ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end()); | |||
| return ret; | |||
| } | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) { | |||
| return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx); | |||
| } | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| const int outputIndexIdx) { | |||
| std::vector<uint32_t> outputIndexes; | |||
| if (outputIndexIdx == -1) { | |||
| outputIndexes = node.outputIndex; | |||
| } else { | |||
| MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx)); | |||
| outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx)); | |||
| } | |||
| std::set<size_t> outputNodeIdx; | |||
| for (uint32_t outputIdx : outputIndexes) { | |||
| auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx); | |||
| outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end()); | |||
| } | |||
| std::vector<size_t> ret; | |||
| ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end()); | |||
| return ret; | |||
| } | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, | |||
| const int outputIndexIdx) { | |||
| return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx); | |||
| } | |||
| void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) { | |||
| std::replace_if( | |||
| std::begin(graphT->outputIndex), std::end(graphT->outputIndex), | |||
| [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index); | |||
| for (auto &subGraph : graphT->subGraph) { | |||
| std::replace_if( | |||
| std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices), | |||
| [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index); | |||
| } | |||
| } | |||
| STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) { | |||
| MS_ASSERT(node != nullptr); | |||
| for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { | |||
| if (*inIdxIt == deleteIdx) { | |||
| inIdxIt = node->inputIndex.erase(inIdxIt); | |||
| } else { | |||
| if (*inIdxIt > deleteIdx) { | |||
| (*inIdxIt)--; | |||
| } | |||
| inIdxIt++; | |||
| } | |||
| } | |||
| // update nodes output indexes | |||
| for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) { | |||
| if (*outIdxIt == deleteIdx) { | |||
| outIdxIt = node->outputIndex.erase(outIdxIt); | |||
| } else { | |||
| if (*outIdxIt > deleteIdx) { | |||
| (*outIdxIt)--; | |||
| } | |||
| outIdxIt++; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { | |||
| uint32_t deleteIdx = *iter; | |||
| if (!forceDelete) { | |||
| if (GetRefCount(graphT, deleteIdx) > 1) { | |||
| iter++; | |||
| continue; | |||
| } | |||
| } | |||
| // update graph input indices | |||
| for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update graph output indices | |||
| for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| for (auto &subgraph : graphT->subGraph) { | |||
| // update subgraph input indices | |||
| for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) { | |||
| if (*gInIdx > deleteIdx) { | |||
| (*gInIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) { | |||
| if (*gOutIdx > deleteIdx) { | |||
| (*gOutIdx)--; | |||
| } | |||
| } | |||
| // update subgraph output indices | |||
| for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) { | |||
| if (*idx > deleteIdx) { | |||
| (*idx)--; | |||
| } | |||
| } | |||
| } | |||
| // update nodes indexes | |||
| for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | |||
| // update nodes input indexes | |||
| UpdateNodeIndex((*node_iter).get(), deleteIdx); | |||
| } | |||
| // update deleteTensorIdx | |||
| for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { | |||
| if (*selfIt > deleteIdx) { | |||
| (*selfIt)--; | |||
| } | |||
| } | |||
| graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx); | |||
| iter = toDeleteTensorIdxes.erase(iter); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| size_t nodeIdx = 0; | |||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | |||
| auto &inNode = graphT->nodes.at(i); | |||
| MS_ASSERT(postNode != nullptr); | |||
| if (inNode->name == node->name) { | |||
| nodeIdx = i; | |||
| break; | |||
| } | |||
| } | |||
| auto inputTensorIdxes = node->inputIndex; | |||
| auto outputTensorIdxes = node->outputIndex; | |||
| if (inputTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (outputTensorIdxes.size() != 1) { | |||
| MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str() | |||
| << "should has 1 output, in fact: " << outputTensorIdxes.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto inDataTensorIdx = inputTensorIdxes.front(); | |||
| auto outDataTensorIdx = outputTensorIdxes.front(); | |||
| MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); | |||
| ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT); | |||
| // find poseNode | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); | |||
| for (auto postNodeIdx : postNodeIdxes) { | |||
| MS_ASSERT(graphT->nodes.size() > postNodeIdx); | |||
| auto &postNode = graphT->nodes.at(postNodeIdx); | |||
| MS_ASSERT(postNode != nullptr); | |||
| for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { | |||
| if (*iter == outDataTensorIdx) { | |||
| *iter = inDataTensorIdx; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| RemoveTensor(graphT, outputTensorIdxes); | |||
| node->inputIndex.clear(); | |||
| node->outputIndex.clear(); | |||
| return RET_OK; | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| if (graphT->nodes.size() <= nodeIdx) { | |||
| MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| schema::CNodeT *node = graphT->nodes.at(nodeIdx).get(); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "node is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto inputTensorIdxes = node->inputIndex; | |||
| auto outputTensorIdxes = node->outputIndex; | |||
| auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); | |||
| if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) { | |||
| MS_LOG(ERROR) << "Only support node who has no more than one input and one output"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputTensorIdxes.empty()) { | |||
| MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| auto inDataTensorIdx = inputTensorIdxes.front(); | |||
| if (!outputTensorIdxes.empty()) { | |||
| auto outDataTensorIdx = outputTensorIdxes.front(); | |||
| MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); | |||
| MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr); | |||
| ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT); | |||
| // find poseNode | |||
| auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); | |||
| for (auto postNodeIdx : postNodeIdxes) { | |||
| MS_ASSERT(graphT->nodes.size() > postNodeIdx); | |||
| auto &postNode = graphT->nodes.at(postNodeIdx); | |||
| MS_ASSERT(postNode != nullptr); | |||
| for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { | |||
| if (*iter == outDataTensorIdx) { | |||
| *iter = inDataTensorIdx; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (removeTensor) { | |||
| // now all node's outputTensors are useless | |||
| // remove all node's outputTensors | |||
| auto status = RemoveTensor(graphT, outputTensorIdxes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| node->inputIndex.clear(); | |||
| node->outputIndex.clear(); | |||
| return RET_OK; | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) { | |||
| MS_ASSERT(graph != nullptr); | |||
| return IsolateOneWayNode(graph, nodeIdx, removeTensor); | |||
| } | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) { | |||
| MS_ASSERT(graphT != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| bool isSubNode = false; | |||
| size_t nodeIdx = 0; | |||
| for (size_t i = 0; i < graphT->nodes.size(); i++) { | |||
| auto &inNode = graphT->nodes.at(i); | |||
| MS_ASSERT(postNode != nullptr); | |||
| if (inNode->name == node->name) { | |||
| isSubNode = true; | |||
| nodeIdx = i; | |||
| break; | |||
| } | |||
| } | |||
| if (!isSubNode) { | |||
| MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str(); | |||
| return RET_PARAM_INVALID; | |||
| } else { | |||
| return IsolateOneWayNode(graphT, nodeIdx, removeTensor); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_ | |||
| #include <vector> | |||
| #include "inner/model_generated.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::lite { | |||
| std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); | |||
| std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| int inputIndexIdx = -1); | |||
| std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1); | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, | |||
| int outputIndexIdx = -1); | |||
| std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1); | |||
| STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false); | |||
| void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT); | |||
| STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true); | |||
| STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true); | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_ | |||
| @@ -22,6 +22,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/meta_graph_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc | |||
| @@ -27,6 +27,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -21,6 +21,7 @@ | |||
| #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -95,7 +96,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN | |||
| const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); | |||
| MS_ASSERT(mulNodeBiasTensor != nullptr); | |||
| if (mulNodeBiasTensor->nodeType != NodeType_ValueNode) { | |||
| // dont fusion, return | |||
| // don't fusion, return | |||
| return RET_OK; | |||
| } | |||
| if (mulNodeBiasTensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| @@ -112,7 +113,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN | |||
| const auto &addNodeBiasTensor = graph->allTensors.at(addNodeInputIndex.at(ADD_OP_BIAS_INDEX)); | |||
| MS_ASSERT(addNodeBiasTensor != nullptr); | |||
| if (addNodeBiasTensor->nodeType != NodeType_ValueNode) { | |||
| // dont fusion, return | |||
| // don't fusion, return | |||
| return RET_OK; | |||
| } | |||
| // scale requires scale shape tail sub of input shape, scale shape same as bias shape | |||
| @@ -21,6 +21,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | |||
| #include <queue> | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -20,7 +20,6 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -21,7 +21,6 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -22,6 +22,7 @@ | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "src/common/quant_utils.h" | |||
| @@ -18,8 +18,10 @@ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/meta_graph_utils.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||