diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 54d2c0c89b..5a1e996312 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -240,13 +240,15 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc - ${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc ${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc + ${LITE_DIR}/tools/optimizer/graph/unused_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -258,6 +260,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/if_pass.cc ${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc ${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc + ${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc ${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc ) diff --git a/mindspore/lite/test/models_tf.cfg b/mindspore/lite/test/models_tf.cfg index 653a72249e..5a9bb53b6f 100644 --- a/mindspore/lite/test/models_tf.cfg +++ b/mindspore/lite/test/models_tf.cfg @@ -61,3 +61,4 @@ ml_noya_tts_melgan.pb 1;16,16,80 ml_video_edit_oneclick_adaptis.pb 3 # Q_hand_0812.pb is not suitable for float16. Out of float16 range. Q_hand_0812.pb +tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62 diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 8505cbe73a..cf20f723de 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -50,13 +50,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/conv_conv_fusion.cc ../optimizer/fusion/tflite_lstm_cell_fusion.cc ../optimizer/fusion/tf_lstm_cell_fusion.cc - ../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc + ../optimizer/fusion/tf_bidirection_gru_fusion.cc + ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/group_depthwise_op_convert_pass.cc ../optimizer/graph/tflite_inputs_adjust_pass.cc ../optimizer/graph/update_conv2d_param_pass.cc + ../optimizer/graph/unused_node_remove_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc ../optimizer/graph/unused_transpose_node_remove_pass.cc ../optimizer/graph/redundant_op_remove_pass.cc @@ -68,6 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/if_pass.cc ../optimizer/graph/functionalize_control_op_pass.cc ../optimizer/graph/functionalize_while.cc + ../optimizer/graph/functionalize_cond.cc ../optimizer/graph/inputs_adjust_pass.cc ../optimizer/graph/primitive_adjust_pass.cc ) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 87337c3697..5b65c9c939 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -18,6 +18,8 @@ #include #include #include "src/common/log_adapter.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "mindspore/core/ir/primitive.h" #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include "tools/optimizer/fusion/conv_activation_fusion.h" #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" @@ -31,7 +33,8 @@ #include "tools/optimizer/fusion/conv_conv_fusion.h" #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" -#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" +#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" +#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" #include "tools/optimizer/graph/primitive_adjust_pass.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" @@ -42,6 +45,7 @@ #include "tools/optimizer/graph/tflite_inputs_adjust_pass.h" #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" #include "tools/optimizer/graph/update_conv2d_param_pass.h" +#include "tools/optimizer/graph/unused_node_remove_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" #include "tools/optimizer/graph/infershape_pass.h" @@ -81,7 +85,7 @@ int AnfTransform::AddFusionPass(const std::shared_ptr &opti fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); } if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); @@ -225,6 +229,23 @@ int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter return RET_OK; } +int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config) { + MS_ASSERT(old_graph != nullptr); + auto asylic_optimizer = std::make_shared(); + auto asylic_pm = std::make_shared("asylic pass manager", false); + // fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic + asylic_pm->AddPass(std::make_shared()); + // remove remaining cyclic nodes + asylic_pm->AddPass(std::make_shared()); + asylic_optimizer->AddPassManager(asylic_pm); + if (!asylic_optimizer->Optimize(old_graph)) { + MS_LOG(ERROR) << "gru cf fusion pass failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; + } + return RET_OK; +} + int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph) { // quant @@ -266,7 +287,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap return old_graph; } - auto status = RunAdjustPass(old_graph, config); + auto status = RunPrecedingPass(old_graph, *config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Run Preceding pass failed."; + return nullptr; + } + + status = RunAdjustPass(old_graph, config); if (status != RET_OK) { MS_LOG(ERROR) << "Run Adjust pass failed."; return nullptr; diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 7e64bc81c3..2970d69da3 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -50,6 +50,8 @@ class AnfTransform { static int AddConstFoldPass(const std::shared_ptr &optimizer, const converter::Flags *config); + static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config); + static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); diff --git a/mindspore/lite/tools/converter/ops/if.h b/mindspore/lite/tools/converter/ops/if.h deleted file mode 100644 index 12a66a2bab..0000000000 --- a/mindspore/lite/tools/converter/ops/if.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ -#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ -#include -#include -#include "ops/primitive_c.h" - -using mindspore::ops::PrimitiveC; - -namespace mindspore { -namespace lite { -constexpr auto kNameIf = "If"; -class If : public PrimitiveC { - public: - If() : PrimitiveC(kNameIf) {} - ~If() = default; - MS_DECLARE_PARENT(If, PrimitiveC); -}; -} // namespace lite -} // namespace mindspore - -#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_ diff --git a/mindspore/lite/tools/converter/ops/loop_cond.h b/mindspore/lite/tools/converter/ops/loop_cond.h deleted file mode 100644 index 25cbc4d096..0000000000 --- a/mindspore/lite/tools/converter/ops/loop_cond.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ -#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ - -#include -#include -#include -#include "ops/primitive_c.h" - -using mindspore::ops::PrimitiveC; - -namespace mindspore { -namespace lite { -constexpr auto kNameLoopCond = "LoopCond"; -class LoopCond : public PrimitiveC { - public: - LoopCond() : PrimitiveC(kNameLoopCond) {} - ~LoopCond() = default; - MS_DECLARE_PARENT(LoopCond, PrimitiveC); -}; -} // namespace lite -} // namespace mindspore - -#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_ diff --git a/mindspore/lite/tools/converter/ops/ops_def.h b/mindspore/lite/tools/converter/ops/ops_def.h index 16a565ae48..1f068fbc63 100644 --- a/mindspore/lite/tools/converter/ops/ops_def.h +++ b/mindspore/lite/tools/converter/ops/ops_def.h @@ -17,16 +17,31 @@ #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ #include "schema/inner/model_generated.h" +#include "ops/primitive_c.h" +using mindspore::ops::PrimitiveC; namespace mindspore { namespace lite { +#define ADD_CONVERTER_ONLY_OP(name) \ + constexpr auto kName##name = #name; \ + class name : public PrimitiveC { \ + public: \ + name() : PrimitiveC(kName##name) {} \ + ~name() = default; \ + MS_DECLARE_PARENT(name, PrimitiveC); \ + }; -enum ConverterPrimitiveType { - ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1, - ConverterPrimitiveType_LoopCond, - ConverterPrimitiveType_NextIteration, - ConverterPrimitiveType_Exit, -}; +ADD_CONVERTER_ONLY_OP(Enter); +ADD_CONVERTER_ONLY_OP(Exit); +ADD_CONVERTER_ONLY_OP(If); +ADD_CONVERTER_ONLY_OP(LoopCond); +ADD_CONVERTER_ONLY_OP(NextIteration); +ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3); +ADD_CONVERTER_ONLY_OP(TensorArrayReadV3); +ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3); +ADD_CONVERTER_ONLY_OP(TensorArraySizeV3); +ADD_CONVERTER_ONLY_OP(TensorArrayV3); +ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc index 6e2ced7776..1dcb54c1cb 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_if_parser.h" #include #include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "tools/converter/ops/if.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc index 35824aa655..f31d1e87a0 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc @@ -19,7 +19,7 @@ #include #include "tools/converter/parser/tf/tf_enter_parser.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/converter/ops/enter.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc index 79d0a5b3d3..26a3332ca1 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc @@ -18,7 +18,7 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/converter/ops/exit.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_if_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_if_parser.cc index 4d6d93d672..0adc6ccbf9 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_if_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_if_parser.cc @@ -19,7 +19,7 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/converter/ops/if.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc index 7a24b104ca..afba27d3c9 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc @@ -18,7 +18,7 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/converter/ops/loop_cond.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc index 2a38c4638d..70736fe8a1 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc @@ -28,7 +28,7 @@ ops::PrimitiveC *TFMergeParser::Parse(const tensorflow::NodeDef &tf_op, std::vector *inputs, int *output_size) { auto prim = std::make_unique(); - *output_size = tf_op.input_size(); + *output_size = 1; for (int i = 0; i < tf_op.input_size(); i++) { inputs->emplace_back(tf_op.input(i)); } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc index 3253f4bad3..af3e41bb51 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc @@ -18,7 +18,7 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/converter/ops/next_iteration.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc index 6e3d5e8027..5a867f7a57 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc @@ -28,7 +28,7 @@ ops::PrimitiveC *TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op, std::vector *inputs, int *output_size) { auto prim = std::make_unique(); - *output_size = tf_op.input_size(); + *output_size = 2; for (int i = 0; i < tf_op.input_size(); i++) { inputs->emplace_back(tf_op.input(i)); } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.cc new file mode 100644 index 0000000000..bf7329c376 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_gather_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArrayGatherParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArrayGatherParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return prim.release(); +} +TFNodeRegistrar g_tfTensorArrayGatherParser("TensorArrayGatherV3", new TFTensorArrayGatherParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.h new file mode 100644 index 0000000000..76ab611611 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_gather_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorArrayGatherParser : public TFNodeParser { + public: + TFTensorArrayGatherParser() = default; + ~TFTensorArrayGatherParser() override = default; + + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.cc new file mode 100644 index 0000000000..f6b4c722f0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArrayParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArrayParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + + *output_size = 2; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return prim.release(); +} +TFNodeRegistrar g_tfTensorArrayParser("TensorArrayV3", new TFTensorArrayParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/enter.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.h similarity index 50% rename from mindspore/lite/tools/converter/ops/enter.h rename to mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.h index 9290b9c40d..ddc303c73a 100644 --- a/mindspore/lite/tools/converter/ops/enter.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_parser.h @@ -13,27 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ -#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ - +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ +#include +#include +#include #include -#include -#include -#include "ops/primitive_c.h" - -using mindspore::ops::PrimitiveC; +#include "tools/converter/parser/tf/tf_node_parser.h" namespace mindspore { namespace lite { -constexpr auto kNameEnter = "Enter"; -class Enter : public PrimitiveC { +class TFTensorArrayParser : public TFNodeParser { public: - Enter() : PrimitiveC(kNameEnter) {} - ~Enter() = default; - MS_DECLARE_PARENT(Enter, PrimitiveC); + TFTensorArrayParser() = default; + ~TFTensorArrayParser() override = default; + + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; }; } // namespace lite } // namespace mindspore - -#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_ +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.cc new file mode 100644 index 0000000000..49ba480348 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_read_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArrayReadParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArrayReadParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return prim.release(); +} +TFNodeRegistrar g_tfTensorArrayReadParser("TensorArrayReadV3", new TFTensorArrayReadParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/exit.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.h similarity index 50% rename from mindspore/lite/tools/converter/ops/exit.h rename to mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.h index 5e0c431d83..287b8ae744 100644 --- a/mindspore/lite/tools/converter/ops/exit.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_read_parser.h @@ -13,27 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ -#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ - +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ +#include +#include +#include #include -#include -#include -#include "ops/primitive_c.h" - -using mindspore::ops::PrimitiveC; +#include "tools/converter/parser/tf/tf_node_parser.h" namespace mindspore { namespace lite { -constexpr auto kNameExit = "Exit"; -class Exit : public PrimitiveC { +class TFTensorArrayReadParser : public TFNodeParser { public: - Exit() : PrimitiveC(kNameExit) {} - ~Exit() = default; - MS_DECLARE_PARENT(Exit, PrimitiveC); + TFTensorArrayReadParser() = default; + ~TFTensorArrayReadParser() override = default; + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_ +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.cc new file mode 100644 index 0000000000..ca4e5f03db --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_scatter_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArrayScatterParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArrayScatterParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return prim.release(); +} +TFNodeRegistrar g_tfTensorArrayScatterParser("TensorArrayScatterV3", new TFTensorArrayScatterParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.h new file mode 100644 index 0000000000..ebbf796029 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_scatter_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorArrayScatterParser : public TFNodeParser { + public: + TFTensorArrayScatterParser() = default; + ~TFTensorArrayScatterParser() override = default; + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.cc new file mode 100644 index 0000000000..bc37f977bd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_size_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArraySizeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArraySizeParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return prim.release(); +} +TFNodeRegistrar g_tfTensorArraySizeParser("TensorArraySizeV3", new TFTensorArraySizeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/ops/next_iteration.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.h similarity index 50% rename from mindspore/lite/tools/converter/ops/next_iteration.h rename to mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.h index b98a8601b8..02e967775d 100644 --- a/mindspore/lite/tools/converter/ops/next_iteration.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_size_parser.h @@ -13,27 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ -#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ - +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ +#include +#include +#include #include -#include -#include -#include "ops/primitive_c.h" - -using mindspore::ops::PrimitiveC; +#include "tools/converter/parser/tf/tf_node_parser.h" namespace mindspore { namespace lite { -constexpr auto kNameNextIteration = "NextIteration"; -class NextIteration : public PrimitiveC { +class TFTensorArraySizeParser : public TFNodeParser { public: - NextIteration() : PrimitiveC(kNameNextIteration) {} - ~NextIteration() = default; - MS_DECLARE_PARENT(NextIteration, PrimitiveC); + TFTensorArraySizeParser() = default; + ~TFTensorArraySizeParser() override = default; + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_ +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.cc new file mode 100644 index 0000000000..47d4ca4a6a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_tensor_array_write_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace lite { +ops::PrimitiveC *TFTensorArrayWriteParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + MS_LOG(DEBUG) << "TF TensorArrayWriteParser"; + if (inputs == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "inputs or output_size is nullptr"; + return nullptr; + } + auto prim = std::make_unique(); + if (prim == nullptr) { + MS_LOG(ERROR) << "prim is nullptr"; + return nullptr; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return prim.release(); +} +TFNodeRegistrar g_tfTensorArrayWriteParser("TensorArrayWriteV3", new TFTensorArrayWriteParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.h new file mode 100644 index 0000000000..39dafec2f6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_tensor_array_write_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTensorArrayWriteParser : public TFNodeParser { + public: + TFTensorArrayWriteParser() = default; + ~TFTensorArrayWriteParser() override = default; + + ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc new file mode 100644 index 0000000000..2bf73e371a --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" +#include +#include +#include +#include "src/common/utils.h" +#include "utils/utils.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "securec/include/securec.h" +#include "tools/converter/ops/ops_def.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kNumFwVars = 4; +constexpr size_t kNumBwVars = 4; +const auto &p1 = std::placeholders::_1; +BaseRef GetPrim(const PrimitivePtr &prim) { return std::make_shared(std::bind(IsOpType, p1, prim)); } + +BaseRef GetPrim(const std::string &prim_name) { + auto prim = std::make_shared(prim_name); + return GetPrim(prim); +} +} // namespace + +TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph) + : TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) { + /* + * vars for fw/bw input + * fw: + * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias + * bw: + * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias + */ +} + +BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true, + const std::vector &vars) const { + auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true}); + auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]}); // gate_kernel + auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter}); + auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]}); // cand_bias + auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter}); + auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias}); + auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid}); + auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared()}); + auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared()}); + auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true}); + auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul}); + auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]}); // cand_kernel + auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter}); + auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]}); // cand_bias + auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter}); + auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1}); + auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared(IsParameterNode), zt}); + auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh}); + auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true}); + auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2}); + return add; +} + +const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input, + const std::vector &vars, + const VarPtr &init_state) const { + // in order to match cyclic graph, some node in cycle is represented by SeqVar + auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input}); + auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared()}); + auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared()}); + auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared(IsParameterNode), fw_max}); + auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum}); + auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum}); + // SeqVar:counter_merge1 + auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared(), fw_less1_enter}); + + // SeqVar:fw_merge,loop_cond + auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared()}); + auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared()}); // identity + auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared(IsParameterNode)}); + auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add}); + auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared(IsParameterNode)}); + auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter}); + auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice}); + auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter}); + + auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1}); + // SeqVar:fw_logical_and + auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and}); + + auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input}); + auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared()}); + auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared(IsParameterNode), + fw_unstack_strided_slice, std::make_shared(IsParameterNode)}); + + // SeqVar:switch1_true + auto counter_add = + VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared(), std::make_shared(IsParameterNode)}); + auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared(IsParameterNode)}); + auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add}); + auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter}); + auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond}); + auto switch1_true = + VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared()}); // identity1 + + auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice}); + auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared()}); + auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared()}); + auto fw_unstack_ta_scatter = + VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow}); + auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter}); + auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle}); + auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1}); + + auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_}); + auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter}); + auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared()}); // select h + + auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1}); + auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state}); + auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3}); + auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond}); + auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared()}); // identity3 + + auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars); + + auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice}); + auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared()}); + auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared()}); + auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle}); + + auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared()}); // cycle + + auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared()}); + auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared(IsParameterNode), concat1}); + auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1}); + auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out}); // select x + auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true}); + + auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow}); + auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write}); + auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2}); + auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond}); + auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared()}); + + auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false}); + auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2}); + auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared(), ta_size, std::make_shared()}); + auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2}); + auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared()}); + auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared(IsParameterNode), range1}); + auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2}); + return fw_out_trans; +} + +const BaseRef TfBidirectionGruCfFusion::DefinePattern() const { + const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_); + + auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_}); + auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared()}); + auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared(IsParameterNode), bw_range}); + auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat}); + auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_); + auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_}); + auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out}); + return concat; +} + +const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(concat_node != nullptr); + MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; + + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + + auto transpose_input = utils::cast((*equiv)[transpose_input_]); + MS_ASSERT(transpose_input != nullptr); + + const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); + auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0); + if (gru_node == nullptr) { + return nullptr; + } + if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) { + return nullptr; + } + + auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0); + if (get_item_node == nullptr) { + return nullptr; + } + + auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope()); + MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success"; + return output_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h new file mode 100644 index 0000000000..a54a797fd1 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ + +#include +#include +#include +#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" +#include "schema/inner/model_generated.h" +#include "src/param_value_lite.h" +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace opt { +// fuse tf 1.x bidirection_gru into MSLITE GRU +class TfBidirectionGruCfFusion : public TfBidirectionGruFusion { + public: + explicit TfBidirectionGruCfFusion(const std::string &name = "tf_bidirection_gru_cf_fusion", bool multi_graph = true); + ~TfBidirectionGruCfFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + BaseRef DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true, + const std::vector &vars) const; + const BaseRef DefineBidirectionRnnPattern(const BaseRef &input, const std::vector &vars, + const VarPtr &init_state) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc similarity index 88% rename from mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc rename to mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 730202a47d..a04ffe7cdc 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" +#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" #include #include #include "ops/concat.h" @@ -24,32 +24,21 @@ #include "ops/transpose.h" #include "src/common/utils.h" #include "utils/utils.h" -#include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" namespace mindspore { namespace opt { namespace { -constexpr size_t kWhileUniqInputsLength = 6; constexpr size_t kCondNodesNum = 12; constexpr size_t kCondCNodesNum = 4; constexpr size_t kBodyNodesNum = 69; constexpr size_t kBodyCNodesNum = 25; const auto &p1 = std::placeholders::_1; - -bool IsParameterNode(const BaseRef &n) { return utils::isa(n); } - -bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { - if (utils::isa(n)) { - auto anf_node = utils::cast(n); - return CheckPrimitiveType(anf_node, prim); - } - return false; -} } // namespace -BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph) - : PatternProcessPass(name, multigraph) { +TfBidirectionGruFusion::TfBidirectionGruFusion(int num_fw_vars, int num_bw_vars, const std::string &name, + bool multi_graph) + : PatternProcessPass(name, multi_graph) { /* * vars for while input * fw_while_inputs: @@ -57,8 +46,10 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, * bw_while_inputs: * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias */ - for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { + for (int i = 0; i < num_fw_vars; ++i) { fw_vars_.emplace_back(std::make_shared()); + } + for (int i = 0; i < num_bw_vars; ++i) { bw_vars_.emplace_back(std::make_shared()); } input_ = std::make_shared(); @@ -68,7 +59,7 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bw_init_state_ = std::make_shared(); } -const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { +const BaseRef TfBidirectionGruFusion::DefinePattern() const { // forward auto fw_reduce = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_, std::make_shared(IsParameterNode)}); @@ -134,7 +125,7 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { return concat; } -AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { +AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { auto is_parameter1 = std::make_shared(IsParameterNode); auto is_parameter2 = std::make_shared(IsParameterNode); auto is_parameter3 = std::make_shared(IsParameterNode); @@ -152,7 +143,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMap return pattern; } -AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { +AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const { std::vector placeholders; for (int i = 0; i < 13; ++i) { placeholders.emplace_back(std::make_shared(IsParameterNode)); @@ -206,7 +197,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMap return pattern; } -ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { +ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { MS_ASSERT(parameter_anf != nullptr); if (!utils::isa(parameter_anf)) { MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; @@ -221,9 +212,9 @@ ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNode return param_value; } -STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, - const AnfNodePtr &bw_cand_kernel_anf, int *input_size, - int *hidden_size) const { +STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, + const AnfNodePtr &bw_cand_kernel_anf, int *input_size, + int *hidden_size) const { MS_ASSERT(fw_cand_kernel != nullptr); MS_ASSERT(bw_cand_kernel != nullptr); MS_ASSERT(input_size != nullptr); @@ -256,9 +247,9 @@ STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_ca return RET_OK; } -ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, - const std::vector &shape, const TypeId type, - void **tensor_data) const { +ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, + const std::vector &shape, const TypeId type, + void **tensor_data) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(tensor_data != nullptr); auto parameter = func_graph->add_parameter(); @@ -300,9 +291,8 @@ ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr return parameter; } -void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, - const int r1, const int c0, const int c1, float *data, - bool t) const { +void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, + const int c0, const int c1, float *data, bool t) const { MS_ASSERT(mat != nullptr); MS_ASSERT(data != nullptr); MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R); @@ -320,9 +310,9 @@ void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int } } -STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, - const int input_size, const int hidden_size, - float *gate_tensor_data, float *recu_tensor_data) const { +STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, + const int input_size, const int hidden_size, float *gate_tensor_data, + float *recu_tensor_data) const { MS_ASSERT(gate_weight != nullptr); MS_ASSERT(cand_weight != nullptr); MS_ASSERT(gate_tensor_data != nullptr); @@ -375,8 +365,8 @@ STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weig return RET_OK; } -STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, - const int hidden_size, float *tensor_data) const { +STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, + const int hidden_size, float *tensor_data) const { MS_ASSERT(bias != nullptr); MS_ASSERT(tensor_data != nullptr); std::vector gate_shape{hidden_size * 2}; @@ -407,10 +397,9 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, return RET_OK; } -CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, - const AnfNodePtr &fw_init_state, - const AnfNodePtr &bw_init_state, - const std::string base_name) const { +CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, + const AnfNodePtr &bw_init_state, + const std::string base_name) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(fw_init_state != nullptr); MS_ASSERT(bw_init_state != nullptr); @@ -424,35 +413,32 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f return new_node; } -CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, - const EquivPtr &equiv, const EquivPtr &fw_body_equiv, - const EquivPtr &bw_body_equiv, - const std::string &base_name) const { +CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, + const EquivPtr &equiv, const std::string &base_name, + int var_offset) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(input != nullptr); MS_ASSERT(equiv != nullptr); - MS_ASSERT(fw_body_equiv != nullptr); - MS_ASSERT(bw_body_equiv != nullptr); auto gru_prim = std::make_shared(); gru_prim->set_bidirectional(true); auto value_node = NewValueNode(gru_prim); - auto fw_gate_kernel = utils::cast((*equiv)[fw_vars_[2]]); + auto fw_gate_kernel = utils::cast((*equiv)[fw_vars_[var_offset]]); MS_ASSERT(fw_gate_kernel != nullptr); - auto fw_gate_bias = utils::cast((*equiv)[fw_vars_[3]]); + auto fw_gate_bias = utils::cast((*equiv)[fw_vars_[var_offset + 1]]); MS_ASSERT(fw_gate_bias != nullptr); - auto fw_cand_kernel = utils::cast((*equiv)[fw_vars_[4]]); + auto fw_cand_kernel = utils::cast((*equiv)[fw_vars_[var_offset + 2]]); MS_ASSERT(fw_cand_kernel != nullptr); - auto fw_cand_bias = utils::cast((*equiv)[fw_vars_[5]]); + auto fw_cand_bias = utils::cast((*equiv)[fw_vars_[var_offset + 3]]); MS_ASSERT(fw_cand_bias != nullptr); - auto bw_gate_kernel = utils::cast((*equiv)[bw_vars_[2]]); + auto bw_gate_kernel = utils::cast((*equiv)[bw_vars_[var_offset]]); MS_ASSERT(bw_gate_kernel != nullptr); - auto bw_gate_bias = utils::cast((*equiv)[bw_vars_[3]]); + auto bw_gate_bias = utils::cast((*equiv)[bw_vars_[var_offset + 1]]); MS_ASSERT(bw_gate_bias != nullptr); - auto bw_cand_kernel = utils::cast((*equiv)[bw_vars_[4]]); + auto bw_cand_kernel = utils::cast((*equiv)[bw_vars_[var_offset + 2]]); MS_ASSERT(bw_cand_kernel != nullptr); - auto bw_cand_bias = utils::cast((*equiv)[bw_vars_[5]]); + auto bw_cand_bias = utils::cast((*equiv)[bw_vars_[var_offset + 3]]); MS_ASSERT(bw_cand_bias != nullptr); auto fw_init_state = utils::cast((*equiv)[fw_init_state_]); @@ -522,8 +508,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr return new_node; } -CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, - const std::string base_name) const { +CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, + const std::string base_name) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(gru_output != nullptr); auto split_prim = std::make_shared(); @@ -571,8 +557,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func return transpose_new_node; } -const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, - const EquivPtr &equiv) const { +const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, + const EquivPtr &equiv) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(concat_node != nullptr); MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; @@ -628,7 +614,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr } const std::string gru_name = "gru_" + concat_node->fullname_with_scope(); - auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name); + auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2); if (gru_node == nullptr) { return nullptr; } diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h similarity index 74% rename from mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h rename to mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h index a4222e47d7..0c31514926 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ -#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ #include #include #include #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" +#include "tools/optimizer/common/gllo_utils.h" #include "schema/inner/model_generated.h" #include "src/param_value_lite.h" #include "backend/optimizer/common/optimizer.h" @@ -27,22 +28,26 @@ namespace mindspore { namespace opt { -class BiDirectionTfGruCellFusion : public PatternProcessPass { +constexpr size_t kWhileUniqInputsLength = 6; +// fuse tf 2.x bidirection_gru into MSLITE GRU +class TfBidirectionGruFusion : public PatternProcessPass { public: - explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion", - bool multigraph = true); - ~BiDirectionTfGruCellFusion() override = default; + explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength, + const std::string &name = "tf_bidirection_gru_fusion", bool multi_graph = true); + ~TfBidirectionGruFusion() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; protected: virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; + CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, + const std::string &base_name, int var_offset) const; + CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, + const std::string base_name) const; private: AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; - CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv, - const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv, - const std::string &base_name) const; + ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, int *input_size, int *hidden_size) const; @@ -56,10 +61,8 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { const int c1, float *data, bool t = false) const; CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, const AnfNodePtr &bw_init_state, const std::string base_name) const; - CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, - const std::string base_name) const; - private: + protected: std::vector fw_vars_; std::vector bw_vars_; VarPtr input_; @@ -68,7 +71,16 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { VarPtr fw_init_state_; VarPtr bw_init_state_; }; +inline bool IsParameterNode(const BaseRef &n) { return utils::isa(n); } + +inline bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { + if (utils::isa(n)) { + auto anf_node = utils::cast(n); + return CheckPrimitiveType(anf_node, prim); + } + return false; +} } // namespace opt } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_ +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_cond.cc b/mindspore/lite/tools/optimizer/graph/functionalize_cond.cc new file mode 100644 index 0000000000..ca241be534 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_cond.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/graph/functionalize_cond.h" +#include +#include +#include +#include +#include "include/errorcode.h" +#include "ops/make_tuple.h" +#include "tools/converter/ops/ops_def.h" +#include "ops/return.h" + +namespace mindspore::opt { + +STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) { + MS_ASSERT(switch_cnode != nullptr); + MS_ASSERT(branch_type != nullptr); + auto manager = fg_->manager(); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr"; + return RET_ERROR; + } + auto node_users = manager->node_users()[switch_cnode]; + if (node_users.size() != 1) { // only one output of switch is referenced in cond + MS_LOG(ERROR) << "switch's node users is not correct"; + return RET_ERROR; + } + auto node_user = node_users.front(); + auto tuple_get_item = node_user.first; + if (!utils::isa(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) { + MS_LOG(ERROR) << "switch's node user is not TupleGetItem"; + return RET_ERROR; + } + auto tuple_get_item_cnode = utils::cast(tuple_get_item); + auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode); + if (idx == 0) { + *branch_type = kElseBranch; + } else if (idx == 1) { + *branch_type = kThenBranch; + } else { + MS_LOG(ERROR) << "wrong tuple_get_item index"; + return RET_ERROR; + } + return RET_OK; +} + +STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, + BranchType branch_type) { + std::deque q; + std::unordered_set vis; + q.push_back(root_node); + while (!q.empty()) { + auto node = q.front(); + q.pop_front(); + vis.insert(node); + if (FunctionalizeControlOpPass::IsSwitch(node)) { + auto cnode = utils::cast(node); + BranchType this_type; + if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) { + MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct"; + return RET_ERROR; + } + continue; + } + if (utils::isa(node)) { + graph->add_parameter(node->cast()); + } else { + graph->AddNode(node); + } + node->set_func_graph(graph); + if (utils::isa(node)) { + auto cnode = utils::cast(node); + for (size_t i = 1; i < cnode->inputs().size(); i++) { + auto inputi = cnode->input(i); + if (vis.find(inputi) == vis.end()) { + q.push_back(cnode->input(i)); + } + } + } + } + return RET_OK; +} + +int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) { + auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node); + if (index == input_nodes_.end()) { + input_nodes_.push_back(node); + return input_nodes_.size() - 1; + } + return index - input_nodes_.begin(); +} + +STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) { + std::vector nodes_need_drop{}; + for (auto &cnode : graph->GetOrderedCnodes()) { + for (auto &input_node : cnode->inputs()) { + if (FunctionalizeControlOpPass::IsSwitch(input_node)) { + auto switch_node = input_node->cast(); + auto switch_input = utils::cast(switch_node->input(1)); + auto pos = PosInInputNodes(switch_input); + nodes_need_drop.push_back(cnode); + pred_nodes_.push_back(switch_node->input(2)); + // set parameter + auto parameter = graph->add_parameter(); + parameter->set_abstract(cnode->abstract()); + // hardcode for subgraph input name + parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter"); + + // replace switch + auto manager = fg_->manager(); + auto node_users = manager->node_users()[cnode]; + for (auto &node_user : node_users) { + if (graph->nodes().contains(node_user.first)) { + manager->SetEdge(node_user.first, node_user.second, parameter); + } + } + } + } + } + return RET_OK; +} + +FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) { + auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF); + if (graph == nullptr) { + MS_LOG(ERROR) << "new graph Partial Node return nullptr"; + return nullptr; + } + graph->set_manager(fg_->manager()); + auto status = BranchSubGraphAddNodes(graph, node, branch_type); + if (status != RET_OK) { + return nullptr; + } + + if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty + auto return_prim_ptr = std::make_shared(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return nullptr; + } + auto value_node = NewValueNode(return_prim_ptr); + std::vector op_inputs{value_node, node}; // If subgraph only has one output tensor + auto return_cnode = graph->NewCNode(op_inputs); + return_cnode->set_fullname_with_scope(name + "-return"); + return_cnode->set_func_graph(graph); + graph->set_return(return_cnode); + graph->output()->cast()->set_fullname_with_scope(name + "_output_0_cnode"); + } + return graph; +} + +CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) { + MS_ASSERT(else_branch != nullptr); + MS_ASSERT(then_branch != nullptr); + + auto if_primc = std::make_shared(); + if (if_primc == nullptr) { + MS_LOG(ERROR) << "new if_primitive failed"; + return nullptr; + } + auto if_value_node = NewValueNode(if_primc); + if (if_value_node == nullptr) { + return nullptr; + } + auto then_value_node = NewValueNode(then_branch); + auto else_value_node = NewValueNode(else_branch); + std::vector if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_}; + std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs)); + return fg_->NewCNode(if_op_inputs); +} + +STATUS FunctionalizeCond::VerifyPredictNode() { + if (pred_nodes_.empty()) { + return RET_ERROR; + } + for (size_t i = 1; i < pred_nodes_.size(); ++i) { + if (pred_nodes_[i] != pred_nodes_[0]) { + return RET_ERROR; + } + } + if (!utils::isa(pred_nodes_[0])) { + return RET_ERROR; + } + pred_node_ = utils::cast(pred_nodes_[0]); + return RET_OK; +} + +STATUS FunctionalizeCond::Process() { + if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) { + MS_LOG(ERROR) << "fg or merge is not correct"; + return RET_ERROR; + } + + auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else"; + auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else"; + + auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch); + if (else_branch == nullptr) { + MS_LOG(ERROR) << "create else branch failed"; + return RET_ERROR; + } + auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch); + if (then_branch == nullptr) { + MS_LOG(ERROR) << "create then branch failed"; + return RET_ERROR; + } + + auto status = IdentifySubgraphInput(else_branch, else_branch_name); + if (status != RET_OK) { + return status; + } + status = IdentifySubgraphInput(then_branch, then_branch_name); + if (status != RET_OK) { + return status; + } + + status = VerifyPredictNode(); + if (status != RET_OK) { + return status; + } + + auto if_node = CreateNewIf(else_branch, then_branch); + if (if_node == nullptr) { + MS_LOG(ERROR) << "create if node error"; + return RET_ERROR; + } + if_node->set_abstract(merge_node_->abstract()->Clone()); + auto manager = fg_->manager(); + auto node_users = manager->node_users()[merge_node_]; + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, if_node); + } + return RET_OK; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_cond.h b/mindspore/lite/tools/optimizer/graph/functionalize_cond.h new file mode 100644 index 0000000000..2f150669b5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/functionalize_cond.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ + +#include +#include +#include +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "tools/optimizer/graph/functionalize_control_op_pass.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { + +typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType; + +// Functionalize all the switch-merge nodes of a loop-free graph into single switch node. +// Precondition: While loops must have been functionalized. +class FunctionalizeCond { + public: + FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {} + + STATUS Process(); + + private: + STATUS GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type); + STATUS BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, BranchType branch_type); + FuncGraphPtr CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type); + int PosInInputNodes(const CNodePtr &node); + STATUS IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name); + CNodePtr CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch); + STATUS VerifyPredictNode(); + + FuncGraphPtr fg_ = nullptr; + CNodePtr merge_node_ = nullptr; + CNodePtr pred_node_ = nullptr; + std::vector input_nodes_{}; + std::vector pred_nodes_{}; +}; +} // namespace mindspore::opt + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_ diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc index 527d7cfe0c..2a8704f51b 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc @@ -18,6 +18,7 @@ #include #include #include "tools/optimizer/graph/functionalize_while.h" +#include "tools/optimizer/graph/functionalize_cond.h" #include "include/errorcode.h" namespace mindspore::opt { @@ -100,6 +101,25 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g return ret; } +STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) { + int ret = RET_OK; + auto nodes = func_graph->nodes(); + for (auto &node : nodes) { + if (!IsMerge(node)) { + continue; + } + auto cnode = utils::cast(node); + FunctionalizeCond fc(func_graph, cnode); + ret = fc.Process(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret; + return ret; + } + } + + return ret; +} + bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { // use name to find the frame InitNodeClusters(func_graph); @@ -107,6 +127,10 @@ bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "build while subgraph failed."; return false; } + if (BuildIfSubgraph(func_graph) != RET_OK) { + MS_LOG(ERROR) << "build while subgraph failed."; + return false; + } return true; } diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h index 3d470943e0..3e23fdd73c 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h +++ b/mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h @@ -23,10 +23,7 @@ #include #include "backend/optimizer/common/pass.h" #include "tools/converter/converter_flags.h" -#include "tools/converter/ops/enter.h" -#include "tools/converter/ops/exit.h" -#include "tools/converter/ops/loop_cond.h" -#include "tools/converter/ops/next_iteration.h" +#include "tools/converter/ops/ops_def.h" #include "tools/optimizer/common/gllo_utils.h" using mindspore::lite::converter::FmkType; @@ -70,6 +67,7 @@ class FunctionalizeControlOpPass : public Pass { protected: STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph); + STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph); std::vector>> node_clusters_{}; std::vector loop_cond_nodes_{}; }; diff --git a/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc new file mode 100644 index 0000000000..8f0466bc03 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/unused_node_remove_pass.h" +#include +#include +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { + +STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto return_node = func_graph->get_return(); + if (return_node == nullptr) { + return RET_OK; + } + std::unordered_set vis; + std::deque q; + q.push_back(return_node); + while (!q.empty()) { + auto node = q.front(); + vis.insert(node); + q.pop_front(); + if (utils::isa(node)) { + auto cnode = utils::cast(node); + for (auto &input : cnode->inputs()) { + if (vis.find(input) == vis.end()) { + q.push_back(input); + } + } + } + if (utils::isa(node)) { + auto sub_graph = utils::cast(node); + auto status = ProcessGraph(sub_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "process sub graph failed"; + return RET_ERROR; + } + } + } + auto nodes = func_graph->nodes(); + for (auto &node : nodes) { + if (vis.find(node) == vis.end()) { + func_graph->DropNode(node); + } + } + return RET_OK; +} + +bool UnusedNodeRemovePass::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto status = ProcessGraph(func_graph); + return status == RET_OK; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.h b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.h new file mode 100644 index 0000000000..26ba43dceb --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_ + +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "mindspore/lite/include/errorcode.h" + +using mindspore::lite::STATUS; +namespace mindspore::opt { +class UnusedNodeRemovePass : public Pass { + public: + UnusedNodeRemovePass() : Pass("remove_unused_node_pass") {} + ~UnusedNodeRemovePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + STATUS ProcessGraph(const FuncGraphPtr &func_graph); +}; +} // namespace mindspore::opt + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_