| @@ -1103,10 +1103,13 @@ schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } | |||
| #endif | |||
| int PrimitiveC::Type() const { | |||
| if (this->primitive_ == nullptr) { | |||
| if (this->primitive_ == nullptr && this->op_type_ == OP_TYPE_NOT_SET) { | |||
| return schema::PrimitiveType_NONE; | |||
| } | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| if (op_type_ != OP_TYPE_NOT_SET) { | |||
| return op_type_; | |||
| } | |||
| return this->primitive_->value.type; | |||
| #else | |||
| return this->primitive_->value_type(); | |||
| @@ -24,6 +24,9 @@ | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include "ir/primitive.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "schema/inner/ops_generated.h" | |||
| #include "schema/ops_generated.h" | |||
| #include "tools/converter/ops/ops_def.h" | |||
| #else | |||
| #include "schema/model_generated.h" | |||
| #endif | |||
| @@ -34,6 +37,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr const int OP_TYPE_NOT_SET = -1; | |||
| constexpr uint32_t kSingleNum = 1; | |||
| constexpr uint32_t kDoubleNum = 2; | |||
| constexpr uint32_t kMultiNum = 3; | |||
| @@ -149,6 +153,7 @@ class PrimitiveC : public mindspore::Primitive { | |||
| std::vector<std::vector<schema::QuantParamT>> output_quant_param_; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| bool infer_flag_ = true; | |||
| int op_type_ = OP_TYPE_NOT_SET; | |||
| }; | |||
| std::shared_ptr<PrimitiveC> GetReturnPrim(); | |||
| @@ -227,6 +232,7 @@ class PrimitiveC { | |||
| char *primitive_buf_ = nullptr; | |||
| bool infer_flag_ = true; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| int op_type_ = OP_TYPE_NOT_SET; | |||
| }; | |||
| using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | |||
| typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | |||
| @@ -2,6 +2,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..) | |||
| set(TEST_DIR ${TOP_DIR}/mindspore/lite/test) | |||
| set(LITE_DIR ${TOP_DIR}/mindspore/lite) | |||
| set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc) | |||
| set(CONVERTER_DIR ${TOP_DIR}/mindspore/lite/tools/converter) | |||
| include_directories(${TOP_DIR}) | |||
| include_directories(${TEST_DIR}) | |||
| include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/gtest.cmake) | |||
| @@ -16,7 +17,7 @@ set(CCSRC_SRC | |||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||
| ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc | |||
| ) | |||
| else(ENABLE_CONVERTER) | |||
| else() | |||
| set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) | |||
| add_compile_definitions(USE_ANDROID_LOG) | |||
| endif() | |||
| @@ -38,10 +39,10 @@ file(GLOB KERNEL_OP_TRAIN_SRC | |||
| ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc | |||
| ) | |||
| if (SUPPORT_TRAIN) | |||
| if(SUPPORT_TRAIN) | |||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||
| endif() | |||
| if (PLATFORM_ARM64) | |||
| if(PLATFORM_ARM64) | |||
| # assembly | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/assembly/arm64/*.s | |||
| ${LITE_DIR}/nnacl/assembly/arm64/*.S) | |||
| @@ -53,7 +54,7 @@ if (PLATFORM_ARM64) | |||
| ) | |||
| endif() | |||
| if (PLATFORM_ARM32) | |||
| if(PLATFORM_ARM32) | |||
| # assembly | |||
| file(GLOB TEST_ASSEMBLY_SRC | |||
| ${LITE_DIR}/nnacl/assembly/arm32/*.S | |||
| @@ -65,7 +66,7 @@ if (PLATFORM_ARM32) | |||
| ) | |||
| endif() | |||
| if ("${X86_64_SIMD}" STREQUAL "sse") | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c) | |||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_OP_SRC | |||
| @@ -74,12 +75,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse") | |||
| ) | |||
| endif() | |||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2") | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c | |||
| ${LITE_DIR}/nnacl/x86_64_avx/*.c | |||
| ${LITE_DIR}/nnacl/assembly/avx/*.S) | |||
| ${LITE_DIR}/nnacl/x86_64_avx/*.c | |||
| ${LITE_DIR}/nnacl/assembly/avx/*.S) | |||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_OP_SRC | |||
| ${KERNEL_OP_SRC} | |||
| @@ -88,7 +89,7 @@ if ("${X86_64_SIMD}" STREQUAL "avx") | |||
| endif() | |||
| ### gpu kernel | |||
| if (SUPPORT_GPU) | |||
| if(SUPPORT_GPU) | |||
| file(GLOB GPU_KERNEL_OP_SRC | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc | |||
| ) | |||
| @@ -102,14 +103,18 @@ if (SUPPORT_GPU) | |||
| ) | |||
| endif() | |||
| if (PLATFORM_ARM32 OR PLATFORM_ARM64) | |||
| if (ENABLE_CONVERTER) | |||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | |||
| if(ENABLE_CONVERTER) | |||
| set(BUILD_MINDDATA "off") | |||
| endif() | |||
| endif() | |||
| ### runtime framework | |||
| add_definitions(-DENABLE_V0) | |||
| file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) | |||
| if(ENABLE_CONVERTER) | |||
| file(GLOB_RECURSE CONVERTER_OPS_SRC ${CONVERTER_DIR}/ops/*.cc) | |||
| set(OPS_SRC ${OPS_SRC} ${CONVERTER_OPS_SRC}) | |||
| endif() | |||
| set(TEST_LITE_SRC | |||
| ${TEST_LITE_SRC} | |||
| ${CCSRC_SRC} | |||
| @@ -144,7 +149,7 @@ set(TEST_LITE_SRC | |||
| ${LITE_DIR}/src/errorcode.cc | |||
| ) | |||
| ### gpu runtime | |||
| if (SUPPORT_GPU) | |||
| if(SUPPORT_GPU) | |||
| include_directories(${TOP_DIR}/third_party/OpenCL-Headers) | |||
| include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) | |||
| set(OPENCL_RUNTIME_SRC | |||
| @@ -210,13 +215,14 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/while_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/if_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| if (SUPPORT_TRAIN) | |||
| if(SUPPORT_TRAIN) | |||
| set(TEST_LITE_SRC | |||
| ${TEST_LITE_SRC} | |||
| # ${LITE_DIR}/src/train/ops/train_ops.cc | |||
| ${LITE_DIR}/src/train/train_populate_parameter.cc | |||
| ${LITE_DIR}/src/train/train_session.cc | |||
| ${LITE_DIR}/src/train/train_model.cc | |||
| @@ -251,7 +257,7 @@ set(TEST_SRC | |||
| ${TEST_DIR}/ut/src/scheduler_test.cc | |||
| ) | |||
| if (ENABLE_CONVERTER) | |||
| if(ENABLE_CONVERTER) | |||
| set(TEST_SRC | |||
| ${TEST_SRC} | |||
| ${TEST_DIR}/st/converter_test.cc | |||
| @@ -265,7 +271,7 @@ if (ENABLE_CONVERTER) | |||
| ) | |||
| endif() | |||
| if (SUPPORT_TRAIN) | |||
| if(SUPPORT_TRAIN) | |||
| set(TEST_SRC | |||
| ${TEST_SRC} | |||
| ${TEST_CASE_KERNEL_TRAIN_SRC} | |||
| @@ -278,7 +284,7 @@ else() | |||
| ) | |||
| endif() | |||
| if (SUPPORT_GPU) | |||
| if(SUPPORT_GPU) | |||
| file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc | |||
| ) | |||
| @@ -288,7 +294,7 @@ if (SUPPORT_GPU) | |||
| ) | |||
| endif() | |||
| if (ENABLE_FP16) | |||
| if(ENABLE_FP16) | |||
| file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc | |||
| ) | |||
| @@ -296,24 +302,24 @@ if (ENABLE_FP16) | |||
| ${TEST_SRC} | |||
| ${TEST_CASE_KERNEL_FP16_SRC} | |||
| ) | |||
| endif () | |||
| endif() | |||
| add_executable(lite-test ${TEST_SRC}) | |||
| add_dependencies(lite-test fbs_src) | |||
| target_link_libraries(lite-test dl mindspore::gtest) | |||
| if (PLATFORM_ARM64 AND ENABLE_FP16) | |||
| if(PLATFORM_ARM64 AND ENABLE_FP16) | |||
| target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | |||
| endif() | |||
| if (PLATFORM_ARM) | |||
| if(PLATFORM_ARM) | |||
| target_link_libraries(lite-test log) | |||
| endif() | |||
| if (SUPPORT_NPU) | |||
| if(SUPPORT_NPU) | |||
| include_directories(${DDK_PATH}) | |||
| target_link_libraries(lite-test npu_kernel_mid) | |||
| endif () | |||
| if (ENABLE_CONVERTER) | |||
| endif() | |||
| if(ENABLE_CONVERTER) | |||
| add_dependencies(lite-test fbs_inner_src) | |||
| target_link_libraries(lite-test | |||
| anf_importer_mid | |||
| @@ -12,6 +12,8 @@ include(${TOP_DIR}/cmake/external_libs/glog.cmake) | |||
| file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc) | |||
| file(GLOB CONVERTER_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc) | |||
| file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc | |||
| @@ -65,6 +67,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/while_pass.cc | |||
| ../optimizer/graph/if_pass.cc | |||
| ../optimizer/graph/mindir_inputs_adjust_pass.cc | |||
| ../optimizer/graph/functionalize_control_op_pass.cc | |||
| ../optimizer/graph/functionalize_while.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -97,12 +101,12 @@ set(LITE_SRC | |||
| ${SRC_DIR}/errorcode.cc | |||
| ${SRC_DIR}/dequant.cc | |||
| ) | |||
| if (SUPPORT_TRAIN) | |||
| if(SUPPORT_TRAIN) | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ) | |||
| endif () | |||
| endif() | |||
| set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) | |||
| file(GLOB KERNEL_SRC | |||
| ${ARM_DIR}/base/*.cc | |||
| @@ -114,13 +118,13 @@ file(GLOB KERNEL_SRC | |||
| ${ARM_DIR}/int8/*.cc | |||
| ) | |||
| if (PLATFORM_ARM64) | |||
| if(PLATFORM_ARM64) | |||
| # assembly | |||
| file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.s | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.S) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | |||
| endif () | |||
| endif() | |||
| file(GLOB PROTO_FILE "" | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | |||
| @@ -133,11 +137,13 @@ add_library(proto_mid OBJECT ${PROTO_SRCS}) | |||
| set(TFLITE_FBS_FILES | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/schema.fbs | |||
| ) | |||
| ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src ${CMAKE_BINARY_DIR}/schema "inner") | |||
| ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src | |||
| ${CMAKE_BINARY_DIR}/schema "inner") | |||
| set_property(SOURCE ${CONVERTER_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| set_property(SOURCE ${OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| set_property(SOURCE ${CONVERTER_OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| set_property(SOURCE ${KERNEL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_executable(converter_lite | |||
| @@ -145,6 +151,7 @@ add_executable(converter_lite | |||
| ${CCSRC_SRC} | |||
| ${CONVERTER_SRC} | |||
| ${OPS_SRC} | |||
| ${CONVERTER_OPS_SRC} | |||
| ${KERNEL_SRC} | |||
| ${LITE_SRC} | |||
| ) | |||
| @@ -48,6 +48,7 @@ | |||
| #include "tools/optimizer/graph/slice_prepose_pass.h" | |||
| #include "tools/optimizer/graph/while_pass.h" | |||
| #include "tools/optimizer/graph/if_pass.h" | |||
| #include "tools/optimizer/graph/functionalize_control_op_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -100,6 +101,15 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| } | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_TF) { | |||
| auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>(); | |||
| if (!functionalize_control_op_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "functionalize control op pass failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || | |||
| config->fmk == lite::converter::FmkType_ONNX) { | |||
| graph_pm->AddPass(std::make_shared<opt::WhilePass>()); | |||
| @@ -145,7 +155,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| if (config->fmk == lite::converter::FmkType_MS) { | |||
| auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); | |||
| if (remove_unused_cast_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified"; | |||
| MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified"; | |||
| return nullptr; | |||
| } | |||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||
| @@ -154,7 +164,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | |||
| if (remove_unused_transpose_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified"; | |||
| MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified"; | |||
| return nullptr; | |||
| } | |||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/ops/enter.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int Enter::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->set_shape(input->shape()); | |||
| output->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | |||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Enter : public PrimitiveC { | |||
| public: | |||
| Enter() { op_type_ = ConverterPrimitiveType_Enter; } | |||
| ~Enter() = default; | |||
| MS_DECLARE_PARENT(Enter, PrimitiveC); | |||
| explicit Enter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_ENTER_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/ops/exit.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int Exit::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->set_shape(input->shape()); | |||
| output->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | |||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Exit : public PrimitiveC { | |||
| public: | |||
| Exit() { op_type_ = ConverterPrimitiveType_Exit; } | |||
| ~Exit() = default; | |||
| MS_DECLARE_PARENT(Exit, PrimitiveC); | |||
| explicit Exit(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_EXIT_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/ops/loop_cond.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int LoopCond::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->set_shape(input->shape()); | |||
| output->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | |||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class LoopCond : public PrimitiveC { | |||
| public: | |||
| LoopCond() { op_type_ = ConverterPrimitiveType_LoopCond; } | |||
| ~LoopCond() = default; | |||
| MS_DECLARE_PARENT(LoopCond, PrimitiveC); | |||
| explicit LoopCond(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/ops/next_iteration.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int NextIteration::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->set_shape(input->shape()); | |||
| output->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | |||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class NextIteration : public PrimitiveC { | |||
| public: | |||
| NextIteration() { op_type_ = ConverterPrimitiveType_NextIteration; } | |||
| ~NextIteration() = default; | |||
| MS_DECLARE_PARENT(NextIteration, PrimitiveC); | |||
| explicit NextIteration(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ | |||
| #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| enum ConverterPrimitiveType { | |||
| ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1, | |||
| ConverterPrimitiveType_LoopCond, | |||
| ConverterPrimitiveType_NextIteration, | |||
| ConverterPrimitiveType_Exit, | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_enter_parser.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/enter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFEnterParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF EnterParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| *primitiveC = new (std::nothrow) Enter(); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfEnterParser("Enter", new TFEnterParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFEnterParser : public TFNodeParser { | |||
| public: | |||
| TFEnterParser() = default; | |||
| ~TFEnterParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_exit_parser.h" | |||
| #include <string> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/exit.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFExitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF ExitParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| *primitiveC = new (std::nothrow) Exit(); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfExitParser("Exit", new TFExitParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFExitParser : public TFNodeParser { | |||
| public: | |||
| TFExitParser() = default; | |||
| ~TFExitParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_loop_cond_parser.h" | |||
| #include <string> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/loop_cond.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFLoopCondParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF LoopCondParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| *primitiveC = new (std::nothrow) LoopCond(); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfLoopCondParser("LoopCond", new TFLoopCondParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFLoopCondParser : public TFNodeParser { | |||
| public: | |||
| TFLoopCondParser() = default; | |||
| ~TFLoopCondParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_merge_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFMergeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF MergeParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::MergeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Merge; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfMergeParser("Merge", new TFMergeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFMergeParser : public TFNodeParser { | |||
| public: | |||
| TFMergeParser() = default; | |||
| ~TFMergeParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -277,7 +277,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co | |||
| } | |||
| param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | |||
| MS_LOG(ERROR) << "Unsupported dataType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -417,6 +417,16 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| return nullptr; | |||
| } | |||
| if (!nodes_with_null_input_.empty()) { | |||
| status = ConnectNullInput(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Connect null inputs failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| } | |||
| status = ConvertRootGraphOutputs(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||
| @@ -474,16 +484,16 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| for (int j = 0; j < input_arg_size; j++) { | |||
| auto &input_arg = tf_sub_signature.input_arg(j); | |||
| auto paramter = sub_func_graph->add_parameter(); | |||
| paramter->set_name(input_arg.name()); | |||
| anf_sub_node_map[input_arg.name()] = paramter; | |||
| auto parameter = sub_func_graph->add_parameter(); | |||
| parameter->set_name(input_arg.name()); | |||
| anf_sub_node_map[input_arg.name()] = parameter; | |||
| auto root_inputs = cnode->inputs(); | |||
| if (op_type == schema::PrimitiveType_While) { | |||
| paramter->set_abstract(root_inputs[j + 1]->abstract()); | |||
| parameter->set_abstract(root_inputs[j + 1]->abstract()); | |||
| } else { | |||
| paramter->set_abstract(root_inputs[j + 2]->abstract()); | |||
| parameter->set_abstract(root_inputs[j + 2]->abstract()); | |||
| } | |||
| sub_graph_inputs.emplace_back(paramter); | |||
| sub_graph_inputs.emplace_back(parameter); | |||
| } | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map; | |||
| for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { | |||
| @@ -643,7 +653,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| const std::vector<std::string> &input_names, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, | |||
| std::vector<AnfNodePtr> *inputs) { | |||
| std::vector<AnfNodePtr> *inputs, | |||
| std::vector<std::string> *input_name_not_found) { | |||
| MS_ASSERT(node_def != nullptr); | |||
| // parse inputs | |||
| for (size_t j = 0; j < input_names.size(); j++) { | |||
| @@ -656,8 +667,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| } | |||
| auto input = GetAnfNode(flatten_input_name, anf_node_map); | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; | |||
| return RET_ERROR; | |||
| MS_LOG(WARNING) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; | |||
| (*input_name_not_found).push_back(flatten_input_name); | |||
| } | |||
| inputs->emplace_back(input); | |||
| } | |||
| @@ -718,6 +729,27 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found) { | |||
| nodes_with_null_input_.emplace_back(node, input_name_not_found); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConnectNullInput() { | |||
| for (auto &it : nodes_with_null_input_) { | |||
| auto &cnode = it.first; | |||
| auto &input_name_not_found = it.second; | |||
| auto &inputs = cnode->inputs(); | |||
| int i = 0; | |||
| for (size_t j = 0; j < inputs.size(); ++j) { | |||
| if (inputs[j] == nullptr) { | |||
| cnode->set_input(j, GetAnfNode(input_name_not_found[i], anf_root_node_map_)); | |||
| ++i; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const FuncGraphPtr &func_graph_ptr, | |||
| @@ -752,7 +784,8 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<AnfNodePtr> inputs = {value_node}; | |||
| status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs); | |||
| std::vector<std::string> input_name_not_found{}; | |||
| status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs, &input_name_not_found); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| @@ -787,6 +820,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| } | |||
| } | |||
| if (!input_name_not_found.empty()) { | |||
| RecordNullInput(anf_node, input_name_not_found); | |||
| } | |||
| status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "proto/graph.pb.h" | |||
| #include "proto/node_def.pb.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -55,7 +56,7 @@ class TFModelParser : public ModelParser { | |||
| STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, | |||
| const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, | |||
| std::vector<AnfNodePtr> *inputs); | |||
| std::vector<AnfNodePtr> *inputs, std::vector<std::string> *input_name_not_found); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_node_map, const FuncGraphPtr &anf_graph, | |||
| int output_size); | |||
| @@ -71,6 +72,10 @@ class TFModelParser : public ModelParser { | |||
| STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph); | |||
| STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found); | |||
| STATUS ConnectNullInput(); | |||
| FuncGraphPtr anf_root_graph_; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map | |||
| @@ -79,6 +84,7 @@ class TFModelParser : public ModelParser { | |||
| std::vector<std::string> graph_output_names_; | |||
| std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name | |||
| std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node | |||
| std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{}; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_next_iteration_parser.h" | |||
| #include <string> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/ops/next_iteration.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFNextIterationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF NextIterationParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| *primitiveC = new (std::nothrow) NextIteration(); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfNextIterationParser("NextIteration", new TFNextIterationParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFNextIterationParser : public TFNodeParser { | |||
| public: | |||
| TFNextIterationParser() = default; | |||
| ~TFNextIterationParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_switch_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF SwitchParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::SwitchT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Switch; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfSwitchParser("Switch", new TFSwitchParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFSwitchParser : public TFNodeParser { | |||
| public: | |||
| TFSwitchParser() = default; | |||
| ~TFSwitchParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| @@ -0,0 +1,140 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| *conv_activation_fusion.h | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <deque> | |||
| #include "tools/optimizer/graph/functionalize_control_op_pass.h" | |||
| #include "tools/optimizer/graph/functionalize_while.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore::opt { | |||
| FuncGraphPtr FunctionalizeControlOpPass::NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type) { | |||
| auto fg = std::make_shared<FuncGraph>(); | |||
| if (fg == nullptr) { | |||
| MS_LOG(ERROR) << "new func)graph failed."; | |||
| return nullptr; | |||
| } | |||
| fg->set_attr("graph_name", MakeValue(subgraph_name)); | |||
| fg->set_attr("fmk", MakeValue(static_cast<int>(fmk_type))); | |||
| return fg; | |||
| } | |||
| std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) { | |||
| std::string cluster_name{}; | |||
| // tf node name use '/' split node name | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| size_t pos = cnode->fullname_with_scope().rfind('/'); | |||
| if (pos != std::string::npos) { | |||
| cluster_name = cnode->fullname_with_scope().substr(0, pos); | |||
| } else { | |||
| cluster_name = cnode->fullname_with_scope(); | |||
| } | |||
| return cluster_name; | |||
| } | |||
| void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) { | |||
| for (auto &node : func_graph->nodes()) { | |||
| auto cluster_name = NodeClusterName(node); | |||
| auto cluster_pos = WhichCluster(cluster_name); | |||
| if (cluster_pos == node_clusters_.size()) { | |||
| std::vector<AnfNodePtr> node_list{node}; | |||
| node_clusters_.emplace_back(std::make_pair(cluster_name, node_list)); | |||
| } else { | |||
| node_clusters_[cluster_pos].second.push_back(node); | |||
| } | |||
| } | |||
| } | |||
| size_t FunctionalizeControlOpPass::WhichCluster(const std::string &cluster_name) { | |||
| size_t pos = node_clusters_.size(); | |||
| for (size_t i = 0; i < pos; ++i) { | |||
| if (node_clusters_[i].first == cluster_name) { | |||
| return i; | |||
| } | |||
| } | |||
| return pos; | |||
| } | |||
| STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_graph) { | |||
| int ret = RET_OK; | |||
| for (auto &node_cluster : node_clusters_) { | |||
| for (auto &node : node_cluster.second) { | |||
| if (IsLoopCond(node)) { | |||
| loop_cond_nodes_.push_back(node->cast<CNodePtr>()); | |||
| FunctionalizeWhile fw(node_cluster.second, node->cast<CNodePtr>(), func_graph); | |||
| ret = fw.Process(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| // use name to find the frame | |||
| InitNodeClusters(func_graph); | |||
| if (BuildWhileSubgraph(func_graph) != RET_OK) { | |||
| MS_LOG(ERROR) << "build while subgraph failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| CNodePtr FunctionalizeControlOpPass::BelongToWhichNode(const CNodePtr &node, const FilterFunc &func) { | |||
| if (node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (func(node)) { | |||
| return node; | |||
| } | |||
| CNodePtr aim_node = nullptr; | |||
| std::deque<AnfNodePtr> todo(256); | |||
| todo.clear(); | |||
| for (auto &input_node : node->inputs()) { | |||
| if (func(input_node)) { | |||
| aim_node = utils::cast<CNodePtr>(input_node); | |||
| todo.clear(); | |||
| break; | |||
| } | |||
| todo.push_back(input_node); | |||
| } | |||
| while (!todo.empty()) { | |||
| AnfNodePtr todo_node = todo.front(); | |||
| todo.pop_front(); | |||
| if (func(todo_node)) { | |||
| aim_node = utils::cast<CNodePtr>(todo_node); | |||
| todo.clear(); | |||
| break; | |||
| } | |||
| if (utils::isa<CNodePtr>(todo_node)) { | |||
| auto cnode = utils::cast<CNodePtr>(todo_node); | |||
| for (size_t i = 0; i < cnode->inputs().size(); i++) { | |||
| todo.push_back(cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| if (aim_node == nullptr) { | |||
| MS_LOG(WARNING) << "not found belonging enter node."; | |||
| return nullptr; | |||
| } | |||
| return aim_node; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| *conv_activation_fusion.h | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class FunctionalizeControlOpPass : public Pass { | |||
| public: | |||
| FunctionalizeControlOpPass() : Pass("functionalize_control_op_pass") {} | |||
| ~FunctionalizeControlOpPass() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| static FuncGraphPtr NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type); | |||
| static bool IsMerge(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Merge; } | |||
| static bool IsLoopCond(const AnfNodePtr &node) { | |||
| return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_LoopCond); | |||
| } | |||
| static bool IsEnter(const AnfNodePtr &node) { | |||
| return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_Enter); | |||
| } | |||
| static bool IsExit(const AnfNodePtr &node) { | |||
| return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_Exit); | |||
| } | |||
| static bool IsSwitch(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Switch; } | |||
| static bool IsNextIteration(const AnfNodePtr &node) { | |||
| return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_NextIteration); | |||
| } | |||
| static bool IsControlFlowOp(const AnfNodePtr &node) { | |||
| return IsLoopCond(node) || IsEnter(node) || IsMerge(node) || IsSwitch(node) || IsExit(node) || | |||
| IsNextIteration(node); | |||
| } | |||
| static CNodePtr BelongToWhichNode(const CNodePtr &node, const FilterFunc &func); | |||
| static int GetSubgraphIndex() { | |||
| static int subgraph_index = 1; | |||
| return subgraph_index++; | |||
| } | |||
| // The names of nodes with the same prefix are a cluster. | |||
| static std::string NodeClusterName(const AnfNodePtr &node); | |||
| void InitNodeClusters(const FuncGraphPtr &func_graph); | |||
| // return the position in node_clusters_ | |||
| size_t WhichCluster(const std::string &cluster_name); | |||
| protected: | |||
| STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph); | |||
| std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{}; | |||
| std::vector<CNodePtr> loop_cond_nodes_{}; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_ | |||
| @@ -0,0 +1,521 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| *conv_activation_fusion.h | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <deque> | |||
| #include "tools/optimizer/graph/functionalize_while.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/while.h" | |||
| namespace { | |||
| mindspore::ValueNodePtr GetWhileAnfPrim() { | |||
| auto while_primitiveT = new (std::nothrow) mindspore::schema::PrimitiveT; | |||
| if (while_primitiveT == nullptr) { | |||
| MS_LOG(ERROR) << "new while_primitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| while_primitiveT->value.type = mindspore::schema::PrimitiveType_While; | |||
| auto whileT = new (std::nothrow) mindspore::schema::WhileT; | |||
| whileT->condSubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex(); | |||
| whileT->bodySubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex(); | |||
| while_primitiveT->value.value = whileT; | |||
| if (while_primitiveT->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new WhileT failed"; | |||
| delete (while_primitiveT); | |||
| return nullptr; | |||
| } | |||
| auto while_prim = std::make_shared<mindspore::lite::While>(while_primitiveT); | |||
| mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_prim); | |||
| return partial_anf_prim; | |||
| } | |||
| } // namespace | |||
| namespace mindspore::opt { | |||
| using mindspore::lite::RET_NULL_PTR; | |||
| CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) { | |||
| return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch); | |||
| } | |||
| CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) { | |||
| return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge); | |||
| } | |||
| CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) { | |||
| return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter); | |||
| } | |||
| int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) { | |||
| auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node); | |||
| if (index == input_enter_nodes_.end()) { | |||
| MS_LOG(WARNING) << node->fullname_with_scope() << " is not in input_enter_nodes_"; | |||
| return -1; | |||
| } | |||
| return index - input_enter_nodes_.begin(); | |||
| } | |||
| STATUS FunctionalizeWhile::NewWhileNode() { | |||
| ValueNodePtr while_anf_primitive = GetWhileAnfPrim(); | |||
| if (while_anf_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Get while anf primitive failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| static int count = 0; | |||
| std::vector<AnfNodePtr> while_op_inputs = {while_anf_primitive}; | |||
| while_node_ = fg_->NewCNode(while_op_inputs); | |||
| while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++)); | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyWhileNodeInput() { | |||
| for (auto &node : node_cluster_) { | |||
| if (FunctionalizeControlOpPass::IsEnter(node)) { | |||
| auto enter_cnode = node->cast<CNodePtr>(); | |||
| input_enter_nodes_.push_back(enter_cnode); | |||
| while_node_->add_input(enter_cnode->input(1)); | |||
| } | |||
| } | |||
| if (input_enter_nodes_.empty()) { | |||
| MS_LOG(ERROR) << "not found input of while node."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() { | |||
| output_exit_nodes_.resize(input_enter_nodes_.size()); | |||
| for (auto &node : node_cluster_) { | |||
| // exit ->switch->merge->enter | |||
| if (FunctionalizeControlOpPass::IsExit(node)) { | |||
| auto exit_node = node->cast<CNodePtr>(); | |||
| auto switch_node = BlongToWhichSwitch(exit_node); | |||
| auto merge_node = BlongToWhichMerge(switch_node); | |||
| auto enter_node = BlongToWhichEnter(merge_node); | |||
| int pos = PosInInputEnterNodes(enter_node); | |||
| if (pos == -1) { | |||
| MS_LOG(ERROR) << "not find in input enter nodes."; | |||
| return RET_ERROR; | |||
| } | |||
| output_exit_nodes_.at(pos) = exit_node; | |||
| } | |||
| } | |||
| if (output_exit_nodes_.size() == 1) { | |||
| while_node_->set_abstract(output_exit_nodes_[0]->abstract()); | |||
| } else { | |||
| AbstractBasePtrList abstract_list; | |||
| abstract_list.resize(output_exit_nodes_.size()); | |||
| std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(), | |||
| [](const CNodePtr &cnode) { return cnode->abstract(); }); | |||
| while_node_->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::UpdateExitNodeUser() { | |||
| if (output_exit_nodes_.size() == 1) { | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[output_exit_nodes_[0]]; | |||
| for (auto &node_user : node_users) { | |||
| if (fg_->nodes().contains(node_user.first)) { | |||
| manager->SetEdge(node_user.first, node_user.second, while_node_); | |||
| } | |||
| } | |||
| } else { | |||
| for (auto &node : output_exit_nodes_) { | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[node]; | |||
| for (auto &node_user : node_users) { | |||
| // new getitem | |||
| AbstractBasePtrList abstractList; | |||
| std::vector<int64_t> shape_vector; | |||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim(); | |||
| if (tuple_get_item_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||
| const auto &exit_node = node; | |||
| auto switch_node = BlongToWhichSwitch(exit_node); | |||
| auto merge_node = BlongToWhichMerge(switch_node); | |||
| auto enter_node = BlongToWhichEnter(merge_node); | |||
| int output_idx = PosInInputEnterNodes(enter_node); | |||
| auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); | |||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue}; | |||
| CNodePtr get_item_node = fg_->NewCNode(inputs); | |||
| std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| get_item_node->set_abstract(abstract); | |||
| get_item_node->set_fullname_with_scope(output_item_name); | |||
| // set | |||
| if (fg_->nodes().contains(node_user.first)) { | |||
| manager->SetEdge(node_user.first, node_user.second, get_item_node); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::BuildWhileNode() { | |||
| int ret = NewWhileNode(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "new while node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyWhileNodeInput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify while node input failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyWhileNodeOutput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify while node output failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| // update exit node user from exit to while | |||
| ret = UpdateExitNodeUser(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "update while node users, ret:" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| // nodes between loop_cond op and merge op be added into cond_func_graph | |||
| STATUS FunctionalizeWhile::CondSubgraphAddNodes() { | |||
| std::deque<AnfNodePtr> todo(512); | |||
| todo.clear(); | |||
| for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) { | |||
| todo.push_back(loop_cond_node_->input(i)); | |||
| } | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.back(); | |||
| todo.pop_back(); | |||
| if (FunctionalizeControlOpPass::IsMerge(node)) { | |||
| continue; | |||
| } | |||
| if (utils::isa<ParameterPtr>(node)) { | |||
| cond_sub_func_graph_->add_parameter(node->cast<ParameterPtr>()); | |||
| } else { | |||
| cond_sub_func_graph_->AddNode(node); | |||
| } | |||
| node->set_func_graph(cond_sub_func_graph_); | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| todo.push_back(cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() { | |||
| std::vector<AnfNodePtr> nodes_need_drop{}; | |||
| for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) { | |||
| for (auto &input_node : cnode->inputs()) { | |||
| if (FunctionalizeControlOpPass::IsMerge(input_node)) { | |||
| auto merge_node = input_node->cast<CNodePtr>(); | |||
| auto enter_node = BlongToWhichEnter(merge_node); | |||
| int pos = PosInInputEnterNodes(enter_node); | |||
| nodes_need_drop.push_back(cnode); | |||
| // set parameter | |||
| auto parameter = cond_sub_func_graph_->add_parameter(); | |||
| parameter->set_abstract(cnode->abstract()); | |||
| // hardcode for subgraph input name | |||
| parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); | |||
| // replace merge | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| if (cond_sub_func_graph_->nodes().contains(node_user.first)) { | |||
| manager->SetEdge(node_user.first, node_user.second, parameter); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // drop node from cond_func_graph | |||
| for (const auto &node : nodes_need_drop) { | |||
| cond_sub_func_graph_->DropNode(node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() { | |||
| auto return_prim_ptr = lite::GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "new value_node failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // cond subgraph output is LoopCond's input | |||
| std::vector<AnfNodePtr> op_inputs{value_node, loop_cond_node_->input(1)}; | |||
| auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs); | |||
| return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return"); | |||
| cond_sub_func_graph_->set_return(return_cnode); | |||
| // hardcode subgraph outputs name | |||
| cond_sub_func_graph_->output()->cast<CNodePtr>()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode"); | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::BuildCondGraph() { | |||
| cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond"; | |||
| cond_sub_func_graph_ = | |||
| FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF); | |||
| if (cond_sub_func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| cond_sub_func_graph_->set_manager(fg_->manager()); | |||
| int ret = CondSubgraphAddNodes(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyCondSubgraphOutput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyCondSubgraphInput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| // nodes between next_iteration op and switch op will be added into body_func_graph | |||
| STATUS FunctionalizeWhile::BodySubgraphAddNodes() { | |||
| std::deque<AnfNodePtr> todo(512); | |||
| todo.clear(); | |||
| for (auto &node : node_cluster_) { | |||
| if (FunctionalizeControlOpPass::IsNextIteration(node)) { | |||
| auto next_iteration_cnode = node->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) { | |||
| todo.push_back(next_iteration_cnode->input(i)); | |||
| } | |||
| body_subgraph_output_map_[node] = next_iteration_cnode->input(1); | |||
| } | |||
| } | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.back(); | |||
| todo.pop_back(); | |||
| if (FunctionalizeControlOpPass::IsSwitch(node)) { | |||
| continue; | |||
| } | |||
| if (utils::isa<ParameterPtr>(node)) { | |||
| body_sub_func_graph_->add_parameter(node->cast<ParameterPtr>()); | |||
| } else { | |||
| body_sub_func_graph_->AddNode(node); | |||
| } | |||
| node->set_func_graph(body_sub_func_graph_); | |||
| if (utils::isa<CNodePtr>(node)) { | |||
| auto cnode = utils::cast<CNodePtr>(node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| todo.push_back(cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() { | |||
| std::vector<AnfNodePtr> nodes_need_drop{}; | |||
| for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) { | |||
| for (auto &input_node : cnode->inputs()) { | |||
| if (FunctionalizeControlOpPass::IsSwitch(input_node)) { | |||
| auto switch_node = input_node->cast<CNodePtr>(); | |||
| auto merge_node = BlongToWhichMerge(switch_node); | |||
| auto enter_node = BlongToWhichEnter(merge_node); | |||
| int pos = PosInInputEnterNodes(enter_node); | |||
| nodes_need_drop.push_back(cnode); | |||
| // set parameter | |||
| auto parameter = body_sub_func_graph_->add_parameter(); | |||
| parameter->set_abstract(cnode->abstract()); | |||
| // hardcode for subgraph input name | |||
| parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter"); | |||
| // replace switch | |||
| auto manager = fg_->manager(); | |||
| auto node_users = manager->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| if (body_sub_func_graph_->nodes().contains(node_user.first)) { | |||
| manager->SetEdge(node_user.first, node_user.second, parameter); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // drop node from cond_func_graph | |||
| for (const auto &node : nodes_need_drop) { | |||
| body_sub_func_graph_->DropNode(node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() { | |||
| std::vector<AnfNodePtr> tmp_output{}; | |||
| tmp_output.resize(input_enter_nodes_.size()); | |||
| // next_iteration -> switch -> merge -> enter | |||
| for (auto &node_pair : body_subgraph_output_map_) { | |||
| auto next_iteration_cnode = utils::cast<CNodePtr>(node_pair.first); | |||
| auto switch_node = BlongToWhichSwitch(next_iteration_cnode); | |||
| auto merge_node = BlongToWhichMerge(switch_node); | |||
| auto enter_node = BlongToWhichEnter(merge_node); | |||
| int pos = PosInInputEnterNodes(enter_node); | |||
| tmp_output[pos] = node_pair.second; | |||
| // hard code. set cnode output name | |||
| node_pair.second->cast<CNodePtr>()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) + | |||
| "_cnode"); | |||
| } | |||
| auto return_prim_ptr = lite::GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| // cond subgraph output is LoopCond's input | |||
| std::vector<AnfNodePtr> op_inputs{value_node}; | |||
| auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs); | |||
| return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return"); | |||
| if (tmp_output.size() == 1) { | |||
| return_cnode->add_input(tmp_output[0]); | |||
| } else { | |||
| std::vector<AnfNodePtr> make_tuple_inputs = tmp_output; | |||
| auto make_tuple_prim_ptr = lite::GetMakeTuplePrim(); | |||
| if (make_tuple_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); | |||
| auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs); | |||
| make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple"); | |||
| return_cnode->add_input(make_tuple_cnode); | |||
| } | |||
| body_sub_func_graph_->set_return(return_cnode); | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::BuildBodyGraph() { | |||
| body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body"; | |||
| body_sub_func_graph_ = | |||
| FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF); | |||
| if (body_sub_func_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| body_sub_func_graph_->set_manager(fg_->manager()); | |||
| int ret = BodySubgraphAddNodes(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyBodySubgraphOutput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = IdentifyBodySubgraphInput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() { | |||
| // set while input cond and body vnode | |||
| auto cond_value_node = NewValueNode(cond_sub_func_graph_); | |||
| auto body_value_node = NewValueNode(body_sub_func_graph_); | |||
| auto inputs = while_node_->inputs(); | |||
| inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node}); | |||
| while_node_->set_inputs(inputs); | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() { | |||
| // fg_ drop cluster node | |||
| for (auto &node : node_cluster_) { | |||
| fg_->DropNode(node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FunctionalizeWhile::Process() { | |||
| int ret = BuildWhileNode(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "build while node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = BuildCondGraph(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "build while node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = BuildBodyGraph(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "build while node failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = InsertFuncGraphToWhileInput(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| ret = DropUselessNodesInMainGraph(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| *conv_activation_fusion.h | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/optimizer/graph/functionalize_control_op_pass.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class FunctionalizeWhile { | |||
| public: | |||
| FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg) | |||
| : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {} | |||
| // while | |||
| STATUS BuildWhileNode(); | |||
| STATUS IdentifyWhileNodeInput(); | |||
| STATUS IdentifyWhileNodeOutput(); | |||
| STATUS UpdateExitNodeUser(); | |||
| STATUS NewWhileNode(); | |||
| STATUS InsertFuncGraphToWhileInput(); | |||
| // cond subgraph | |||
| STATUS BuildCondGraph(); | |||
| STATUS CondSubgraphAddNodes(); | |||
| STATUS IdentifyCondSubgraphInput(); | |||
| STATUS IdentifyCondSubgraphOutput(); | |||
| // body subgraph | |||
| STATUS BuildBodyGraph(); | |||
| STATUS BodySubgraphAddNodes(); | |||
| STATUS IdentifyBodySubgraphInput(); | |||
| STATUS IdentifyBodySubgraphOutput(); | |||
| CNodePtr BlongToWhichSwitch(const CNodePtr &node); | |||
| CNodePtr BlongToWhichMerge(const CNodePtr &node); | |||
| CNodePtr BlongToWhichEnter(const CNodePtr &node); | |||
| int PosInInputEnterNodes(const CNodePtr &node); | |||
| STATUS DropUselessNodesInMainGraph(); | |||
| STATUS Process(); | |||
| private: | |||
| std::vector<AnfNodePtr> node_cluster_{}; | |||
| const CNodePtr loop_cond_node_; | |||
| FuncGraphPtr fg_; | |||
| FuncGraphPtr cond_sub_func_graph_ = nullptr; | |||
| FuncGraphPtr body_sub_func_graph_ = nullptr; | |||
| CNodePtr while_node_ = nullptr; | |||
| std::string cond_subgraph_name_{}; | |||
| std::string body_subgraph_name_{}; | |||
| // while | |||
| std::vector<CNodePtr> input_enter_nodes_{}; | |||
| std::vector<CNodePtr> output_exit_nodes_{}; | |||
| // pair (next iteration node, next iteration node input) | |||
| std::map<AnfNodePtr, AnfNodePtr> body_subgraph_output_map_{}; | |||
| // pair (switch node, switch output in body graph) | |||
| std::map<AnfNodePtr, AnfNodePtr> body_subgraph_input_map_{}; | |||
| // pair (switch node, switch output in body graph) | |||
| std::map<AnfNodePtr, AnfNodePtr> cond_subgraph_input_map_{}; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_PASS_H_ | |||
| @@ -325,26 +325,6 @@ STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const Fu | |||
| return RET_OK; | |||
| } | |||
| STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { | |||
| auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>(); | |||
| MS_ASSERT(body_partial_cnode != nullptr); | |||
| auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>(); | |||
| MS_ASSERT(body_vnode != nullptr); | |||
| auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode); | |||
| MS_ASSERT(body_fg != nullptr); | |||
| AbstractBasePtrList abstract_list; | |||
| auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output()); | |||
| for (auto &cnode : body_fg_output_cnode->inputs()) { | |||
| if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) { | |||
| continue; | |||
| } | |||
| abstract_list.push_back(cnode->abstract()); | |||
| } | |||
| switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| return RET_OK; | |||
| } | |||
| bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { | |||
| MS_LOG(INFO) << "The framework type of model should be tf/tflite."; | |||
| @@ -384,14 +364,6 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| auto type = GetCNodeType(cnode); | |||
| if (type == schema::PrimitiveType_Switch) { | |||
| int ret = SwitchCNodeInferShape(cnode); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PartialCNodeInferShape failed."; | |||
| return false; | |||
| } | |||
| } | |||
| if ((type == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | |||
| @@ -41,7 +41,6 @@ class InferShapePass : public Pass { | |||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | |||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | |||
| STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | |||
| STATUS SwitchCNodeInferShape(const CNodePtr &cnode); | |||
| int StrIsContain(const std::vector<std::string> &total, const std::string &aim); | |||
| int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); | |||