Merge pull request !21148 from 徐安越/master1tags/v1.4.0
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -96,6 +96,7 @@ bool PassTutorial::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| // register customed Pass | |||
| REG_PASS(POSITION_BEGIN, PassTutorial) | |||
| REG_PASS(PassTutorial, PassTutorial) | |||
| REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"}) | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -20,9 +20,7 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <mutex> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "include/lite_utils.h" | |||
| namespace mindspore { | |||
| @@ -39,53 +37,33 @@ using PassPtr = std::shared_ptr<Pass>; | |||
| /// \brief PassRegistry defined registration of Pass. | |||
| class MS_API PassRegistry { | |||
| public: | |||
| /// \brief Destructor of PassRegistry. | |||
| virtual ~PassRegistry() = default; | |||
| /// \brief Static method to get a single instance of PassRegistry. | |||
| /// | |||
| /// \return Pointer of PassRegistry. | |||
| static PassRegistry *GetInstance(); | |||
| /// \brief Method to register Pass. | |||
| /// \brief Constructor of PassRegistry to register pass. | |||
| /// | |||
| /// \param[in] position Define where to replace the pass. | |||
| /// \param[in] pos Define where to replace the pass. | |||
| /// \param[in] pass Define user's defined pass. | |||
| void RegPass(int position, const PassPtr &pass); | |||
| PassRegistry(const std::string &pass_name, const PassPtr &pass); | |||
| /// \brief Method to get all passes user write. | |||
| /// \brief Constructor of PassRegistry to assign which passes are required for external extension. | |||
| /// | |||
| /// \return A map include all pass. | |||
| const std::unordered_map<int, PassPtr> &GetPasses() const; | |||
| private: | |||
| /// \brief Constructor of PassRegistry. | |||
| PassRegistry() = default; | |||
| private: | |||
| std::unordered_map<int, PassPtr> passes_; | |||
| std::mutex mutex_; | |||
| }; | |||
| /// \brief PassRegistrar defined registration class of Pass. | |||
| class MS_API PassRegistrar { | |||
| public: | |||
| /// \brief Constructor of PassRegistrar to register pass. | |||
| /// | |||
| /// \param[in] pos Define where to replace the pass. | |||
| /// \param[in] pass Define user's defined pass. | |||
| PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); } | |||
| /// \param[in position Define the place where assigned passes will run. | |||
| /// \param[in] assigned Define the name of passes assigned by user. | |||
| PassRegistry(PassPosition position, const std::vector<std::string> &assigned); | |||
| /// \brief Destructor of PassRegistrar. | |||
| ~PassRegistrar() = default; | |||
| ~PassRegistry() = default; | |||
| }; | |||
| /// \brief Defined registering macro to register Pass, which called by user directly. | |||
| /// | |||
| /// \param[in] position Define where to replace the pass. | |||
| /// \param[in] name Define name of user's pass, which is a string. | |||
| /// \param[in] pass Define user's defined pass. | |||
| #define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared<pass>()); | |||
| #define REG_PASS(name, pass) static PassRegistry g_##name##PassReg(#name, std::make_shared<pass>()); | |||
| /// \brief Defined assigning macro to assign Passes, which called by user directly. | |||
| /// | |||
| /// \param[in] position Define the place where assigned passes will run. | |||
| /// \param[in] assigned Define the name of passes assigned by user. | |||
| #define REG_SCHEDULED_PASS(position, assigned) static PassRegistry g_##position(position, assigned); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -201,12 +201,18 @@ if(MSLITE_ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/converter/converter.cc | |||
| ${LITE_DIR}/tools/converter/export_model.cc | |||
| ${LITE_DIR}/tools/converter/dump_graph.cc | |||
| ${LITE_DIR}/tools/converter/optimizer_manager.cc | |||
| ${LITE_DIR}/tools/converter/parser/parser_utils.cc | |||
| ${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc | |||
| ${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc | |||
| ${LITE_DIR}/tools/optimizer/common/gllo_utils.cc | |||
| ${LITE_DIR}/tools/optimizer/common/format_utils.cc | |||
| ${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/format/conv_weight_format.cc | |||
| ${LITE_DIR}/tools/optimizer/format/delete_redundant_transpose.cc | |||
| ${LITE_DIR}/tools/optimizer/format/to_format_base.cc | |||
| ${LITE_DIR}/tools/optimizer/format/to_nchw_format.cc | |||
| ${LITE_DIR}/tools/optimizer/format/to_nhwc_format.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc | |||
| @@ -247,7 +253,7 @@ if(MSLITE_ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/control_flow_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unify_format_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/decrease_transpose_algo.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/node_infershape.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/reduce_same_act_pass.cc | |||
| @@ -271,7 +277,7 @@ if(MSLITE_ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/common/node_util.cc | |||
| ${LITE_DIR}/tools/common/storage.cc | |||
| ${LITE_DIR}/tools/converter/parser/inputs_adjust.cc | |||
| ${LITE_DIR}/tools/converter/parser/insert_transpose.cc | |||
| ${LITE_DIR}/tools/converter/parser/unify_format.cc | |||
| ${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc | |||
| ${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc | |||
| @@ -119,10 +119,11 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86" || $backen | |||
| fi | |||
| if [[ $backend == "all" || $backend == "arm32_3516D" ]]; then | |||
| sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend | |||
| hi3516_status=$? | |||
| if [[ $hi3516_status -ne 0 ]]; then | |||
| echo "Run nnie hi3516 failed" | |||
| exit 1 | |||
| fi | |||
| exit 0 | |||
| # sh $cur_path/scripts/nnie/run_converter_nnie.sh -r $release_path -m $models_path -d $device_id -e $backend | |||
| # hi3516_status=$? | |||
| # if [[ $hi3516_status -ne 0 ]]; then | |||
| # echo "Run nnie hi3516 failed" | |||
| # exit 1 | |||
| # fi | |||
| fi | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ops/addn.h" | |||
| #include "ops/custom.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/registry/pass_content.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "ut/tools/converter/registry/model_parser_test.h" | |||
| @@ -207,13 +208,17 @@ class TestFusion : public Pass { | |||
| return true; | |||
| } | |||
| }; | |||
| REG_PASS(POSITION_BEGIN, TestFusion) | |||
| REG_PASS(TestFusion, TestFusion) | |||
| REG_SCHEDULED_PASS(POSITION_BEGIN, {"TestFusion"}) | |||
| } // namespace opt | |||
| TEST_F(PassRegistryTest, TestRegistry) { | |||
| auto passes = opt::PassRegistry::GetInstance()->GetPasses(); | |||
| ASSERT_EQ(passes.size(), 1); | |||
| auto begin_pass = passes[opt::POSITION_BEGIN]; | |||
| auto &passes = opt::PassStoreRoomInfo(); | |||
| auto &assigned_passes = opt::ExternalAssignedPassesInfo(); | |||
| ASSERT_EQ(assigned_passes.size(), 1); | |||
| auto pass_names = assigned_passes[opt::POSITION_BEGIN]; | |||
| ASSERT_EQ(pass_names.size(), 1); | |||
| auto begin_pass = passes[pass_names.front()]; | |||
| ASSERT_NE(begin_pass, nullptr); | |||
| auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass); | |||
| ASSERT_NE(begin_pass_test, nullptr); | |||
| @@ -19,6 +19,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/export_model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/optimizer_manager.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc | |||
| @@ -36,7 +37,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/unify_format.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc | |||
| @@ -46,6 +47,11 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/common/gllo_utils.cc | |||
| ../optimizer/common/format_utils.cc | |||
| ../optimizer/common/multiple_pattern_process_pass.cc | |||
| ../optimizer/format/conv_weight_format.cc | |||
| ../optimizer/format/delete_redundant_transpose.cc | |||
| ../optimizer/format/to_format_base.cc | |||
| ../optimizer/format/to_nchw_format.cc | |||
| ../optimizer/format/to_nhwc_format.cc | |||
| ../optimizer/fusion/affine_activation_fusion.cc | |||
| ../optimizer/fusion/affine_fusion.cc | |||
| ../optimizer/fusion/conv_biasadd_fusion.cc | |||
| @@ -102,7 +108,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| ../optimizer/graph/control_flow_pass.cc | |||
| ../optimizer/graph/primitive_adjust_pass.cc | |||
| ../optimizer/graph/unify_format_pass.cc | |||
| ../optimizer/graph/decrease_transpose_algo.cc | |||
| ../optimizer/graph/node_infershape.cc | |||
| ../optimizer/graph/transpose_strategy.cc | |||
| ../optimizer/graph/reduce_same_act_pass.cc | |||
| @@ -20,8 +20,9 @@ | |||
| #include <unordered_map> | |||
| #include <deque> | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/optimizer_manager.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/core/ir/primitive.h" | |||
| #include "ir/primitive.h" | |||
| #include "tools/optimizer/fusion/affine_activation_fusion.h" | |||
| #include "tools/optimizer/fusion/affine_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_biasadd_fusion.h" | |||
| @@ -56,7 +57,7 @@ | |||
| #include "tools/optimizer/graph/control_flow_pass.h" | |||
| #include "tools/optimizer/graph/reduce_same_act_pass.h" | |||
| #include "tools/optimizer/graph/split_one_pass.h" | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include "tools/optimizer/graph/decrease_transpose_algo.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -68,6 +69,10 @@ | |||
| #include "include/registry/pass_registry.h" | |||
| #include "tools/optimizer/fisson/multi_conv_split_pass.h" | |||
| #include "tools/optimizer/fusion/transpose_fusion.h" | |||
| #include "tools/optimizer/format/delete_redundant_transpose.h" | |||
| #include "tools/optimizer/format/to_nchw_format.h" | |||
| #include "tools/optimizer/format/to_nhwc_format.h" | |||
| #include "tools/optimizer/format/conv_weight_format.h" | |||
| using std::string; | |||
| namespace mindspore::lite { | |||
| @@ -238,22 +243,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte | |||
| return RET_OK; | |||
| } | |||
| STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) { | |||
| auto instance = opt::PassRegistry::GetInstance(); | |||
| auto plugin_passes = instance->GetPasses(); | |||
| if (plugin_passes.find(position) == plugin_passes.end()) { | |||
| MS_LOG(DEBUG) << "there is no plugin pass in current position."; | |||
| return RET_OK; | |||
| } | |||
| auto plugin_pass = plugin_passes.at(position); | |||
| if (!plugin_pass->Run(old_graph)) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) { | |||
| all_func_graphs->insert(func_graph); | |||
| auto nodes = func_graph->GetOrderedCnodes(); | |||
| @@ -337,16 +326,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con | |||
| return nullptr; | |||
| } | |||
| status = RunPluginPass(old_graph, opt::POSITION_BEGIN); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run plugin pass failed."; | |||
| if (!opt::RunExternalPass(old_graph, opt::POSITION_BEGIN)) { | |||
| MS_LOG(ERROR) << "Run external pass failed, place is BEGIN"; | |||
| return nullptr; | |||
| } | |||
| auto format_pass = std::make_shared<opt::UnifyFormatPass>(); | |||
| format_pass->Init(config->fmk, config->trainModel); | |||
| if (!format_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "Run format pass failed."; | |||
| if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { | |||
| MS_LOG(ERROR) << "Run transpose opt pass failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -370,16 +356,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con | |||
| } | |||
| } | |||
| format_pass = std::make_shared<opt::UnifyFormatPass>(); | |||
| format_pass->Init(config->fmk, config->trainModel); | |||
| if (!format_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "Run format pass failed."; | |||
| if (!opt::RunOptimizerPass(old_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { | |||
| MS_LOG(ERROR) << "Run transpose opt pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunPluginPass(old_graph, opt::POSITION_END); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run plugin pass failed."; | |||
| if (!opt::RunExternalPass(old_graph, opt::POSITION_END)) { | |||
| MS_LOG(ERROR) << "Run external pass failed, place is END"; | |||
| return nullptr; | |||
| } | |||
| @@ -403,7 +386,20 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con | |||
| return old_graph; | |||
| } | |||
| void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) { | |||
| auto fmk = config->fmk; | |||
| auto is_train = config->trainModel; | |||
| opt::PassRegistry("ConvWeightToKHWC", std::make_shared<opt::ConvWeightToKHWC>()); | |||
| opt::PassRegistry("ConvWeightToKCHW", std::make_shared<opt::ConvWeightToKCHW>()); | |||
| opt::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)); | |||
| opt::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()); | |||
| opt::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)); | |||
| opt::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)); | |||
| opt::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)); | |||
| } | |||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { | |||
| AppendPassToStoreRoom(config); | |||
| auto new_graph = TransformFuncGraph(main_graph, config); | |||
| if (new_graph == nullptr) { | |||
| MS_LOG(ERROR) << "optimizer failed."; | |||
| @@ -51,13 +51,13 @@ class AnfTransform { | |||
| static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position); | |||
| int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs); | |||
| int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| void AppendPassToStoreRoom(const converter::Flags *config); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,13 +20,14 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/version.h" | |||
| #include "ir/func_graph.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "tools/converter/dump_graph_init.h" | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include "tools/converter/optimizer_manager.h" | |||
| #include "tools/optimizer/graph/control_flow_pass.h" | |||
| namespace mindspore { | |||
| @@ -192,10 +193,8 @@ STATUS ExportModel(const FuncGraphPtr &graph) { | |||
| return RET_ERROR; | |||
| } | |||
| (void)Manage(mirror_graph, true); | |||
| auto format_pass = std::make_shared<opt::UnifyFormatPass>(); | |||
| format_pass->Init(flags->fmk, flags->trainModel); | |||
| if (!format_pass->Run(mirror_graph)) { | |||
| MS_LOG(ERROR) << "Run format pass failed."; | |||
| if (!opt::RunOptimizerPass(mirror_graph, {"InferShapePass", "DecreaseTransposeAlgo"})) { | |||
| MS_LOG(ERROR) << "Run transpose opt pass failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| @@ -23,7 +23,7 @@ | |||
| #include "tools/converter/import/mindir_adjust.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include "tools/converter/parser/unify_format.h" | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| @@ -208,8 +208,8 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel); | |||
| if (!insert_transpose->Run(func_graph)) { | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_MS, flag.trainModel); | |||
| if (!unify_format->Run(func_graph)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/optimizer_manager.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/registry/pass_content.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "func graph is nullptr."; | |||
| return false; | |||
| } | |||
| auto &passes_info = PassStoreRoomInfo(); | |||
| for (auto &name : pass_names) { | |||
| if (passes_info.find(name) == passes_info.end()) { | |||
| MS_LOG(ERROR) << "cannot find required pass."; | |||
| return false; | |||
| } | |||
| if (!passes_info[name]->Run(func_graph)) { | |||
| MS_LOG(ERROR) << "run pass failed, pass name is " << name; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "func graph is nullptr."; | |||
| return false; | |||
| } | |||
| auto &external_assigned = ExternalAssignedPassesInfo(); | |||
| if (external_assigned.find(position) == external_assigned.end()) { | |||
| MS_LOG(DEBUG) << "there is no external pass in current position, position is " << position; | |||
| return true; | |||
| } | |||
| auto &passes_info = PassStoreRoomInfo(); | |||
| for (auto &name : external_assigned[position]) { | |||
| if (passes_info.find(name) == passes_info.end()) { | |||
| MS_LOG(ERROR) << "cannot find required pass."; | |||
| return false; | |||
| } | |||
| if (!passes_info[name]->Run(func_graph)) { | |||
| MS_LOG(ERROR) << "run pass failed, pass name is " << name; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include "include/registry/pass_registry.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool RunOptimizerPass(const FuncGraphPtr &func_graph, std::vector<std::string> pass_names); | |||
| bool RunExternalPass(const FuncGraphPtr &func_graph, PassPosition position); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H | |||
| @@ -31,7 +31,7 @@ | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include "tools/converter/parser/unify_format.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| namespace mindspore::lite { | |||
| @@ -104,8 +104,8 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false); | |||
| if (!insert_transpose->Run(res_graph_)) { | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_CAFFE, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -54,21 +54,21 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case 2: { | |||
| case kBuildInputFlagTwo: { | |||
| auto value_data = opt::CastToInt(value_ptr); | |||
| auto param_node = | |||
| opt::BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case 3: { | |||
| case kBuildInputFlagThree: { | |||
| auto value_data = opt::CastToVec2DInt(value_ptr); | |||
| auto param_node = | |||
| opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case 4: { | |||
| case kBuildInputFlagFour: { | |||
| auto value_data = GetValue<float>(value_ptr); | |||
| auto param_node = | |||
| opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| @@ -1,511 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include <queue> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "ops/op_utils.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| using mindspore::lite::NCHW_SHAPE; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| constexpr size_t kNCHWDimNumber = 4; | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| bool IsSpecialType(const CNodePtr &cnode) { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) || | |||
| opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) || | |||
| opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = opt::GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = opt::GetNCHWOpMap(); | |||
| if (fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = opt::kNHWC2NCHW; | |||
| trans_info->post_ = opt::kNCHW2NHWC; | |||
| } else if (fmk_type_ == lite::converter::FmkType_TF) { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) { | |||
| trans_info->pre_ = opt::kNCHW2NHWC; | |||
| trans_info->post_ = opt::kNHWC2NCHW; | |||
| } | |||
| if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) { | |||
| trans_info->pre_ = opt::kNHWC2NCHW; | |||
| trans_info->post_ = opt::kNCHW2NHWC; | |||
| } | |||
| } else { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) { | |||
| if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = opt::kNCHW2NHWC; | |||
| trans_info->post_ = opt::kNHWC2NCHW; | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm, bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr new_input = nullptr; | |||
| AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; | |||
| std::string trans_name = | |||
| before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; | |||
| new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name); | |||
| auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0)); | |||
| if (perm == NC2NH) { | |||
| new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW)); | |||
| } else if (perm == NH2NC) { | |||
| new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC)); | |||
| } | |||
| return new_input; | |||
| } | |||
| STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, | |||
| bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr new_input = nullptr; | |||
| new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index); | |||
| if (new_input == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (new_input == cnode->input(index) || new_input == cnode) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| } | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| if (before) { | |||
| tr.SetEdge(cnode, index, new_input); | |||
| tr.Commit(); | |||
| } else { | |||
| func_graph->manager()->Replace(cnode, new_input); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = opt::GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = opt::GetNCHWOpMap(); | |||
| if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() && | |||
| specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| MS_LOG(ERROR) << "op don't meet nhwc condition."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end() | |||
| ? specify_nhwc_op_map.at(prim->name()) | |||
| : specify_nchw_op_map.at(prim->name()); | |||
| if (insert_index.empty()) { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) { | |||
| insert_index.push_back(1); | |||
| } else { | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| insert_index.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| for (auto &index : insert_index) { | |||
| if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!cnode->abstract()->isa<abstract::AbstractTuple>()) { | |||
| if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } else { | |||
| auto node_users = func_graph->manager()->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| auto post_node = node_user.first; | |||
| CNodePtr tuple_get_item = nullptr; | |||
| if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) { | |||
| if (!train_flag_) { | |||
| MS_LOG(ERROR) << "post node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } else { | |||
| tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0); | |||
| post_node = tuple_get_item; | |||
| func_graph->manager()->Replace(cnode, tuple_get_item); | |||
| } | |||
| } | |||
| if (func_graph->manager()->node_users()[post_node].empty()) { | |||
| continue; | |||
| } | |||
| auto post_cnode = post_node->cast<CNodePtr>(); | |||
| if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (tuple_get_item != nullptr) { | |||
| func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1)); | |||
| } | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| auto node = cnode->input(i); | |||
| if (!utils::isa<ParameterPtr>(node)) { | |||
| continue; | |||
| } | |||
| auto param_node = node->cast<ParameterPtr>(); | |||
| if (param_node->has_default()) { | |||
| continue; | |||
| } | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| if (shape_vector.size() != 4) { | |||
| continue; | |||
| } | |||
| if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 && | |||
| shape_vector[1] == -1) { | |||
| continue; | |||
| } | |||
| std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H], | |||
| shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]}; | |||
| abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims)); | |||
| auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre"); | |||
| auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0)); | |||
| new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC)); | |||
| if (trans_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| func_graph->manager()->Replace(param_node, trans_cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| opt::TransTypePair trans_info; | |||
| GetTransNodeFormatType(cnode, &trans_info); | |||
| if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH; | |||
| auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC; | |||
| if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) { | |||
| return RET_OK; | |||
| } | |||
| if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto sub_inputs = sub_graph->get_inputs(); | |||
| sub_inputs_map_[sub_graph] = sub_inputs; | |||
| for (auto &node : sub_inputs) { | |||
| auto param_node = node->cast<ParameterPtr>(); | |||
| MS_ASSERT(param_node != nullptr); | |||
| auto node_name = node->fullname_with_scope(); | |||
| auto last_underline = node_name.find_last_of("_"); | |||
| node_name = node_name.substr(0, last_underline); | |||
| last_underline = node_name.find_last_of("_"); | |||
| auto index = std::stoi(node_name.substr(last_underline + 1)) + 3; | |||
| param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone()); | |||
| if (utils::isa<CNodePtr>(cnode->input(index))) { | |||
| ShapeVector shape_vec = {-1}; | |||
| auto out_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| MS_ASSERT(trans_cnode != nullptr); | |||
| auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0)); | |||
| if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) { | |||
| param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec)); | |||
| } | |||
| } else { | |||
| lite::DataInfo data_info; | |||
| if (utils::isa<ParameterPtr>(cnode->input(index))) { | |||
| if (cnode->input(index)->cast<ParameterPtr>()->has_default()) { | |||
| param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param()); | |||
| } | |||
| continue; | |||
| } | |||
| auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info); | |||
| if (status != lite::RET_OK) { | |||
| continue; | |||
| } | |||
| ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); | |||
| if (data_info.data_.empty()) { | |||
| param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec)); | |||
| } else { | |||
| param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec, | |||
| data_info.data_.data(), data_info.data_.size())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void InsertTranspose::ResetSubGraphInput() { | |||
| for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { | |||
| auto &sub_graph = iter->first; | |||
| auto &sub_inputs = iter->second; | |||
| auto manager = sub_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| for (auto &sub_input : sub_inputs) { | |||
| auto param_node = sub_graph->add_parameter(); | |||
| MS_ASSERT(param_node != nullptr); | |||
| param_node->set_abstract(sub_input->abstract()->Clone()); | |||
| param_node->set_name(sub_input->fullname_with_scope()); | |||
| manager->Replace(sub_input, param_node); | |||
| auto sub_param_input = sub_input->cast<ParameterPtr>(); | |||
| MS_ASSERT(sub_param_input != nullptr); | |||
| sub_param_input->set_default_param(nullptr); | |||
| } | |||
| } | |||
| } | |||
| void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_input = return_node->inputs(); | |||
| lite::RemoveIfDepend(return_node); | |||
| lite::RemoveIfMakeTuple(return_node); | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) { | |||
| continue; | |||
| } | |||
| auto node_name = return_node->input(i)->fullname_with_scope(); | |||
| if (node_name.substr(node_name.size() - 5) != "_post") { | |||
| continue; | |||
| } | |||
| auto trans_cnode = return_node->input(i)->cast<CNodePtr>(); | |||
| MS_ASSERT(trans_cnode != nullptr); | |||
| auto trans_input = trans_cnode->input(1); | |||
| auto trans_input_name = trans_input->fullname_with_scope(); | |||
| if (utils::isa<ParameterPtr>(trans_input)) { | |||
| trans_input->cast<ParameterPtr>()->set_name(node_name); | |||
| } else if (utils::isa<CNodePtr>(trans_input)) { | |||
| trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name); | |||
| } | |||
| trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode"; | |||
| trans_cnode->set_fullname_with_scope(trans_input_name); | |||
| } | |||
| return_node->set_inputs(origin_input); | |||
| } | |||
| void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_inputs = return_node->inputs(); | |||
| lite::RemoveIfDepend(return_node); | |||
| lite::RemoveIfMakeTuple(return_node); | |||
| AbstractBasePtrList abstract_list; | |||
| bool infer_done = true; | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| auto abstract_base = opt::GetCNodeInputAbstract(return_node, i); | |||
| MS_ASSERT(abstract_base != nullptr); | |||
| abstract_list.emplace_back(abstract_base->Clone()); | |||
| auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>(); | |||
| MS_ASSERT(abstract_tensor != nullptr); | |||
| auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape()); | |||
| MS_ASSERT(shape_ptr != nullptr); | |||
| auto shape = shape_ptr->shape(); | |||
| if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { | |||
| infer_done = false; | |||
| } | |||
| if (utils::isa<CNodePtr>(return_node->input(i))) { | |||
| auto input_cnode = return_node->input(i)->cast<CNodePtr>(); | |||
| if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { | |||
| input_cnode = input_cnode->input(1)->cast<CNodePtr>(); | |||
| } | |||
| auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) { | |||
| infer_done = false; | |||
| } | |||
| } | |||
| } | |||
| return_node->set_inputs(origin_inputs); | |||
| if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) { | |||
| cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| } else { | |||
| if (abstract_list.size() != 1) { | |||
| MS_LOG(ERROR) << "cnode output is invalid."; | |||
| } | |||
| cnode->set_abstract(abstract_list.front()); | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done)); | |||
| } | |||
| bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (main_graph) { | |||
| status = HandleGraphInput(func_graph, cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| SetSubGraphAbstract(cnode, sub_func_graph); | |||
| continue; | |||
| } | |||
| status = HandleGraphNode(func_graph, cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(opt::kInferDone) != nullptr) { | |||
| prim->EraseAttr(opt::kInferDone); | |||
| } | |||
| if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool InsertTranspose::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| auto prim = GetValueNode<PrimitivePtr>(node); | |||
| if (prim == nullptr) { | |||
| continue; | |||
| } | |||
| } | |||
| // insert transpose for some ops whose format must be NHWC, which is depend on framework. | |||
| // In this process, tranpose can be fused, which the original graph may not be able to restored. | |||
| if (!BasicProcess(func_graph, true)) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| return true; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -35,7 +35,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx_pad_adjust.h" | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| #include "ops/transpose.h" | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include "tools/converter/parser/unify_format.h" | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| namespace mindspore { | |||
| @@ -95,8 +95,8 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false); | |||
| if (!insert_transpose->Run(res_graph_)) { | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_ONNX, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -33,7 +33,7 @@ | |||
| #include "tools/converter/parser/tf/functionalize_control_op_pass.h" | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include "tools/converter/parser/unify_format.h" | |||
| using mindspore::lite::converter::FmkType_TF; | |||
| namespace mindspore { | |||
| @@ -576,8 +576,8 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false); | |||
| if (!insert_transpose->Run(res_graph_)) { | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TF, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -30,7 +30,7 @@ | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/parser/tflite/tflite_inputs_adjust.h" | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| #include "tools/converter/parser/insert_transpose.h" | |||
| #include "tools/converter/parser/unify_format.h" | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| namespace mindspore::lite { | |||
| @@ -105,8 +105,8 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false); | |||
| if (!insert_transpose->Run(res_graph_)) { | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(lite::converter::FmkType_TFLITE, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/unify_format.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| constexpr int kInputChannal = 3; | |||
| } | |||
| void UnifyFormatToNHWC::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = opt::GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = opt::GetNCHWOpMap(); | |||
| if (fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = opt::kNHWC2NCHW; | |||
| trans_info->post_ = opt::kNCHW2NHWC; | |||
| } else if (fmk_type_ == lite::converter::FmkType_TF) { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) { | |||
| trans_info->pre_ = opt::kNCHW2NHWC; | |||
| trans_info->post_ = opt::kNHWC2NCHW; | |||
| } | |||
| if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) { | |||
| trans_info->pre_ = opt::kNHWC2NCHW; | |||
| trans_info->post_ = opt::kNCHW2NHWC; | |||
| } | |||
| } else { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) { | |||
| if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = opt::kNCHW2NHWC; | |||
| trans_info->post_ = opt::kNHWC2NCHW; | |||
| } | |||
| } | |||
| } | |||
| void UnifyFormatToNHWC::SetSensitiveOps() { | |||
| auto &sensitive_nhwc_ops = opt::GetNHWCOpMap(); | |||
| auto &sensitive_nchw_ops = opt::GetNCHWOpMap(); | |||
| sensitive_ops_.insert(sensitive_nhwc_ops.begin(), sensitive_nhwc_ops.end()); | |||
| sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end()); | |||
| } | |||
| bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { | |||
| if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) { | |||
| return false; | |||
| } | |||
| if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && | |||
| shape[opt::kInputIndexThree] == kInputChannal && shape[1] == -1) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatToNHWC::DecideWhetherInferShapeForNewNode() { return false; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_ | |||
| #include "tools/optimizer/format/to_format_base.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class UnifyFormatToNHWC : public opt::ToFormatBase { | |||
| public: | |||
| explicit UnifyFormatToNHWC(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false) | |||
| : ToFormatBase(fmk_type, train_flag) {} | |||
| ~UnifyFormatToNHWC() override = default; | |||
| private: | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override; | |||
| void SetSensitiveOps() override; | |||
| bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) override; | |||
| bool DecideWhetherInferShapeForNewNode() override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_UNIFY_FORMAT_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "include/registry/pass_registry.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| std::map<std::string, PassPtr> &MS_API PassStoreRoomInfo(); | |||
| std::map<PassPosition, std::vector<std::string>> &MS_API ExternalAssignedPassesInfo(); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_CONTENT_H | |||
| @@ -15,31 +15,38 @@ | |||
| */ | |||
| #include "include/registry/pass_registry.h" | |||
| #include <unordered_map> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/registry/pass_content.h" | |||
| #include "src/common/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| PassRegistry *PassRegistry::GetInstance() { | |||
| static PassRegistry instance; | |||
| return &instance; | |||
| } | |||
| void PassRegistry::RegPass(int position, const PassPtr &pass) { | |||
| namespace { | |||
| std::map<std::string, PassPtr> pass_store_room; | |||
| std::map<PassPosition, std::vector<std::string>> external_assigned_passes; | |||
| std::mutex pass_mutex; | |||
| void RegPass(const std::string &pass_name, const PassPtr &pass) { | |||
| if (pass == nullptr) { | |||
| MS_LOG(ERROR) << "pass is nullptr."; | |||
| return; | |||
| } | |||
| auto instance = PassRegistry::GetInstance(); | |||
| std::unique_lock<std::mutex> lock(instance->mutex_); | |||
| instance->passes_[position] = pass; | |||
| std::unique_lock<std::mutex> lock(pass_mutex); | |||
| pass_store_room[pass_name] = pass; | |||
| } | |||
| } // namespace | |||
| const std::unordered_map<int, PassPtr> &PassRegistry::GetPasses() const { | |||
| auto instance = PassRegistry::GetInstance(); | |||
| std::unique_lock<std::mutex> lock(instance->mutex_); | |||
| return instance->passes_; | |||
| PassRegistry::PassRegistry(const std::string &pass_name, const PassPtr &pass) { RegPass(pass_name, pass); } | |||
| PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &assigned) { | |||
| std::unique_lock<std::mutex> lock(pass_mutex); | |||
| external_assigned_passes[position] = assigned; | |||
| } | |||
| std::map<std::string, PassPtr> &PassStoreRoomInfo() { return pass_store_room; } | |||
| std::map<PassPosition, std::vector<std::string>> &ExternalAssignedPassesInfo() { return external_assigned_passes; } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -852,10 +852,10 @@ STATUS GetFilterDim(const std::vector<int64_t> &oriDims, kTransFilterType type, | |||
| int64_t *filterH, int64_t *filterW) { | |||
| MS_ASSERT(oriDims.size() == 4); | |||
| std::unordered_map<kTransFilterType, int> maps = { | |||
| {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, | |||
| {kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, | |||
| {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, | |||
| {kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7}, | |||
| {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, {kCKHW2HWKC, 2}, | |||
| {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, | |||
| {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, {kNHWC2CKHW, 5}, {kKHWC2KCHW, 5}, {kCHWK2HWCK, 6}, | |||
| {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7}, | |||
| }; | |||
| if (maps.find(type) == maps.end()) { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| @@ -915,10 +915,10 @@ STATUS SetFilterDim(const tensor::TensorPtr &tensor, kTransFilterType type, int3 | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| std::unordered_map<kTransFilterType, int> maps = { | |||
| {kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, | |||
| {kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, | |||
| {kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, | |||
| {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6}, | |||
| {kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, {kKCHW2HWKC, 2}, | |||
| {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, {kKHWC2KCHW, 3}, {kHWCK2CKHW, 4}, | |||
| {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, | |||
| {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6}, | |||
| }; | |||
| if (maps.find(type) == maps.end()) { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| @@ -1137,10 +1137,10 @@ static STATUS TransFilterData(const tensor::TensorPtr &tensor, kTransFilterType | |||
| T *p2Buff = nullptr; | |||
| std::unordered_map<kTransFilterType, int> maps = { | |||
| {kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, | |||
| {kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, | |||
| {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, | |||
| {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8}, | |||
| {kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, {kKCHW2KHWC, 3}, | |||
| {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, | |||
| {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, | |||
| {kKHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8}, | |||
| }; | |||
| if (maps.find(type) == maps.end()) { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| @@ -1510,5 +1510,23 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp | |||
| tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index)); | |||
| return tuple_cnode; | |||
| } | |||
| STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape) { | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "abstract of cnode is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensor>(abstract)) { | |||
| MS_LOG(ERROR) << "abstract of cnode is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>(); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "shape of cnode's output is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| *shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -43,6 +43,8 @@ inline constexpr size_t kInputSizeTwo = 2; | |||
| inline constexpr size_t kInputSizeThree = 3; | |||
| inline constexpr size_t kInputSizeFour = 4; | |||
| inline constexpr size_t kInputSizeFive = 5; | |||
| inline const std::vector<int> kNH2NC = {0, 3, 1, 2}; | |||
| inline const std::vector<int> kNC2NH = {0, 2, 3, 1}; | |||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | |||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | |||
| inline const PrimitivePtr kPrimConv2DBackpropInputFusion = | |||
| @@ -178,6 +180,8 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu | |||
| CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index); | |||
| STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape); | |||
| template <const PrimitivePtr *prim = nullptr> | |||
| inline bool IsSpecifiedNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/format/conv_weight_format.h" | |||
| #include <vector> | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/parser_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| STATUS ConvWeightFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transform conv weight format failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| if (ConvWeightFormatTrans(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transform conv weight format failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) && | |||
| !CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) && | |||
| !CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(cnode->inputs().size() > kConvWeightIndex); | |||
| auto weight_node = cnode->input(kConvWeightIndex); | |||
| MS_ASSERT(weight_node != nullptr); | |||
| if (utils::isa<CNodePtr>(weight_node)) { | |||
| if (lite::HandleWeightConst(graph, cnode, weight_node->cast<CNodePtr>(), src_format_, dst_format_) != | |||
| lite::RET_OK) { | |||
| MS_LOG(ERROR) << "handle cnode weight failed."; | |||
| return RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| if (TransferConvWeight(weight_node) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transfer weight format failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (utils::isa<Parameter>(weight_node)) { | |||
| if (lite::HandleWeightSharing(graph, dst_format_, weight_node->cast<ParameterPtr>(), src_format_, dst_format_) != | |||
| lite::RET_OK) { | |||
| MS_LOG(ERROR) << "handle weight-sharing failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvWeightFormatBase::TransferConvWeight(const AnfNodePtr &weight_node) { | |||
| MS_ASSERT(weight_node != nullptr); | |||
| auto weight_value = GetTensorInfo(weight_node); | |||
| if (weight_value == nullptr) { | |||
| MS_LOG(ERROR) << "weight node must const value"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto status = TransFilterFormat(weight_value, src_format_, dst_format_); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "trans conv weight failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto type_id = static_cast<TypeId>(weight_value->data_type()); | |||
| auto shape = weight_value->shape(); | |||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||
| auto abstract = lite::CreateTensorAbstract(shape_vector, type_id); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| weight_node->set_abstract(abstract); | |||
| return lite::RET_OK; | |||
| } | |||
| bool ConvWeightFormatBase::Run(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (src_format_ == dst_format_) { | |||
| return true; | |||
| } | |||
| auto manager = Manage(graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto status = ConvWeightFormatTrans(graph); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ConvWeightFormatBase : public Pass { | |||
| public: | |||
| explicit ConvWeightFormatBase(const std::string &name = "ConvWeightFormatBase") : Pass(name) {} | |||
| ~ConvWeightFormatBase() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph); | |||
| STATUS TransferConvWeight(const AnfNodePtr &weight_node); | |||
| protected: | |||
| schema::Format src_format_{schema::Format_KHWC}; | |||
| schema::Format dst_format_{schema::Format_KHWC}; | |||
| }; | |||
| class ConvWeightToKHWC : public ConvWeightFormatBase { | |||
| public: | |||
| ConvWeightToKHWC() : ConvWeightFormatBase("ConvWeightToKHWC") { src_format_ = schema::Format_KCHW; } | |||
| ~ConvWeightToKHWC() override = default; | |||
| }; | |||
| class ConvWeightToKCHW : public ConvWeightFormatBase { | |||
| public: | |||
| ConvWeightToKCHW() : ConvWeightFormatBase("ConvWeightToKCHW") { dst_format_ = schema::Format_KCHW; } | |||
| ~ConvWeightToKCHW() override = default; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_CONV_WEIGHT_FORMAT_H_ | |||
| @@ -0,0 +1,150 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/format/delete_redundant_transpose.h" | |||
| #include <vector> | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kDimNumber = 4; | |||
| } // namespace | |||
| STATUS DeleteRedundantTranspose::DeleteNot4DTranspose(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNode>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "delete transpose failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (DeleteNot4DTranspose(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "delete transpose failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| if (!CheckPrimitiveType(node, prim::kPrimTranspose)) { | |||
| continue; | |||
| } | |||
| auto abstract = GetCNodeInputAbstract(cnode, 1); | |||
| ShapeVector shape; | |||
| if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::vector<int> perm; | |||
| if (GetTransposePerm(cnode, &perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch transpose perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!shape.empty() && shape.size() != perm.size()) { | |||
| MS_LOG(DEBUG) << "transpose node need to be deleted."; | |||
| manager->Replace(node, cnode->input(1)); | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS DeleteRedundantTranspose::TransTransFusion(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_lite = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_lite) { | |||
| if (!utils::isa<CNode>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (TransTransFusion(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "delete transpose failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (TransTransFusion(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "delete transpose failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || | |||
| !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) { | |||
| continue; | |||
| } | |||
| std::vector<int> post_perm; | |||
| if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transpose rm cannot be obtained, " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::vector<int> pre_perm; | |||
| auto pre_cnode = cnode->input(1)->cast<CNodePtr>(); | |||
| MS_ASSERT(pre_cnode != nullptr); | |||
| if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transpose rm cannot be obtained, " << pre_cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) { | |||
| func_graph->manager()->Replace(cnode, pre_cnode->input(1)); | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool DeleteRedundantTranspose::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| if (TransTransFusion(func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "ranspose and transpose fusion failed."; | |||
| return false; | |||
| } | |||
| if (DeleteNot4DTranspose(func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "delete not 4D transpose failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_ | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class DeleteRedundantTranspose : public Pass { | |||
| public: | |||
| DeleteRedundantTranspose() : Pass("delete_redundant_transpose") {} | |||
| ~DeleteRedundantTranspose() = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph); | |||
| STATUS TransTransFusion(const FuncGraphPtr &func_graph); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_ | |||
| @@ -0,0 +1,315 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/format/to_format_base.h" | |||
| #include "ops/op_utils.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| using mindspore::lite::NHWC_SHAPE; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kDimNumber = 4; | |||
| } // namespace | |||
| STATUS ToFormatBase::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, | |||
| bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr trans_input = before ? cnode->input(index) : cnode; | |||
| std::string trans_name = before ? cnode->fullname_with_scope() + "_pre_" + std::to_string(index - 1) | |||
| : cnode->fullname_with_scope() + "_post"; | |||
| auto trans_cnode = opt::GenTransposeNode(func_graph, trans_input, perm, trans_name); | |||
| if (DecideWhetherInferShapeForNewNode()) { | |||
| auto status = node_infer_shape_->InferShape(trans_cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer generated trans node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } else { | |||
| auto abstract = trans_input->abstract(); | |||
| if (abstract != nullptr) { | |||
| trans_cnode->set_abstract(abstract->Clone()); | |||
| } | |||
| } | |||
| auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0)); | |||
| if (perm == kNC2NH) { | |||
| trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW)); | |||
| } else if (perm == kNH2NC) { | |||
| trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC)); | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| } | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| if (before) { | |||
| tr.SetEdge(cnode, index, trans_cnode); | |||
| tr.Commit(); | |||
| } else { | |||
| func_graph->manager()->Replace(cnode, trans_cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ToFormatBase::ModifyCNodeAbstract(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto abstract_base = cnode->abstract(); | |||
| std::vector<AbstractBasePtr> abstracts; | |||
| if (utils::isa<abstract::AbstractTuple>(abstract_base)) { | |||
| auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base); | |||
| abstracts = abstract_tuple->elements(); | |||
| } else { | |||
| abstracts.push_back(abstract_base); | |||
| } | |||
| for (auto &abstract : abstracts) { | |||
| ShapeVector shape; | |||
| if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch shape failed, " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (shape.size() != kDimNumber) { | |||
| MS_LOG(DEBUG) << "shape don't need to modify."; | |||
| continue; | |||
| } | |||
| if (format_ == mindspore::NCHW) { | |||
| ShapeVector transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]}; | |||
| abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape)); | |||
| } else { | |||
| ShapeVector transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]}; | |||
| abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape)); | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ToFormatBase::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (sensitive_ops_.find(prim->name()) == sensitive_ops_.end()) { | |||
| MS_LOG(ERROR) << "op don't meet condition."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto insert_index = sensitive_ops_.at(prim->name()); | |||
| if (insert_index.empty()) { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) { | |||
| insert_index.push_back(1); | |||
| } else { | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| insert_index.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| for (auto &index : insert_index) { | |||
| if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!cnode->abstract()->isa<abstract::AbstractTuple>()) { | |||
| if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } else { | |||
| auto node_users = func_graph->manager()->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| auto post_node = node_user.first; | |||
| CNodePtr tuple_get_item = nullptr; | |||
| if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) { | |||
| if (!train_flag_) { | |||
| MS_LOG(ERROR) << "post node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } else { | |||
| tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0); | |||
| post_node = tuple_get_item; | |||
| func_graph->manager()->Replace(cnode, tuple_get_item); | |||
| } | |||
| } | |||
| if (func_graph->manager()->node_users()[post_node].empty()) { | |||
| continue; | |||
| } | |||
| auto post_cnode = post_node->cast<CNodePtr>(); | |||
| if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (tuple_get_item != nullptr) { | |||
| func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1)); | |||
| } | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto graph_input = func_graph->get_inputs(); | |||
| for (auto &input : graph_input) { | |||
| auto input_param = input->cast<ParameterPtr>(); | |||
| MS_ASSERT(input_param != nullptr); | |||
| auto abstract = input_param->abstract(); | |||
| MS_ASSERT(abstract != nullptr); | |||
| ShapeVector shape; | |||
| if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, shape)) { | |||
| continue; | |||
| } | |||
| ShapeVector transfer_shape; | |||
| if (format_ == mindspore::NCHW) { | |||
| transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]}; | |||
| } else { | |||
| transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]}; | |||
| } | |||
| CNodePtr trans_cnode; | |||
| if (format_ == mindspore::NCHW) { | |||
| trans_cnode = opt::GenTransposeNode(func_graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh"); | |||
| } else { | |||
| trans_cnode = opt::GenTransposeNode(func_graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc"); | |||
| } | |||
| if (trans_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "create transpose cnode failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0)); | |||
| MS_ASSERT(trans_prim != nullptr); | |||
| if (format_ == mindspore::NCHW) { | |||
| trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW)); | |||
| } else { | |||
| trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC)); | |||
| } | |||
| trans_cnode->set_abstract(abstract->Clone()); | |||
| abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape)); | |||
| func_graph->manager()->Replace(input, trans_cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ToFormatBase::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| opt::TransTypePair trans_info; | |||
| GetTransNodeFormatType(cnode, &trans_info); | |||
| if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? kNH2NC : kNC2NH; | |||
| auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? kNC2NH : kNH2NC; | |||
| if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) { | |||
| return lite::RET_OK; | |||
| } | |||
| if (ModifyCNodeAbstract(cnode) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "adjust cnode's output shape failed, " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| continue; | |||
| } | |||
| status = HandleGraphNode(func_graph, cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| if (main_graph) { | |||
| status = HandleGraphInput(func_graph); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ToFormatBase::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| if (format_ != mindspore::NHWC && format_ != mindspore::NCHW) { | |||
| MS_LOG(ERROR) << "format transferring only support nc2nh or nh2nc."; | |||
| return false; | |||
| } | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| node_infer_shape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_); | |||
| if (node_infer_shape_ == nullptr) { | |||
| MS_LOG(ERROR) << "create NodeInferShape object failed."; | |||
| return false; | |||
| } | |||
| SetSensitiveOps(); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| auto prim = GetValueNode<PrimitivePtr>(node); | |||
| if (prim == nullptr) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!BasicProcess(func_graph, true)) { | |||
| MS_LOG(ERROR) << "transfer format failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -14,49 +14,51 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "utils/utils.h" | |||
| #include <vector> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include "tools/optimizer/graph/infershape_pass.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class InsertTranspose { | |||
| namespace opt { | |||
| class ToFormatBase : public Pass { | |||
| public: | |||
| InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {} | |||
| ~InsertTranspose() = default; | |||
| bool Run(const FuncGraphPtr &func_graph); | |||
| explicit ToFormatBase(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false, | |||
| std::string pass_name = "to_format_base") | |||
| : Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {} | |||
| ~ToFormatBase() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); | |||
| STATUS HandleGraphInput(const FuncGraphPtr &func_graph); | |||
| STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before, | |||
| size_t index = 0); | |||
| STATUS ModifyCNodeAbstract(const CNodePtr &cnode); | |||
| private: | |||
| AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm, bool before, size_t index); | |||
| bool ResetFuncGraph(const FuncGraphPtr &func_graph); | |||
| bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info); | |||
| STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void ResetSubGraphInput(); | |||
| void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| protected: | |||
| virtual void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0; | |||
| virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); } | |||
| virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { return true; } | |||
| virtual bool DecideWhetherInferShapeForNewNode() { return true; } | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool train_flag_{false}; | |||
| std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_; | |||
| mindspore::Format format_{mindspore::NHWC}; | |||
| std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr}; | |||
| std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/format/to_nchw_format.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void ToNCHWFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) { | |||
| trans_info->pre_ = opt::kNHWC2NCHW; | |||
| trans_info->post_ = opt::kNCHW2NHWC; | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_ | |||
| #include "tools/optimizer/format/to_format_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ToNCHWFormat : public ToFormatBase { | |||
| public: | |||
| explicit ToNCHWFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false) | |||
| : ToFormatBase(fmk_type, train_flag, "to_nchw_format") { | |||
| format_ = mindspore::NCHW; | |||
| } | |||
| ~ToNCHWFormat() = default; | |||
| private: | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NCHW_FORMAT_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/format/to_nhwc_format.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void ToNHWCFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (sensitive_ops_.find(prim->name()) != sensitive_ops_.end()) { | |||
| trans_info->pre_ = opt::kNCHW2NHWC; | |||
| trans_info->post_ = opt::kNHWC2NCHW; | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_ | |||
| #include "tools/optimizer/format/to_format_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ToNHWCFormat : public ToFormatBase { | |||
| public: | |||
| explicit ToNHWCFormat(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false) | |||
| : ToFormatBase(fmk_type, train_flag, "to_nhwc_format") {} | |||
| ~ToNHWCFormat() = default; | |||
| private: | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_NHWC_FORMAT_H_ | |||
| @@ -20,13 +20,9 @@ | |||
| #include <vector> | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "mindspore/core/ops/transpose.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| } // namespace | |||
| bool IsBNCNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| auto anf_node = utils::cast<AnfNodePtr>(n); | |||
| @@ -142,7 +138,7 @@ AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return nullptr; | |||
| } | |||
| if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) { | |||
| if ((pre_perm == kNH2NC && post_perm == kNC2NH) || (pre_perm == kNC2NH && post_perm == kNH2NC)) { | |||
| return pre_cnode->input(1); | |||
| } | |||
| return nullptr; | |||
| @@ -166,8 +162,10 @@ AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const minds | |||
| return nullptr; | |||
| } | |||
| const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>(); | |||
| auto perm_node = transpose_cnode->input(2); | |||
| auto perm_node = transpose_cnode->input(kInputIndexTwo); | |||
| auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post"); | |||
| trans_post_node->set_abstract(any_cnode->abstract()->Clone()); | |||
| any_cnode->set_abstract(transpose_cnode->input(1)->abstract()->Clone()); | |||
| auto tr = func_graph->manager()->Transact(); | |||
| tr.SetEdge(any_cnode, 1, transpose_cnode->input(1)); | |||
| tr.Commit(); | |||
| @@ -19,9 +19,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/optimizer/common/multiple_pattern_process_pass.h" | |||
| namespace mindspore { | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include "tools/optimizer/graph/decrease_transpose_algo.h" | |||
| #include <queue> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| @@ -24,14 +24,9 @@ | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| using mindspore::lite::NCHW_SHAPE; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr int kInputChannel = 3; | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node, | |||
| std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes, | |||
| std::set<CNodePtr> *middle_nodes) { | |||
| @@ -101,23 +96,40 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode | |||
| return lite::RET_OK; | |||
| } | |||
| bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes, | |||
| const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| for (auto &in_cnode : in_nodes) { | |||
| void SetTransType(const std::set<CNodePtr> &cnodes, FormatTransNodeType *trans_type) { | |||
| MS_ASSERT(trans_type != nullptr); | |||
| FormatTransNodeType local_trans_type; | |||
| for (auto &cnode : cnodes) { | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK || | |||
| perm != NH2NC) { | |||
| return false; | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK || | |||
| (perm != kNH2NC && perm != kNC2NH)) { | |||
| *trans_type = kNONE; | |||
| return; | |||
| } | |||
| } | |||
| for (auto &out_cnode : out_nodes) { | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK || | |||
| perm != NC2NH) { | |||
| return false; | |||
| local_trans_type = perm == kNH2NC ? kNHWC2NCHW : kNCHW2NHWC; | |||
| *trans_type = *trans_type == kNONE ? local_trans_type : *trans_type; | |||
| if (*trans_type != local_trans_type) { | |||
| *trans_type = kNONE; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes, | |||
| const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes, | |||
| TransTypePair *trans_info) { | |||
| MS_ASSERT(func_graph != nullptr && trans_info != nullptr); | |||
| SetTransType(in_nodes, &trans_info->pre_); | |||
| if (trans_info->pre_ == kNONE) { | |||
| return false; | |||
| } | |||
| SetTransType(out_nodes, &trans_info->post_); | |||
| if (trans_info->post_ == kNONE) { | |||
| return false; | |||
| } | |||
| if (trans_info->pre_ == trans_info->post_) { | |||
| return false; | |||
| } | |||
| auto &dynamic_ops = GetDynamicFormatOpList(); | |||
| TransposeStrategy transpose_strategy; | |||
| for (auto &middle_cnode : middle_nodes) { | |||
| @@ -133,8 +145,8 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set< | |||
| return true; | |||
| } | |||
| void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, | |||
| bool train_flag) { | |||
| void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, | |||
| bool train_flag, FormatTransNodeType trans_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (utils::isa<CNodePtr>(cnode->input(index))) { | |||
| return; | |||
| @@ -157,42 +169,24 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s | |||
| (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) { | |||
| return; | |||
| } | |||
| std::vector<int> new_shape = data_info.shape_; | |||
| ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end()); | |||
| if (data_info.shape_.size() == 1) { | |||
| new_shape = {1, 1, 1, data_info.shape_[0]}; | |||
| expand_shape = {1, 1, 1, data_info.shape_[0]}; | |||
| } else if (data_info.shape_.size() == kInputSizeTwo) { | |||
| new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]}; | |||
| expand_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]}; | |||
| } else if (data_info.shape_.size() == kInputSizeThree) { | |||
| new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]}; | |||
| expand_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[kInputIndexTwo]}; | |||
| } | |||
| auto size = data_info.data_.size() / sizeof(float); | |||
| std::vector<float> new_data(size); | |||
| auto new_data_ptr = static_cast<float *>(new_data.data()); | |||
| auto nchw_data = reinterpret_cast<float *>(data_info.data_.data()); | |||
| // nchw to nhwc | |||
| auto batch = new_shape[lite::NCHW_N]; | |||
| auto channel = new_shape[lite::NCHW_C]; | |||
| auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W]; | |||
| for (auto i = 0; i < batch; i++) { | |||
| float *src_batch = nchw_data + i * channel * area; | |||
| float *dst_batch = new_data_ptr + i * channel * area; | |||
| for (int j = 0; j < area; ++j) { | |||
| float *src_area = src_batch + i; | |||
| float *dst_area = dst_batch + i * channel; | |||
| for (int k = 0; k < channel; ++k) { | |||
| dst_area[k] = src_area[k * area]; | |||
| } | |||
| } | |||
| auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape, | |||
| data_info.data_.data(), data_info.data_.size()); | |||
| if (trans_type == kNHWC2NCHW) { | |||
| (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW); | |||
| } else { | |||
| (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC); | |||
| } | |||
| auto param_node = func_graph->add_parameter(); | |||
| param_node->set_name(cnode->input(index)->fullname_with_scope()); | |||
| std::vector<int64_t> shape_vec{new_shape[0], new_shape[kInputIndexTwo], new_shape[kInputIndexThree], new_shape[1]}; | |||
| auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||
| return; | |||
| } | |||
| status = lite::InitParameterFromTensorInfo(param_node, tensor_info); | |||
| status = lite::InitParameterFromTensorInfo(param_node, tensor); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "init parameter from tensor info failed"; | |||
| return; | |||
| @@ -200,72 +194,10 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s | |||
| auto tr = func_graph->manager()->Transact(); | |||
| tr.SetEdge(cnode, index, param_node); | |||
| tr.Commit(); | |||
| return; | |||
| } | |||
| } // namespace | |||
| void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = GetNCHWOpMap(); | |||
| if (fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = kNHWC2NCHW; | |||
| trans_info->post_ = kNCHW2NHWC; | |||
| } else if (fmk_type_ == lite::converter::FmkType_TF) { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && GetFormat(cnode) == NCHW) { | |||
| trans_info->pre_ = kNCHW2NHWC; | |||
| trans_info->post_ = kNHWC2NCHW; | |||
| } | |||
| if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) { | |||
| trans_info->pre_ = kNHWC2NCHW; | |||
| trans_info->post_ = kNCHW2NHWC; | |||
| } | |||
| } else { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) { | |||
| if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = kNCHW2NHWC; | |||
| trans_info->post_ = kNHWC2NCHW; | |||
| } | |||
| } | |||
| } | |||
| bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) { | |||
| return false; | |||
| } | |||
| std::vector<int> post_perm; | |||
| if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| std::vector<int> pre_perm; | |||
| auto pre_node = cnode->input(1); | |||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | |||
| if (pre_cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) { | |||
| func_graph->manager()->Replace(cnode, pre_cnode->input(1)); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) { | |||
| return lite::RET_OK; | |||
| @@ -285,7 +217,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons | |||
| MS_LOG(ERROR) << "get post transpose node perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if ((cur_perm == NH2NC && post_trans_perm == NC2NH) || (cur_perm == NC2NH && post_trans_perm == NH2NC)) { | |||
| if ((cur_perm == kNH2NC && post_trans_perm == kNC2NH) || (cur_perm == kNC2NH && post_trans_perm == kNH2NC)) { | |||
| func_graph->manager()->Replace(post_node, cnode->input(1)); | |||
| } | |||
| } | |||
| @@ -293,15 +225,11 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, | |||
| bool before, size_t index) { | |||
| STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, | |||
| bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr new_input = nullptr; | |||
| if (need_reset_) { | |||
| new_input = transpose_strategy_.TransposeDependOnShape(func_graph, cnode, perm, before, index); | |||
| } else { | |||
| new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index); | |||
| } | |||
| new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index); | |||
| if (new_input == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| @@ -312,13 +240,6 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP | |||
| auto new_cnode_input = new_input->cast<CNodePtr>(); | |||
| int status = lite::RET_OK; | |||
| if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) { | |||
| if (need_reset_) { | |||
| if (before) { | |||
| pre_insert_trans_.insert(new_cnode_input); | |||
| } else { | |||
| post_insert_trans_.insert(new_cnode_input); | |||
| } | |||
| } | |||
| status = node_infer_shape_.InferShape(new_cnode_input); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| @@ -337,7 +258,7 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP | |||
| tr.Commit(); | |||
| } else { | |||
| func_graph->manager()->Replace(cnode, new_input); | |||
| if (!need_reset_ && PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) { | |||
| if (PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "post transpose fusion failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| @@ -345,8 +266,8 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| @@ -380,8 +301,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| TransTypePair *trans_insert_info) { | |||
| STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| TransTypePair *trans_insert_info) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| MS_ASSERT(trans_insert_info != nullptr); | |||
| TransTypePair trans_info; | |||
| @@ -393,14 +314,14 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode); | |||
| auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_insert_info->pre_); | |||
| if (status == lite::RET_NOT_SUPPORT) { | |||
| return lite::RET_NO_CHANGE; | |||
| } else if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "change op attr failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? NH2NC : NC2NH; | |||
| auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (IsMonadNode(cnode->input(i))) { | |||
| continue; | |||
| @@ -431,8 +352,8 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!cnode->abstract()->isa<abstract::AbstractTuple>()) { | |||
| if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) { | |||
| @@ -470,63 +391,8 @@ STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, cons | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| auto node = cnode->input(i); | |||
| if (!utils::isa<ParameterPtr>(node)) { | |||
| continue; | |||
| } | |||
| auto param_node = node->cast<ParameterPtr>(); | |||
| if (param_node->has_default()) { | |||
| continue; | |||
| } | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| if (shape_vector.size() != kInputSizeFour) { | |||
| continue; | |||
| } | |||
| if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && | |||
| shape_vector[kInputIndexThree] == kInputChannel && shape_vector[1] == -1) { | |||
| continue; | |||
| } | |||
| std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H], | |||
| shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]}; | |||
| abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims)); | |||
| auto trans_cnode = GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre"); | |||
| if (trans_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto status = node_infer_shape_.InferShape(trans_cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| func_graph->manager()->Replace(param_node, trans_cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::set<CNodePtr> *visit_transposes) { | |||
| STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::set<CNodePtr> *visit_transposes) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| @@ -543,7 +409,8 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con | |||
| visit_transposes->insert(in_cnode); | |||
| } | |||
| } | |||
| if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) { | |||
| TransTypePair trans_info; | |||
| if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes, &trans_info)) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| @@ -568,9 +435,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con | |||
| continue; | |||
| } | |||
| for (size_t i = 1; i < middle_cnode->size(); ++i) { | |||
| ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_); | |||
| ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_); | |||
| } | |||
| status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode); | |||
| status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "change op attr failed."; | |||
| return lite::RET_ERROR; | |||
| @@ -584,7 +451,7 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con | |||
| return lite::RET_OK; | |||
| } | |||
| void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto sub_inputs = sub_graph->get_inputs(); | |||
| sub_inputs_map_[sub_graph] = sub_inputs; | |||
| @@ -628,7 +495,7 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr | |||
| } | |||
| } | |||
| void UnifyFormatPass::ResetSubGraphInput() { | |||
| void DecreaseTransposeAlgo::ResetSubGraphInput() { | |||
| for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { | |||
| auto &sub_graph = iter->first; | |||
| auto &sub_inputs = iter->second; | |||
| @@ -647,7 +514,7 @@ void UnifyFormatPass::ResetSubGraphInput() { | |||
| } | |||
| } | |||
| void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_input = return_node->inputs(); | |||
| @@ -676,7 +543,7 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt | |||
| return_node->set_inputs(origin_input); | |||
| } | |||
| void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_inputs = return_node->inputs(); | |||
| @@ -720,7 +587,7 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph | |||
| prim->AddAttr(kInferDone, MakeValue<bool>(infer_done)); | |||
| } | |||
| bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { | |||
| bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); | |||
| auto manager = Manage(func_graph, true); | |||
| @@ -771,7 +638,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap | |||
| MS_LOG(ERROR) << "insert pre node failed."; | |||
| return false; | |||
| } | |||
| auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? NH2NC : NC2NH; | |||
| auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? kNH2NC : kNC2NH; | |||
| if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); | |||
| return false; | |||
| @@ -780,7 +647,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) { | |||
| bool DecreaseTransposeAlgo::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| @@ -811,7 +678,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph | |||
| } | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK || | |||
| perm != NH2NC) { | |||
| perm != kNH2NC) { | |||
| continue; | |||
| } | |||
| auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes); | |||
| @@ -823,144 +690,7 @@ bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(kInferDone) != nullptr) { | |||
| prim->EraseAttr(kInferDone); | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| bool all_op_can_infer = true; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| all_op_can_infer = false; | |||
| } else { | |||
| all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); | |||
| } | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| all_op_can_infer = false; | |||
| } else { | |||
| all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); | |||
| } | |||
| continue; | |||
| } | |||
| auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode); | |||
| if (!cur_op_can_infer) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| lite::NotSupportOp::GetInstance()->InsertOp(prim->name()); | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT); | |||
| all_op_can_infer = false; | |||
| } | |||
| } | |||
| return all_op_can_infer; | |||
| } | |||
| bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| if (!JudgeAllOpsCanInfer(func_graph)) { | |||
| MS_LOG(ERROR) << "exist op cannot support infer shape."; | |||
| return false; | |||
| } | |||
| if (!RunNodeInferShape(func_graph)) { | |||
| MS_LOG(ERROR) << "RunNodeInferShape failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| ResetFuncGraph(func_graph); | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) { | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| if (!RunNodeInferShape(sub_func_graph)) { | |||
| MS_LOG(ERROR) << "subgraph infer shape failed."; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); | |||
| return false; | |||
| } | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| if (!RunNodeInferShape(sub_func_graph)) { | |||
| MS_LOG(ERROR) << "subgraph infer shape failed."; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); | |||
| return false; | |||
| } | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| SetSubGraphAbstract(cnode, sub_func_graph); | |||
| continue; | |||
| } | |||
| auto status = node_infer_shape_.InferShape(cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| bool DecreaseTransposeAlgo::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| auto &nchw_op = GetNCHWOpMap(); | |||
| @@ -970,32 +700,14 @@ bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNode | |||
| if (utils::isa<CNodePtr>(cnode->input(1))) { | |||
| auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat)); | |||
| if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) { | |||
| InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2}); | |||
| InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1}); | |||
| } | |||
| } | |||
| { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| } | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| std::vector<int> perm; | |||
| auto status = GetTransposePerm(cnode, &perm); | |||
| if (status != RET_OK) { | |||
| return false; | |||
| } | |||
| if (!shape.empty() && shape.size() != perm.size()) { | |||
| manager->Replace(cnode, cnode->input(1)); | |||
| } | |||
| InsertPreTransNode(func_graph, cnode, kNH2NC); | |||
| InsertPostTransNode(func_graph, cnode, kNC2NH); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) { | |||
| bool DecreaseTransposeAlgo::DoFixFormat(const FuncGraphPtr &func_graph) { | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| @@ -1041,8 +753,10 @@ bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) { | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||
| bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| node_infer_shape_.Init(fmk_type_, train_flag_); | |||
| transpose_strategy_.Init(fmk_type_, train_flag_); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| auto prim = GetValueNode<PrimitivePtr>(node); | |||
| @@ -1050,15 +764,6 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!JudgeAllOpsCanInfer(func_graph)) { | |||
| MS_LOG(ERROR) << "exist op cannot support infer shape."; | |||
| return false; | |||
| } | |||
| if (!RunNodeInferShape(func_graph)) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| if (!DoFixFormat(func_graph)) { | |||
| MS_LOG(ERROR) << "DoFixFormat failed."; | |||
| @@ -31,10 +31,11 @@ | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class UnifyFormatPass : public Pass { | |||
| class DecreaseTransposeAlgo : public Pass { | |||
| public: | |||
| UnifyFormatPass() : Pass("unify_format_pass") {} | |||
| ~UnifyFormatPass() override = default; | |||
| explicit DecreaseTransposeAlgo(FmkType fmk_type = FmkType::FmkType_MS, bool train_flag = false) | |||
| : Pass("DecreaseTransposeAlgo"), fmk_type_(fmk_type), train_flag_(train_flag) {} | |||
| ~DecreaseTransposeAlgo() override = default; | |||
| void Init(FmkType fmk_type, bool train_flag) { | |||
| fmk_type_ = fmk_type; | |||
| train_flag_ = train_flag; | |||
| @@ -42,7 +43,6 @@ class UnifyFormatPass : public Pass { | |||
| transpose_strategy_.Init(fmk_type, train_flag); | |||
| } | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| bool RunOnlyForShape(const FuncGraphPtr &func_graph); | |||
| private: | |||
| STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| @@ -51,16 +51,10 @@ class UnifyFormatPass : public Pass { | |||
| size_t index = 0); | |||
| bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| bool DoFixFormat(const FuncGraphPtr &func_graph); | |||
| bool RunNodeInferShape(const FuncGraphPtr &func_graph); | |||
| bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); | |||
| bool ResetFuncGraph(const FuncGraphPtr &func_graph); | |||
| bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); | |||
| bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph); | |||
| bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); | |||
| STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::set<CNodePtr> *visit_transposes); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); | |||
| @@ -69,12 +63,9 @@ class UnifyFormatPass : public Pass { | |||
| void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool need_reset_{false}; | |||
| bool train_flag_{false}; | |||
| NodeInferShape node_infer_shape_; | |||
| TransposeStrategy transpose_strategy_; | |||
| std::set<AnfNodePtr> pre_insert_trans_; | |||
| std::set<AnfNodePtr> post_insert_trans_; | |||
| std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_; | |||
| }; | |||
| } // namespace opt | |||
| @@ -29,6 +29,10 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "create NodeInferShape object failed."; | |||
| return false; | |||
| } | |||
| if (!JudgeAllOpsCanInfer(func_graph)) { | |||
| MS_LOG(ERROR) << "exist op cannot support infer shape."; | |||
| return false; | |||
| } | |||
| if (InferProcess(func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return false; | |||
| @@ -37,6 +41,47 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| return true; | |||
| } | |||
| bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| bool all_op_can_infer = true; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| all_op_can_infer = false; | |||
| } else { | |||
| all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); | |||
| } | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| all_op_can_infer = false; | |||
| } else { | |||
| all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); | |||
| } | |||
| continue; | |||
| } | |||
| auto cur_op_can_infer = node_infer_shape_->JudgeOpSupportInfer(cnode); | |||
| if (!cur_op_can_infer) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| lite::NotSupportOp::GetInstance()->InsertOp(prim->name()); | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT); | |||
| all_op_can_infer = false; | |||
| } | |||
| } | |||
| return all_op_can_infer; | |||
| } | |||
| STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| @@ -55,7 +100,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)InferProcess(sub_func_graph); | |||
| if (InferProcess(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "subgraph infer shape failed."; | |||
| return false; | |||
| } | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| @@ -63,7 +111,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)InferProcess(sub_func_graph); | |||
| if (InferProcess(sub_func_graph) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "subgraph infer shape failed."; | |||
| return false; | |||
| } | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| SetSubGraphAbstract(cnode, sub_func_graph); | |||
| continue; | |||
| @@ -132,7 +183,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr | |||
| continue; | |||
| } | |||
| auto node_name = return_node->input(i)->fullname_with_scope(); | |||
| if (node_name.substr(node_name.size() - 5) != "_post") { | |||
| if (node_name.size() < kInputSizeFive || node_name.substr(node_name.size() - kInputSizeFive) != "_post") { | |||
| continue; | |||
| } | |||
| auto trans_cnode = return_node->input(i)->cast<CNodePtr>(); | |||
| @@ -29,10 +29,11 @@ class InferShapePass : public Pass { | |||
| public: | |||
| explicit InferShapePass(FmkType fmk_type = lite::converter::FmkType_MS, bool train_flag = false) | |||
| : Pass("infer_shape"), fmk_type_(fmk_type), train_flag_(train_flag) {} | |||
| ~InferShapePass() = default; | |||
| ~InferShapePass() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); | |||
| STATUS InferProcess(const FuncGraphPtr &func_graph); | |||
| void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| @@ -32,8 +32,6 @@ namespace { | |||
| constexpr size_t kFirstInput = 1; | |||
| constexpr size_t kHalfDivisor = 2; | |||
| constexpr size_t kOnnxStridedSlice = 6; | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) { | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| @@ -70,7 +68,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu | |||
| MS_LOG(ERROR) << "transpose perm get failed."; | |||
| return nullptr; | |||
| } | |||
| if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) { | |||
| if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) { | |||
| return input_cnode->input(kFirstInput); | |||
| } | |||
| } | |||
| @@ -170,7 +168,8 @@ bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CN | |||
| return true; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != kInputSizeFour) { | |||
| @@ -183,39 +182,17 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis)); | |||
| auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis]; | |||
| prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis)); | |||
| return ChangeCommonOp(cnode, trans_type); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimCrop)) { | |||
| auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0)); | |||
| if (crop_prim == nullptr) { | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto axis = crop_prim->get_axis(); | |||
| auto offsets = crop_prim->get_offsets(); | |||
| auto new_axis = axis_map[axis < 0 ? axis + kInputSizeFour : axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]}; | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| return ChangeOpCrop(cnode, trans_type); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { | |||
| return ChangeOpSlice(func_graph, cnode); | |||
| return ChangeOpSlice(func_graph, cnode, trans_type); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { | |||
| return ChangeOpStrideSlice(func_graph, cnode); | |||
| return ChangeOpStrideSlice(func_graph, cnode, trans_type); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| @@ -242,7 +219,7 @@ STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_ | |||
| CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>(); | |||
| size_t input_index = before ? index : node_users.front().second; | |||
| auto shape = node_infer_shape_.GetInputShape(base_node, input_index); | |||
| if (!shape.empty() && shape.size() != NH2NC.size()) { | |||
| if (!shape.empty() && shape.size() != kNH2NC.size()) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| return lite::RET_OK; | |||
| @@ -263,9 +240,9 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s | |||
| if (GetTransposePerm(cnode, &perm) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| if (perm == NH2NC) { | |||
| if (perm == kNH2NC) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else if (perm == NC2NH) { | |||
| } else if (perm == kNC2NH) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else { | |||
| return false; | |||
| @@ -297,8 +274,79 @@ void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, Tra | |||
| } | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| STATUS TransposeStrategy::ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis)); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| new_axis = kNC2NH[axis]; | |||
| } | |||
| prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis)); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0)); | |||
| if (crop_prim == nullptr) { | |||
| MS_LOG(ERROR) << "cnode is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto axis = crop_prim->get_axis(); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| auto offsets = crop_prim->get_offsets(); | |||
| if (trans_type == kNCHW2NHWC) { | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]}; | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } else { | |||
| auto new_axis = kNC2NH[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]}; | |||
| } else { | |||
| offsets.pop_back(); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| @@ -321,15 +369,21 @@ STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CN | |||
| [](int64_t v) { return static_cast<int>(v); }); | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| TransformAttrByAxes(func_graph, cnode, i, axes); | |||
| TransformAttrByAxes(func_graph, cnode, i, axes, trans_type); | |||
| } | |||
| auto tmp_axes = TransformOpAxesAttr(axes); | |||
| auto tmp_axes = TransformOpAxesAttr(axes, trans_type); | |||
| std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end()); | |||
| prim->set_axes(new_axes); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (cnode->size() != kOnnxStridedSlice) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| @@ -347,9 +401,9 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co | |||
| if (index == kInputIndexFour) { | |||
| continue; | |||
| } | |||
| TransformAttrByAxes(func_graph, cnode, index, axes); | |||
| TransformAttrByAxes(func_graph, cnode, index, axes, trans_type); | |||
| } | |||
| auto cur_axes = TransformOpAxesAttr(axes); | |||
| auto cur_axes = TransformOpAxesAttr(axes, trans_type); | |||
| auto param_node = | |||
| BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node); | |||
| @@ -357,11 +411,10 @@ STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, co | |||
| } | |||
| void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes) { | |||
| const std::vector<int> &axes, FormatTransNodeType trans_type) { | |||
| if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) { | |||
| return; | |||
| } | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index); | |||
| if (origin_input.size() != axes.size()) { | |||
| return; | |||
| @@ -369,8 +422,16 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons | |||
| std::vector<int> cur_input; | |||
| for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) { | |||
| for (size_t index = 0; index < axes.size(); ++index) { | |||
| int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + kInputSizeFour : axes[index]]; | |||
| if (nhwc_dim == dim) { | |||
| int axis = axes[index]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| if (cur_axis == dim) { | |||
| cur_input.push_back(origin_input[index]); | |||
| } | |||
| } | |||
| @@ -379,14 +440,23 @@ void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, cons | |||
| func_graph->manager()->Replace(cnode->input(input_index), param_node); | |||
| } | |||
| std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes) { | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| std::vector<int> cur_axis; | |||
| std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes, | |||
| FormatTransNodeType trans_type) { | |||
| std::vector<int> cur_axes; | |||
| for (size_t i = 0; i < origin_axes.size(); ++i) { | |||
| cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + kInputSizeFour : origin_axes[i]]); | |||
| int axis = origin_axes[i]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| cur_axes.push_back(cur_axis); | |||
| } | |||
| std::sort(cur_axis.begin(), cur_axis.end()); | |||
| return cur_axis; | |||
| std::sort(cur_axes.begin(), cur_axes.end()); | |||
| return cur_axes; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -39,23 +39,25 @@ class TransposeStrategy { | |||
| } | |||
| AnfNodePtr TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &code, | |||
| const std::vector<int> &perm, bool before, size_t index); | |||
| AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm, | |||
| bool before, size_t index); | |||
| bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info, | |||
| TransTypePair *trans_insert_info); | |||
| STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| private: | |||
| AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm, | |||
| bool before, size_t index); | |||
| STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index); | |||
| bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count, | |||
| FormatTransNodeType *trans_type); | |||
| void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info); | |||
| STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes); | |||
| std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes); | |||
| const std::vector<int> &axes, FormatTransNodeType trans_type); | |||
| std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type); | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool train_flag_{false}; | |||
| NodeInferShape node_infer_shape_; | |||