From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -240,13 +240,15 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc | |||
| @@ -258,6 +260,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/if_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc | |||
| ) | |||
| @@ -61,3 +61,4 @@ ml_noya_tts_melgan.pb 1;16,16,80 | |||
| ml_video_edit_oneclick_adaptis.pb 3 | |||
| # Q_hand_0812.pb is not suitable for float16. Out of float16 range. | |||
| Q_hand_0812.pb | |||
| tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62 | |||
| @@ -50,13 +50,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/conv_conv_fusion.cc | |||
| ../optimizer/fusion/tflite_lstm_cell_fusion.cc | |||
| ../optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc | |||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/clip_convert_activation_pass.cc | |||
| ../optimizer/graph/group_depthwise_op_convert_pass.cc | |||
| ../optimizer/graph/tflite_inputs_adjust_pass.cc | |||
| ../optimizer/graph/update_conv2d_param_pass.cc | |||
| ../optimizer/graph/unused_node_remove_pass.cc | |||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | |||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | |||
| ../optimizer/graph/redundant_op_remove_pass.cc | |||
| @@ -68,6 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/if_pass.cc | |||
| ../optimizer/graph/functionalize_control_op_pass.cc | |||
| ../optimizer/graph/functionalize_while.cc | |||
| ../optimizer/graph/functionalize_cond.cc | |||
| ../optimizer/graph/inputs_adjust_pass.cc | |||
| ../optimizer/graph/primitive_adjust_pass.cc | |||
| ) | |||
| @@ -18,6 +18,8 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "mindspore/core/ir/primitive.h" | |||
| #include "tools/optimizer/fusion/conv_biasadd_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_activation_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" | |||
| @@ -31,7 +33,8 @@ | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | |||
| #include "tools/optimizer/graph/primitive_adjust_pass.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||
| @@ -42,6 +45,7 @@ | |||
| #include "tools/optimizer/graph/tflite_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/update_conv2d_param_pass.h" | |||
| #include "tools/optimizer/graph/unused_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/infershape_pass.h" | |||
| @@ -81,7 +85,7 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>()); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_MS) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| @@ -225,6 +229,23 @@ int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config) { | |||
| MS_ASSERT(old_graph != nullptr); | |||
| auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false); | |||
| // fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic | |||
| asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>()); | |||
| // remove remaining cyclic nodes | |||
| asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>()); | |||
| asylic_optimizer->AddPassManager(asylic_pm); | |||
| if (!asylic_optimizer->Optimize(old_graph)) { | |||
| MS_LOG(ERROR) << "gru cf fusion pass failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, | |||
| const FuncGraphPtr &new_graph) { | |||
| // quant | |||
| @@ -266,7 +287,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| return old_graph; | |||
| } | |||
| auto status = RunAdjustPass(old_graph, config); | |||
| auto status = RunPrecedingPass(old_graph, *config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Preceding pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunAdjustPass(old_graph, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Adjust pass failed."; | |||
| return nullptr; | |||
| @@ -50,6 +50,8 @@ class AnfTransform { | |||
| static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config); | |||
| static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr auto kNameIf = "If"; | |||
| class If : public PrimitiveC { | |||
| public: | |||
| If() : PrimitiveC(kNameIf) {} | |||
| ~If() = default; | |||
| MS_DECLARE_PARENT(If, PrimitiveC); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ | |||
| @@ -1,39 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr auto kNameLoopCond = "LoopCond"; | |||
| class LoopCond : public PrimitiveC { | |||
| public: | |||
| LoopCond() : PrimitiveC(kNameLoopCond) {} | |||
| ~LoopCond() = default; | |||
| MS_DECLARE_PARENT(LoopCond, PrimitiveC); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ | |||
| @@ -17,16 +17,31 @@ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define ADD_CONVERTER_ONLY_OP(name) \ | |||
| constexpr auto kName##name = #name; \ | |||
| class name : public PrimitiveC { \ | |||
| public: \ | |||
| name() : PrimitiveC(kName##name) {} \ | |||
| ~name() = default; \ | |||
| MS_DECLARE_PARENT(name, PrimitiveC); \ | |||
| }; | |||
| enum ConverterPrimitiveType { | |||
| ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1, | |||
| ConverterPrimitiveType_LoopCond, | |||
| ConverterPrimitiveType_NextIteration, | |||
| ConverterPrimitiveType_Exit, | |||
| }; | |||
| ADD_CONVERTER_ONLY_OP(Enter); | |||
| ADD_CONVERTER_ONLY_OP(Exit); | |||
| ADD_CONVERTER_ONLY_OP(If); | |||
| ADD_CONVERTER_ONLY_OP(LoopCond); | |||
| ADD_CONVERTER_ONLY_OP(NextIteration); | |||
| ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3); | |||
| ADD_CONVERTER_ONLY_OP(TensorArrayReadV3); | |||
| ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3); | |||
| ADD_CONVERTER_ONLY_OP(TensorArraySizeV3); | |||
| ADD_CONVERTER_ONLY_OP(TensorArrayV3); | |||
| ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx_if_parser.h" | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| #include "tools/converter/ops/if.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_enter_parser.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/enter.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -18,7 +18,7 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/exit.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -19,7 +19,7 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/if.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -18,7 +18,7 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/loop_cond.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -28,7 +28,7 @@ ops::PrimitiveC *TFMergeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto prim = std::make_unique<ops::Merge>(); | |||
| *output_size = tf_op.input_size(); | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| @@ -18,7 +18,7 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/next_iteration.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -28,7 +28,7 @@ ops::PrimitiveC *TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| auto prim = std::make_unique<ops::Switch>(); | |||
| *output_size = tf_op.input_size(); | |||
| *output_size = 2; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_gather_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArrayGatherParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArrayGatherParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArrayGatherV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArrayGatherParser("TensorArrayGatherV3", new TFTensorArrayGatherParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorArrayGatherParser : public TFNodeParser { | |||
| public: | |||
| TFTensorArrayGatherParser() = default; | |||
| ~TFTensorArrayGatherParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArrayParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArrayParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArrayV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 2; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArrayParser("TensorArrayV3", new TFTensorArrayParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,27 +13,25 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr auto kNameEnter = "Enter"; | |||
| class Enter : public PrimitiveC { | |||
| class TFTensorArrayParser : public TFNodeParser { | |||
| public: | |||
| Enter() : PrimitiveC(kNameEnter) {} | |||
| ~Enter() = default; | |||
| MS_DECLARE_PARENT(Enter, PrimitiveC); | |||
| TFTensorArrayParser() = default; | |||
| ~TFTensorArrayParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_read_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArrayReadParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArrayReadParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArrayReadV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArrayReadParser("TensorArrayReadV3", new TFTensorArrayReadParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,27 +13,25 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr auto kNameExit = "Exit"; | |||
| class Exit : public PrimitiveC { | |||
| class TFTensorArrayReadParser : public TFNodeParser { | |||
| public: | |||
| Exit() : PrimitiveC(kNameExit) {} | |||
| ~Exit() = default; | |||
| MS_DECLARE_PARENT(Exit, PrimitiveC); | |||
| TFTensorArrayReadParser() = default; | |||
| ~TFTensorArrayReadParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_scatter_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArrayScatterParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArrayScatterParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArrayScatterV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArrayScatterParser("TensorArrayScatterV3", new TFTensorArrayScatterParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorArrayScatterParser : public TFNodeParser { | |||
| public: | |||
| TFTensorArrayScatterParser() = default; | |||
| ~TFTensorArrayScatterParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_size_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArraySizeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArraySizeParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArraySizeV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArraySizeParser("TensorArraySizeV3", new TFTensorArraySizeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,27 +13,25 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ops/primitive_c.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr auto kNameNextIteration = "NextIteration"; | |||
| class NextIteration : public PrimitiveC { | |||
| class TFTensorArraySizeParser : public TFNodeParser { | |||
| public: | |||
| NextIteration() : PrimitiveC(kNameNextIteration) {} | |||
| ~NextIteration() = default; | |||
| MS_DECLARE_PARENT(NextIteration, PrimitiveC); | |||
| TFTensorArraySizeParser() = default; | |||
| ~TFTensorArraySizeParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_array_write_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *TFTensorArrayWriteParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF TensorArrayWriteParser"; | |||
| if (inputs == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "inputs or output_size is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_unique<TensorArrayWriteV3>(); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "prim is nullptr"; | |||
| return nullptr; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfTensorArrayWriteParser("TensorArrayWriteV3", new TFTensorArrayWriteParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorArrayWriteParser : public TFNodeParser { | |||
| public: | |||
| TFTensorArrayWriteParser() = default; | |||
| ~TFTensorArrayWriteParser() override = default; | |||
| ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ | |||
| @@ -0,0 +1,214 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | |||
| #include <memory> | |||
| #include <set> | |||
| #include <functional> | |||
| #include "src/common/utils.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kNumFwVars = 4; | |||
| constexpr size_t kNumBwVars = 4; | |||
| const auto &p1 = std::placeholders::_1; | |||
| BaseRef GetPrim(const PrimitivePtr &prim) { return std::make_shared<CondVar>(std::bind(IsOpType, p1, prim)); } | |||
| BaseRef GetPrim(const std::string &prim_name) { | |||
| auto prim = std::make_shared<Primitive>(prim_name); | |||
| return GetPrim(prim); | |||
| } | |||
| } // namespace | |||
| TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph) | |||
| : TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) { | |||
| /* | |||
| * vars for fw/bw input | |||
| * fw: | |||
| * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias | |||
| * bw: | |||
| * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias | |||
| */ | |||
| } | |||
| BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true, | |||
| const std::vector<VarPtr> &vars) const { | |||
| auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true}); | |||
| auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]}); // gate_kernel | |||
| auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter}); | |||
| auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]}); // cand_bias | |||
| auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter}); | |||
| auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias}); | |||
| auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid}); | |||
| auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()}); | |||
| auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()}); | |||
| auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true}); | |||
| auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul}); | |||
| auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]}); // cand_kernel | |||
| auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter}); | |||
| auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]}); // cand_bias | |||
| auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter}); | |||
| auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1}); | |||
| auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared<CondVar>(IsParameterNode), zt}); | |||
| auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh}); | |||
| auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true}); | |||
| auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2}); | |||
| return add; | |||
| } | |||
| const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input, | |||
| const std::vector<VarPtr> &vars, | |||
| const VarPtr &init_state) const { | |||
| // in order to match cyclic graph, some node in cycle is represented by SeqVar | |||
| auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input}); | |||
| auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared<SeqVar>()}); | |||
| auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared<Var>()}); | |||
| auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared<CondVar>(IsParameterNode), fw_max}); | |||
| auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum}); | |||
| auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum}); | |||
| // SeqVar:counter_merge1 | |||
| auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared<SeqVar>(), fw_less1_enter}); | |||
| // SeqVar:fw_merge,loop_cond | |||
| auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared<SeqVar>()}); | |||
| auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared<Var>()}); // identity | |||
| auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add}); | |||
| auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter}); | |||
| auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice}); | |||
| auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter}); | |||
| auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1}); | |||
| // SeqVar:fw_logical_and | |||
| auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and}); | |||
| auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input}); | |||
| auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared<SeqVar>()}); | |||
| auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<CondVar>(IsParameterNode), | |||
| fw_unstack_strided_slice, std::make_shared<CondVar>(IsParameterNode)}); | |||
| // SeqVar:switch1_true | |||
| auto counter_add = | |||
| VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared<SeqVar>(), std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add}); | |||
| auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter}); | |||
| auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond}); | |||
| auto switch1_true = | |||
| VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared<Var>()}); // identity1 | |||
| auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice}); | |||
| auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()}); | |||
| auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()}); | |||
| auto fw_unstack_ta_scatter = | |||
| VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow}); | |||
| auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter}); | |||
| auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle}); | |||
| auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1}); | |||
| auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_}); | |||
| auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter}); | |||
| auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared<SeqVar>()}); // select h | |||
| auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1}); | |||
| auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state}); | |||
| auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3}); | |||
| auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond}); | |||
| auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared<Var>()}); // identity3 | |||
| auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars); | |||
| auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice}); | |||
| auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()}); | |||
| auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()}); | |||
| auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle}); | |||
| auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared<SeqVar>()}); // cycle | |||
| auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<SeqVar>()}); | |||
| auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared<CondVar>(IsParameterNode), concat1}); | |||
| auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1}); | |||
| auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out}); // select x | |||
| auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true}); | |||
| auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow}); | |||
| auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write}); | |||
| auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2}); | |||
| auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond}); | |||
| auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared<Var>()}); | |||
| auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false}); | |||
| auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2}); | |||
| auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<Var>(), ta_size, std::make_shared<Var>()}); | |||
| auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2}); | |||
| auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()}); | |||
| auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), range1}); | |||
| auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2}); | |||
| return fw_out_trans; | |||
| } | |||
| const BaseRef TfBidirectionGruCfFusion::DefinePattern() const { | |||
| const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_); | |||
| auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_}); | |||
| auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()}); | |||
| auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), bw_range}); | |||
| auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat}); | |||
| auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_); | |||
| auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_}); | |||
| auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out}); | |||
| return concat; | |||
| } | |||
| const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(concat_node != nullptr); | |||
| MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]); | |||
| MS_ASSERT(transpose_input != nullptr); | |||
| const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); | |||
| auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0); | |||
| if (gru_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0); | |||
| if (get_item_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope()); | |||
| MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success"; | |||
| return output_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| // fuse tf 1.x bidirection_gru into MSLITE GRU | |||
| class TfBidirectionGruCfFusion : public TfBidirectionGruFusion { | |||
| public: | |||
| explicit TfBidirectionGruCfFusion(const std::string &name = "tf_bidirection_gru_cf_fusion", bool multi_graph = true); | |||
| ~TfBidirectionGruCfFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| BaseRef DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true, | |||
| const std::vector<VarPtr> &vars) const; | |||
| const BaseRef DefineBidirectionRnnPattern(const BaseRef &input, const std::vector<VarPtr> &vars, | |||
| const VarPtr &init_state) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "ops/concat.h" | |||
| @@ -24,32 +24,21 @@ | |||
| #include "ops/transpose.h" | |||
| #include "src/common/utils.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kWhileUniqInputsLength = 6; | |||
| constexpr size_t kCondNodesNum = 12; | |||
| constexpr size_t kCondCNodesNum = 4; | |||
| constexpr size_t kBodyNodesNum = 69; | |||
| constexpr size_t kBodyCNodesNum = 25; | |||
| const auto &p1 = std::placeholders::_1; | |||
| bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||
| bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| auto anf_node = utils::cast<AnfNodePtr>(n); | |||
| return CheckPrimitiveType(anf_node, prim); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph) | |||
| : PatternProcessPass(name, multigraph) { | |||
| TfBidirectionGruFusion::TfBidirectionGruFusion(int num_fw_vars, int num_bw_vars, const std::string &name, | |||
| bool multi_graph) | |||
| : PatternProcessPass(name, multi_graph) { | |||
| /* | |||
| * vars for while input | |||
| * fw_while_inputs: | |||
| @@ -57,8 +46,10 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, | |||
| * bw_while_inputs: | |||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||
| */ | |||
| for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { | |||
| for (int i = 0; i < num_fw_vars; ++i) { | |||
| fw_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| for (int i = 0; i < num_bw_vars; ++i) { | |||
| bw_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| input_ = std::make_shared<Var>(); | |||
| @@ -68,7 +59,7 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, | |||
| bw_init_state_ = std::make_shared<Var>(); | |||
| } | |||
| const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| const BaseRef TfBidirectionGruFusion::DefinePattern() const { | |||
| // forward | |||
| auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), | |||
| input_length_, std::make_shared<CondVar>(IsParameterNode)}); | |||
| @@ -134,7 +125,7 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| return concat; | |||
| } | |||
| AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode); | |||
| @@ -152,7 +143,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMap | |||
| return pattern; | |||
| } | |||
| AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { | |||
| std::vector<CondVarPtr> placeholders; | |||
| for (int i = 0; i < 13; ++i) { | |||
| placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode)); | |||
| @@ -206,7 +197,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMap | |||
| return pattern; | |||
| } | |||
| ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { | |||
| ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { | |||
| MS_ASSERT(parameter_anf != nullptr); | |||
| if (!utils::isa<ParameterPtr>(parameter_anf)) { | |||
| MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; | |||
| @@ -221,9 +212,9 @@ ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNode | |||
| return param_value; | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, | |||
| const AnfNodePtr &bw_cand_kernel_anf, int *input_size, | |||
| int *hidden_size) const { | |||
| STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, | |||
| const AnfNodePtr &bw_cand_kernel_anf, int *input_size, | |||
| int *hidden_size) const { | |||
| MS_ASSERT(fw_cand_kernel != nullptr); | |||
| MS_ASSERT(bw_cand_kernel != nullptr); | |||
| MS_ASSERT(input_size != nullptr); | |||
| @@ -256,9 +247,9 @@ STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_ca | |||
| return RET_OK; | |||
| } | |||
| ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||
| const std::vector<int> &shape, const TypeId type, | |||
| void **tensor_data) const { | |||
| ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, | |||
| const std::vector<int> &shape, const TypeId type, | |||
| void **tensor_data) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(tensor_data != nullptr); | |||
| auto parameter = func_graph->add_parameter(); | |||
| @@ -300,9 +291,8 @@ ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr | |||
| return parameter; | |||
| } | |||
| void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, | |||
| const int r1, const int c0, const int c1, float *data, | |||
| bool t) const { | |||
| void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, | |||
| const int c0, const int c1, float *data, bool t) const { | |||
| MS_ASSERT(mat != nullptr); | |||
| MS_ASSERT(data != nullptr); | |||
| MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R); | |||
| @@ -320,9 +310,9 @@ void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int | |||
| } | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, | |||
| const int input_size, const int hidden_size, | |||
| float *gate_tensor_data, float *recu_tensor_data) const { | |||
| STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, | |||
| const int input_size, const int hidden_size, float *gate_tensor_data, | |||
| float *recu_tensor_data) const { | |||
| MS_ASSERT(gate_weight != nullptr); | |||
| MS_ASSERT(cand_weight != nullptr); | |||
| MS_ASSERT(gate_tensor_data != nullptr); | |||
| @@ -375,8 +365,8 @@ STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weig | |||
| return RET_OK; | |||
| } | |||
| STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, | |||
| const int hidden_size, float *tensor_data) const { | |||
| STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, | |||
| const int hidden_size, float *tensor_data) const { | |||
| MS_ASSERT(bias != nullptr); | |||
| MS_ASSERT(tensor_data != nullptr); | |||
| std::vector<int> gate_shape{hidden_size * 2}; | |||
| @@ -407,10 +397,9 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, | |||
| return RET_OK; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fw_init_state, | |||
| const AnfNodePtr &bw_init_state, | |||
| const std::string base_name) const { | |||
| CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, | |||
| const AnfNodePtr &bw_init_state, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(fw_init_state != nullptr); | |||
| MS_ASSERT(bw_init_state != nullptr); | |||
| @@ -424,35 +413,32 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f | |||
| return new_node; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| const EquivPtr &equiv, const EquivPtr &fw_body_equiv, | |||
| const EquivPtr &bw_body_equiv, | |||
| const std::string &base_name) const { | |||
| CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| const EquivPtr &equiv, const std::string &base_name, | |||
| int var_offset) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(equiv != nullptr); | |||
| MS_ASSERT(fw_body_equiv != nullptr); | |||
| MS_ASSERT(bw_body_equiv != nullptr); | |||
| auto gru_prim = std::make_shared<ops::GRU>(); | |||
| gru_prim->set_bidirectional(true); | |||
| auto value_node = NewValueNode(gru_prim); | |||
| auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]); | |||
| auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset]]); | |||
| MS_ASSERT(fw_gate_kernel != nullptr); | |||
| auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]); | |||
| auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 1]]); | |||
| MS_ASSERT(fw_gate_bias != nullptr); | |||
| auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]); | |||
| auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 2]]); | |||
| MS_ASSERT(fw_cand_kernel != nullptr); | |||
| auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]); | |||
| auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 3]]); | |||
| MS_ASSERT(fw_cand_bias != nullptr); | |||
| auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]); | |||
| auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset]]); | |||
| MS_ASSERT(bw_gate_kernel != nullptr); | |||
| auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]); | |||
| auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 1]]); | |||
| MS_ASSERT(bw_gate_bias != nullptr); | |||
| auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]); | |||
| auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 2]]); | |||
| MS_ASSERT(bw_cand_kernel != nullptr); | |||
| auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]); | |||
| auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 3]]); | |||
| MS_ASSERT(bw_cand_bias != nullptr); | |||
| auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]); | |||
| @@ -522,8 +508,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr | |||
| return new_node; | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const { | |||
| CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(gru_output != nullptr); | |||
| auto split_prim = std::make_shared<ops::Split>(); | |||
| @@ -571,8 +557,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func | |||
| return transpose_new_node; | |||
| } | |||
| const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||
| const EquivPtr &equiv) const { | |||
| const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(concat_node != nullptr); | |||
| MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; | |||
| @@ -628,7 +614,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr | |||
| } | |||
| const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); | |||
| auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name); | |||
| auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2); | |||
| if (gru_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -13,12 +13,13 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| @@ -27,22 +28,26 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||
| constexpr size_t kWhileUniqInputsLength = 6; | |||
| // fuse tf 2.x bidirection_gru into MSLITE GRU | |||
| class TfBidirectionGruFusion : public PatternProcessPass { | |||
| public: | |||
| explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion", | |||
| bool multigraph = true); | |||
| ~BiDirectionTfGruCellFusion() override = default; | |||
| explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength, | |||
| const std::string &name = "tf_bidirection_gru_fusion", bool multi_graph = true); | |||
| ~TfBidirectionGruFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| protected: | |||
| virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, | |||
| const std::string &base_name, int var_offset) const; | |||
| CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const; | |||
| private: | |||
| AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; | |||
| CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, | |||
| const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv, | |||
| const std::string &base_name) const; | |||
| ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; | |||
| lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, | |||
| int *input_size, int *hidden_size) const; | |||
| @@ -56,10 +61,8 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||
| const int c1, float *data, bool t = false) const; | |||
| CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, | |||
| const AnfNodePtr &bw_init_state, const std::string base_name) const; | |||
| CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const; | |||
| private: | |||
| protected: | |||
| std::vector<VarPtr> fw_vars_; | |||
| std::vector<VarPtr> bw_vars_; | |||
| VarPtr input_; | |||
| @@ -68,7 +71,16 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||
| VarPtr fw_init_state_; | |||
| VarPtr bw_init_state_; | |||
| }; | |||
| inline bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); } | |||
| inline bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| auto anf_node = utils::cast<AnfNodePtr>(n); | |||
| return CheckPrimitiveType(anf_node, prim); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ | |||
| @@ -0,0 +1,249 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/functionalize_cond.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <deque> | |||
| #include <unordered_set> | |||
| #include "include/errorcode.h" | |||
| #include "ops/make_tuple.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| #include "ops/return.h" | |||
| namespace mindspore::opt { | |||
| STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) { | |||
| MS_ASSERT(switch_cnode != nullptr); | |||
| MS_ASSERT(branch_type != nullptr); | |||
| auto manager = fg_->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto node_users = manager->node_users()[switch_cnode]; | |||
| if (node_users.size() != 1) { // only one output of switch is referenced in cond | |||
| MS_LOG(ERROR) << "switch's node users is not correct"; | |||
| return RET_ERROR; | |||
| } | |||
| auto node_user = node_users.front(); | |||
| auto tuple_get_item = node_user.first; | |||
| if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) { | |||
| MS_LOG(ERROR) << "switch's node user is not TupleGetItem"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item); | |||
| auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode); | |||
| if (idx == 0) { | |||
| *branch_type = kElseBranch; | |||
| } else if (idx == 1) { | |||
| *branch_type = kThenBranch; | |||
| } else { | |||
| MS_LOG(ERROR) << "wrong tuple_get_item index"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, | |||
| BranchType branch_type) { | |||
| std::deque<AnfNodePtr> q; | |||
| std::unordered_set<AnfNodePtr> vis; | |||
| q.push_back(root_node); | |||
| while (!q.empty()) { | |||
| auto node = q.front(); | |||
| q.pop_front(); | |||
| vis.insert(node); | |||
| if (FunctionalizeControlOpPass::IsSwitch(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| BranchType this_type; | |||
| if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) { | |||
| MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct"; | |||
| return RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| if (utils::isa<ParameterPtr>(node)) { | |||
| graph->add_parameter(node->cast<ParameterPtr>()); | |||
| } else { | |||
| graph->AddNode(node); | |||
| } | |||
| node->set_func_graph(graph); | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| auto inputi = cnode->input(i); | |||
| if (vis.find(inputi) == vis.end()) { | |||
| q.push_back(cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) { | |||
| auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node); | |||
| if (index == input_nodes_.end()) { | |||
| input_nodes_.push_back(node); | |||
| return input_nodes_.size() - 1; | |||
| } | |||
| return index - input_nodes_.begin(); | |||
| } | |||
| STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) { | |||
| std::vector<AnfNodePtr> nodes_need_drop{}; | |||
| for (auto &cnode : graph->GetOrderedCnodes()) { | |||
| for (auto &input_node : cnode->inputs()) { | |||
| if (FunctionalizeControlOpPass::IsSwitch(input_node)) { | |||
| auto switch_node = input_node->cast<CNodePtr>(); | |||
| auto switch_input = utils::cast<CNodePtr>(switch_node->input(1)); | |||
| auto pos = PosInInputNodes(switch_input); | |||
| nodes_need_drop.push_back(cnode); | |||
| pred_nodes_.push_back(switch_node->input(2)); | |||
| // set parameter | |||
| auto parameter = graph->add_parameter(); | |||
| parameter->set_abstract(cnode->abstract()); | |||
| // hardcode for subgraph input name | |||
| parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter"); | |||
| // replace switch | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| if (graph->nodes().contains(node_user.first)) { | |||
| manager->SetEdge(node_user.first, node_user.second, parameter); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) { | |||
| auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF); | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "new graph Partial Node return nullptr"; | |||
| return nullptr; | |||
| } | |||
| graph->set_manager(fg_->manager()); | |||
| auto status = BranchSubGraphAddNodes(graph, node, branch_type); | |||
| if (status != RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty | |||
| auto return_prim_ptr = std::make_shared<ops::Return>(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs{value_node, node}; // If subgraph only has one output tensor | |||
| auto return_cnode = graph->NewCNode(op_inputs); | |||
| return_cnode->set_fullname_with_scope(name + "-return"); | |||
| return_cnode->set_func_graph(graph); | |||
| graph->set_return(return_cnode); | |||
| graph->output()->cast<CNodePtr>()->set_fullname_with_scope(name + "_output_0_cnode"); | |||
| } | |||
| return graph; | |||
| } | |||
| CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) { | |||
| MS_ASSERT(else_branch != nullptr); | |||
| MS_ASSERT(then_branch != nullptr); | |||
| auto if_primc = std::make_shared<mindspore::lite::If>(); | |||
| if (if_primc == nullptr) { | |||
| MS_LOG(ERROR) << "new if_primitive failed"; | |||
| return nullptr; | |||
| } | |||
| auto if_value_node = NewValueNode(if_primc); | |||
| if (if_value_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto then_value_node = NewValueNode(then_branch); | |||
| auto else_value_node = NewValueNode(else_branch); | |||
| std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_}; | |||
| std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs)); | |||
| return fg_->NewCNode(if_op_inputs); | |||
| } | |||
| STATUS FunctionalizeCond::VerifyPredictNode() { | |||
| if (pred_nodes_.empty()) { | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t i = 1; i < pred_nodes_.size(); ++i) { | |||
| if (pred_nodes_[i] != pred_nodes_[0]) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (!utils::isa<CNodePtr>(pred_nodes_[0])) { | |||
| return RET_ERROR; | |||
| } | |||
| pred_node_ = utils::cast<CNodePtr>(pred_nodes_[0]); | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeCond::Process() { | |||
| if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) { | |||
| MS_LOG(ERROR) << "fg or merge is not correct"; | |||
| return RET_ERROR; | |||
| } | |||
| auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else"; | |||
| auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else"; | |||
| auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch); | |||
| if (else_branch == nullptr) { | |||
| MS_LOG(ERROR) << "create else branch failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch); | |||
| if (then_branch == nullptr) { | |||
| MS_LOG(ERROR) << "create then branch failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = IdentifySubgraphInput(else_branch, else_branch_name); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = IdentifySubgraphInput(then_branch, then_branch_name); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = VerifyPredictNode(); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| auto if_node = CreateNewIf(else_branch, then_branch); | |||
| if (if_node == nullptr) { | |||
| MS_LOG(ERROR) << "create if node error"; | |||
| return RET_ERROR; | |||
| } | |||
| if_node->set_abstract(merge_node_->abstract()->Clone()); | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[merge_node_]; | |||
| for (auto &node_user : node_users) { | |||
| manager->SetEdge(node_user.first, node_user.second, if_node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/optimizer/graph/functionalize_control_op_pass.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType; | |||
| // Functionalize all the switch-merge nodes of a loop-free graph into single switch node. | |||
| // Precondition: While loops must have been functionalized. | |||
| class FunctionalizeCond { | |||
| public: | |||
| FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {} | |||
| STATUS Process(); | |||
| private: | |||
| STATUS GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type); | |||
| STATUS BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, BranchType branch_type); | |||
| FuncGraphPtr CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type); | |||
| int PosInInputNodes(const CNodePtr &node); | |||
| STATUS IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name); | |||
| CNodePtr CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch); | |||
| STATUS VerifyPredictNode(); | |||
| FuncGraphPtr fg_ = nullptr; | |||
| CNodePtr merge_node_ = nullptr; | |||
| CNodePtr pred_node_ = nullptr; | |||
| std::vector<CNodePtr> input_nodes_{}; | |||
| std::vector<AnfNodePtr> pred_nodes_{}; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include <algorithm> | |||
| #include <deque> | |||
| #include "tools/optimizer/graph/functionalize_while.h" | |||
| #include "tools/optimizer/graph/functionalize_cond.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::opt { | |||
| @@ -100,6 +101,25 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g | |||
| return ret; | |||
| } | |||
| STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) { | |||
| int ret = RET_OK; | |||
| auto nodes = func_graph->nodes(); | |||
| for (auto &node : nodes) { | |||
| if (!IsMerge(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| FunctionalizeCond fc(func_graph, cnode); | |||
| ret = fc.Process(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| // use name to find the frame | |||
| InitNodeClusters(func_graph); | |||
| @@ -107,6 +127,10 @@ bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "build while subgraph failed."; | |||
| return false; | |||
| } | |||
| if (BuildIfSubgraph(func_graph) != RET_OK) { | |||
| MS_LOG(ERROR) << "build while subgraph failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -23,10 +23,7 @@ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/ops/enter.h" | |||
| #include "tools/converter/ops/exit.h" | |||
| #include "tools/converter/ops/loop_cond.h" | |||
| #include "tools/converter/ops/next_iteration.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| @@ -70,6 +67,7 @@ class FunctionalizeControlOpPass : public Pass { | |||
| protected: | |||
| STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph); | |||
| STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph); | |||
| std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{}; | |||
| std::vector<CNodePtr> loop_cond_nodes_{}; | |||
| }; | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/unused_node_remove_pass.h" | |||
| #include <deque> | |||
| #include <unordered_set> | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore::opt { | |||
| STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto return_node = func_graph->get_return(); | |||
| if (return_node == nullptr) { | |||
| return RET_OK; | |||
| } | |||
| std::unordered_set<AnfNodePtr> vis; | |||
| std::deque<AnfNodePtr> q; | |||
| q.push_back(return_node); | |||
| while (!q.empty()) { | |||
| auto node = q.front(); | |||
| vis.insert(node); | |||
| q.pop_front(); | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (vis.find(input) == vis.end()) { | |||
| q.push_back(input); | |||
| } | |||
| } | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(node)) { | |||
| auto sub_graph = utils::cast<FuncGraphPtr>(node); | |||
| auto status = ProcessGraph(sub_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "process sub graph failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| auto nodes = func_graph->nodes(); | |||
| for (auto &node : nodes) { | |||
| if (vis.find(node) == vis.end()) { | |||
| func_graph->DropNode(node); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| bool UnusedNodeRemovePass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto status = ProcessGraph(func_graph); | |||
| return status == RET_OK; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| using mindspore::lite::STATUS; | |||
| namespace mindspore::opt { | |||
| class UnusedNodeRemovePass : public Pass { | |||
| public: | |||
| UnusedNodeRemovePass() : Pass("remove_unused_node_pass") {} | |||
| ~UnusedNodeRemovePass() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| STATUS ProcessGraph(const FuncGraphPtr &func_graph); | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_ | |||