Browse Source

functionalize while

tags/v1.2.0-rc1
mengyuanli 5 years ago
parent
commit
0f426c434f
34 changed files with 1902 additions and 73 deletions
  1. +4
    -1
      mindspore/lite/src/ops/primitive_c.cc
  2. +6
    -0
      mindspore/lite/src/ops/primitive_c.h
  3. +30
    -24
      mindspore/lite/test/CMakeLists.txt
  4. +12
    -5
      mindspore/lite/tools/converter/CMakeLists.txt
  5. +12
    -2
      mindspore/lite/tools/converter/anf_transform.cc
  6. +56
    -0
      mindspore/lite/tools/converter/ops/enter.cc
  7. +39
    -0
      mindspore/lite/tools/converter/ops/enter.h
  8. +56
    -0
      mindspore/lite/tools/converter/ops/exit.cc
  9. +39
    -0
      mindspore/lite/tools/converter/ops/exit.h
  10. +56
    -0
      mindspore/lite/tools/converter/ops/loop_cond.cc
  11. +39
    -0
      mindspore/lite/tools/converter/ops/loop_cond.h
  12. +56
    -0
      mindspore/lite/tools/converter/ops/next_iteration.cc
  13. +39
    -0
      mindspore/lite/tools/converter/ops/next_iteration.h
  14. +33
    -0
      mindspore/lite/tools/converter/ops/ops_def.h
  15. +50
    -0
      mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc
  16. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_enter_parser.h
  17. +49
    -0
      mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc
  18. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_exit_parser.h
  19. +49
    -0
      mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc
  20. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.h
  21. +62
    -0
      mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc
  22. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_merge_parser.h
  23. +48
    -11
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  24. +7
    -1
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  25. +49
    -0
      mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc
  26. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.h
  27. +62
    -0
      mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc
  28. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_switch_parser.h
  29. +140
    -0
      mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc
  30. +71
    -0
      mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h
  31. +521
    -0
      mindspore/lite/tools/optimizer/graph/functionalize_while.cc
  32. +89
    -0
      mindspore/lite/tools/optimizer/graph/functionalize_while.h
  33. +0
    -28
      mindspore/lite/tools/optimizer/graph/infershape_pass.cc
  34. +0
    -1
      mindspore/lite/tools/optimizer/graph/infershape_pass.h

+ 4
- 1
mindspore/lite/src/ops/primitive_c.cc View File

@@ -1103,10 +1103,13 @@ schema::QuantType PrimitiveC::quant_type() const { return quant_type_; }
#endif

int PrimitiveC::Type() const {
if (this->primitive_ == nullptr) {
if (this->primitive_ == nullptr && this->op_type_ == OP_TYPE_NOT_SET) {
return schema::PrimitiveType_NONE;
}
#ifdef PRIMITIVE_WRITEABLE
if (op_type_ != OP_TYPE_NOT_SET) {
return op_type_;
}
return this->primitive_->value.type;
#else
return this->primitive_->value_type();


+ 6
- 0
mindspore/lite/src/ops/primitive_c.h View File

@@ -24,6 +24,9 @@
#ifdef PRIMITIVE_WRITEABLE
#include "ir/primitive.h"
#include "schema/inner/model_generated.h"
#include "schema/inner/ops_generated.h"
#include "schema/ops_generated.h"
#include "tools/converter/ops/ops_def.h"
#else
#include "schema/model_generated.h"
#endif
@@ -34,6 +37,7 @@

namespace mindspore {
namespace lite {
constexpr const int OP_TYPE_NOT_SET = -1;
constexpr uint32_t kSingleNum = 1;
constexpr uint32_t kDoubleNum = 2;
constexpr uint32_t kMultiNum = 3;
@@ -149,6 +153,7 @@ class PrimitiveC : public mindspore::Primitive {
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
bool infer_flag_ = true;
int op_type_ = OP_TYPE_NOT_SET;
};
std::shared_ptr<PrimitiveC> GetReturnPrim();

@@ -227,6 +232,7 @@ class PrimitiveC {
char *primitive_buf_ = nullptr;
bool infer_flag_ = true;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
int op_type_ = OP_TYPE_NOT_SET;
};
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);


+ 30
- 24
mindspore/lite/test/CMakeLists.txt View File

@@ -2,6 +2,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(TEST_DIR ${TOP_DIR}/mindspore/lite/test)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
set(CONVERTER_DIR ${TOP_DIR}/mindspore/lite/tools/converter)
include_directories(${TOP_DIR})
include_directories(${TEST_DIR})
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/external_libs/gtest.cmake)
@@ -16,7 +17,7 @@ set(CCSRC_SRC
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
)
else(ENABLE_CONVERTER)
else()
set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc)
add_compile_definitions(USE_ANDROID_LOG)
endif()
@@ -38,10 +39,10 @@ file(GLOB KERNEL_OP_TRAIN_SRC
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
)

if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
endif()
if (PLATFORM_ARM64)
if(PLATFORM_ARM64)
# assembly
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/assembly/arm64/*.s
${LITE_DIR}/nnacl/assembly/arm64/*.S)
@@ -53,7 +54,7 @@ if (PLATFORM_ARM64)
)
endif()

if (PLATFORM_ARM32)
if(PLATFORM_ARM32)
# assembly
file(GLOB TEST_ASSEMBLY_SRC
${LITE_DIR}/nnacl/assembly/arm32/*.S
@@ -65,7 +66,7 @@ if (PLATFORM_ARM32)
)
endif()

if ("${X86_64_SIMD}" STREQUAL "sse")
if("${X86_64_SIMD}" STREQUAL "sse")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
@@ -74,12 +75,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
)
endif()

if ("${X86_64_SIMD}" STREQUAL "avx")
if("${X86_64_SIMD}" STREQUAL "avx")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
${LITE_DIR}/nnacl/x86_64_avx/*.c
${LITE_DIR}/nnacl/assembly/avx/*.S)
${LITE_DIR}/nnacl/x86_64_avx/*.c
${LITE_DIR}/nnacl/assembly/avx/*.S)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
@@ -88,7 +89,7 @@ if ("${X86_64_SIMD}" STREQUAL "avx")
endif()

### gpu kernel
if (SUPPORT_GPU)
if(SUPPORT_GPU)
file(GLOB GPU_KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc
)
@@ -102,14 +103,18 @@ if (SUPPORT_GPU)
)
endif()

if (PLATFORM_ARM32 OR PLATFORM_ARM64)
if (ENABLE_CONVERTER)
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
if(ENABLE_CONVERTER)
set(BUILD_MINDDATA "off")
endif()
endif()
### runtime framework
add_definitions(-DENABLE_V0)
file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc)
if(ENABLE_CONVERTER)
file(GLOB_RECURSE CONVERTER_OPS_SRC ${CONVERTER_DIR}/ops/*.cc)
set(OPS_SRC ${OPS_SRC} ${CONVERTER_OPS_SRC})
endif()
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${CCSRC_SRC}
@@ -144,7 +149,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/errorcode.cc
)
### gpu runtime
if (SUPPORT_GPU)
if(SUPPORT_GPU)
include_directories(${TOP_DIR}/third_party/OpenCL-Headers)
include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include)
set(OPENCL_RUNTIME_SRC
@@ -210,13 +215,14 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
)
endif()
### train
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
# ${LITE_DIR}/src/train/ops/train_ops.cc
${LITE_DIR}/src/train/train_populate_parameter.cc
${LITE_DIR}/src/train/train_session.cc
${LITE_DIR}/src/train/train_model.cc
@@ -251,7 +257,7 @@ set(TEST_SRC
${TEST_DIR}/ut/src/scheduler_test.cc
)

if (ENABLE_CONVERTER)
if(ENABLE_CONVERTER)
set(TEST_SRC
${TEST_SRC}
${TEST_DIR}/st/converter_test.cc
@@ -265,7 +271,7 @@ if (ENABLE_CONVERTER)
)
endif()

if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(TEST_SRC
${TEST_SRC}
${TEST_CASE_KERNEL_TRAIN_SRC}
@@ -278,7 +284,7 @@ else()
)
endif()

if (SUPPORT_GPU)
if(SUPPORT_GPU)
file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
)
@@ -288,7 +294,7 @@ if (SUPPORT_GPU)
)
endif()

if (ENABLE_FP16)
if(ENABLE_FP16)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
)
@@ -296,24 +302,24 @@ if (ENABLE_FP16)
${TEST_SRC}
${TEST_CASE_KERNEL_FP16_SRC}
)
endif ()
endif()

add_executable(lite-test ${TEST_SRC})
add_dependencies(lite-test fbs_src)
target_link_libraries(lite-test dl mindspore::gtest)
if (PLATFORM_ARM64 AND ENABLE_FP16)
if(PLATFORM_ARM64 AND ENABLE_FP16)
target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid)
endif()

if (PLATFORM_ARM)
if(PLATFORM_ARM)
target_link_libraries(lite-test log)
endif()

if (SUPPORT_NPU)
if(SUPPORT_NPU)
include_directories(${DDK_PATH})
target_link_libraries(lite-test npu_kernel_mid)
endif ()
if (ENABLE_CONVERTER)
endif()
if(ENABLE_CONVERTER)
add_dependencies(lite-test fbs_inner_src)
target_link_libraries(lite-test
anf_importer_mid


+ 12
- 5
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -12,6 +12,8 @@ include(${TOP_DIR}/cmake/external_libs/glog.cmake)
file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)

file(GLOB CONVERTER_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)

file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/converter.cc
@@ -65,6 +67,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/while_pass.cc
../optimizer/graph/if_pass.cc
../optimizer/graph/mindir_inputs_adjust_pass.cc
../optimizer/graph/functionalize_control_op_pass.cc
../optimizer/graph/functionalize_while.cc
)

add_subdirectory(../anf_importer anf_importer)
@@ -97,12 +101,12 @@ set(LITE_SRC
${SRC_DIR}/errorcode.cc
${SRC_DIR}/dequant.cc
)
if (SUPPORT_TRAIN)
if(SUPPORT_TRAIN)
set(LITE_SRC
${LITE_SRC}
)

endif ()
endif()
set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm)
file(GLOB KERNEL_SRC
${ARM_DIR}/base/*.cc
@@ -114,13 +118,13 @@ file(GLOB KERNEL_SRC
${ARM_DIR}/int8/*.cc
)

if (PLATFORM_ARM64)
if(PLATFORM_ARM64)
# assembly
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.s
${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/assembly/arm64/*.S)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif ()
endif()

file(GLOB PROTO_FILE ""
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
@@ -133,11 +137,13 @@ add_library(proto_mid OBJECT ${PROTO_SRCS})
set(TFLITE_FBS_FILES
${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/schema.fbs
)
ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src ${CMAKE_BINARY_DIR}/schema "inner")
ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src
${CMAKE_BINARY_DIR}/schema "inner")

set_property(SOURCE ${CONVERTER_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${CONVERTER_OPS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${KERNEL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_executable(converter_lite
@@ -145,6 +151,7 @@ add_executable(converter_lite
${CCSRC_SRC}
${CONVERTER_SRC}
${OPS_SRC}
${CONVERTER_OPS_SRC}
${KERNEL_SRC}
${LITE_SRC}
)


+ 12
- 2
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -48,6 +48,7 @@
#include "tools/optimizer/graph/slice_prepose_pass.h"
#include "tools/optimizer/graph/while_pass.h"
#include "tools/optimizer/graph/if_pass.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@@ -100,6 +101,15 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
}
}

if (config->fmk == lite::converter::FmkType_TF) {
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
if (!functionalize_control_op_pass->Run(old_graph)) {
MS_LOG(ERROR) << "functionalize control op pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}

if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
@@ -145,7 +155,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified";
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
return nullptr;
}
remove_unused_cast_pass->SetFmkType(config->fmk);
@@ -154,7 +164,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
if (remove_unused_transpose_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified";
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified";
return nullptr;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);


+ 56
- 0
mindspore/lite/tools/converter/ops/enter.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/ops/enter.h"
#include "src/tensorlist.h"

namespace mindspore {
namespace lite {

int Enter::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 39
- 0
mindspore/lite/tools/converter/ops/enter.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ENTER_H_
#define LITE_MINDSPORE_LITE_C_OPS_ENTER_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {

class Enter : public PrimitiveC {
public:
Enter() { op_type_ = ConverterPrimitiveType_Enter; }
~Enter() = default;
MS_DECLARE_PARENT(Enter, PrimitiveC);
explicit Enter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_ENTER_H_

+ 56
- 0
mindspore/lite/tools/converter/ops/exit.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/ops/exit.h"
#include "src/tensorlist.h"

namespace mindspore {
namespace lite {

int Exit::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 39
- 0
mindspore/lite/tools/converter/ops/exit.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EXIT_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXIT_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {

class Exit : public PrimitiveC {
public:
Exit() { op_type_ = ConverterPrimitiveType_Exit; }
~Exit() = default;
MS_DECLARE_PARENT(Exit, PrimitiveC);
explicit Exit(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_EXIT_H_

+ 56
- 0
mindspore/lite/tools/converter/ops/loop_cond.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/ops/loop_cond.h"
#include "src/tensorlist.h"

namespace mindspore {
namespace lite {

int LoopCond::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 39
- 0
mindspore/lite/tools/converter/ops/loop_cond.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_
#define LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {

class LoopCond : public PrimitiveC {
public:
LoopCond() { op_type_ = ConverterPrimitiveType_LoopCond; }
~LoopCond() = default;
MS_DECLARE_PARENT(LoopCond, PrimitiveC);
explicit LoopCond(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_LOOPCOND_H_

+ 56
- 0
mindspore/lite/tools/converter/ops/next_iteration.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/ops/next_iteration.h"
#include "src/tensorlist.h"

namespace mindspore {
namespace lite {

int NextIteration::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
for (size_t i = 0; i < inputs_.size(); i++) {
auto *input = inputs_[i];
auto *output = outputs_[i];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_tensorlist = reinterpret_cast<TensorList *>(output);
output_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 39
- 0
mindspore/lite/tools/converter/ops/next_iteration.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {

class NextIteration : public PrimitiveC {
public:
NextIteration() { op_type_ = ConverterPrimitiveType_NextIteration; }
~NextIteration() = default;
MS_DECLARE_PARENT(NextIteration, PrimitiveC);
explicit NextIteration(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_

+ 33
- 0
mindspore/lite/tools/converter/ops/ops_def.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {

enum ConverterPrimitiveType {
ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1,
ConverterPrimitiveType_LoopCond,
ConverterPrimitiveType_NextIteration,
ConverterPrimitiveType_Exit,
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_

+ 50
- 0
mindspore/lite/tools/converter/parser/tf/tf_enter_parser.cc View File

@@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_enter_parser.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/enter.h"

namespace mindspore {
namespace lite {
STATUS TFEnterParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF EnterParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

*primitiveC = new (std::nothrow) Enter();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfEnterParser("Enter", new TFEnterParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_enter_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ENTER_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFEnterParser : public TFNodeParser {
public:
TFEnterParser() = default;
~TFEnterParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 49
- 0
mindspore/lite/tools/converter/parser/tf/tf_exit_parser.cc View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_exit_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/exit.h"

namespace mindspore {
namespace lite {
STATUS TFExitParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ExitParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

*primitiveC = new (std::nothrow) Exit();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfExitParser("Exit", new TFExitParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_exit_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXIT_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFExitParser : public TFNodeParser {
public:
TFExitParser() = default;
~TFExitParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 49
- 0
mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.cc View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_loop_cond_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/loop_cond.h"

namespace mindspore {
namespace lite {
STATUS TFLoopCondParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF LoopCondParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

*primitiveC = new (std::nothrow) LoopCond();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfLoopCondParser("LoopCond", new TFLoopCondParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_loop_cond_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOOP_COND_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFLoopCondParser : public TFNodeParser {
public:
TFLoopCondParser() = default;
~TFLoopCondParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 62
- 0
mindspore/lite/tools/converter/parser/tf/tf_merge_parser.cc View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_merge_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFMergeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF MergeParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::MergeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_Merge;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfMergeParser("Merge", new TFMergeParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_merge_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MERGE_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFMergeParser : public TFNodeParser {
public:
TFMergeParser() = default;
~TFMergeParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 48
- 11
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -277,7 +277,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co
}
param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t));
} else {
MS_LOG(ERROR) << "Unsupport dataType: " << type;
MS_LOG(ERROR) << "Unsupported dataType: " << type;
return RET_ERROR;
}

@@ -417,6 +417,16 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
MS_LOG(ERROR) << "Convert ops failed.";
return nullptr;
}

if (!nodes_with_null_input_.empty()) {
status = ConnectNullInput();
if (status != RET_OK) {
MS_LOG(ERROR) << "Connect null inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
}

status = ConvertRootGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
@@ -474,16 +484,16 @@ STATUS TFModelParser::ConvertSubgraph() {
std::vector<ParameterPtr> sub_graph_inputs;
for (int j = 0; j < input_arg_size; j++) {
auto &input_arg = tf_sub_signature.input_arg(j);
auto paramter = sub_func_graph->add_parameter();
paramter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = paramter;
auto parameter = sub_func_graph->add_parameter();
parameter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = parameter;
auto root_inputs = cnode->inputs();
if (op_type == schema::PrimitiveType_While) {
paramter->set_abstract(root_inputs[j + 1]->abstract());
parameter->set_abstract(root_inputs[j + 1]->abstract());
} else {
paramter->set_abstract(root_inputs[j + 2]->abstract());
parameter->set_abstract(root_inputs[j + 2]->abstract());
}
sub_graph_inputs.emplace_back(paramter);
sub_graph_inputs.emplace_back(parameter);
}
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
@@ -643,7 +653,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs) {
std::vector<AnfNodePtr> *inputs,
std::vector<std::string> *input_name_not_found) {
MS_ASSERT(node_def != nullptr);
// parse inputs
for (size_t j = 0; j < input_names.size(); j++) {
@@ -656,8 +667,8 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
}
auto input = GetAnfNode(flatten_input_name, anf_node_map);
if (input == nullptr) {
MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
return RET_ERROR;
MS_LOG(WARNING) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
(*input_name_not_found).push_back(flatten_input_name);
}
inputs->emplace_back(input);
}
@@ -718,6 +729,27 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
return RET_OK;
}

STATUS TFModelParser::RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found) {
nodes_with_null_input_.emplace_back(node, input_name_not_found);
return RET_OK;
}

STATUS TFModelParser::ConnectNullInput() {
for (auto &it : nodes_with_null_input_) {
auto &cnode = it.first;
auto &input_name_not_found = it.second;
auto &inputs = cnode->inputs();
int i = 0;
for (size_t j = 0; j < inputs.size(); ++j) {
if (inputs[j] == nullptr) {
cnode->set_input(j, GetAnfNode(input_name_not_found[i], anf_root_node_map_));
++i;
}
}
}
return RET_OK;
}

STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr,
@@ -752,7 +784,8 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return RET_ERROR;
}
std::vector<AnfNodePtr> inputs = {value_node};
status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs);
std::vector<std::string> input_name_not_found{};
status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs, &input_name_not_found);
if (status != RET_OK) {
return status;
}
@@ -787,6 +820,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
}
}

if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";


+ 7
- 1
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -22,6 +22,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "proto/graph.pb.h"
#include "proto/node_def.pb.h"
#include "schema/inner/model_generated.h"
@@ -55,7 +56,7 @@ class TFModelParser : public ModelParser {
STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
std::vector<AnfNodePtr> *inputs);
std::vector<AnfNodePtr> *inputs, std::vector<std::string> *input_name_not_found);
STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
std::unordered_map<std::string, AnfNodePtr> *anf_node_map, const FuncGraphPtr &anf_graph,
int output_size);
@@ -71,6 +72,10 @@ class TFModelParser : public ModelParser {

STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);

STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);

STATUS ConnectNullInput();

FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
@@ -79,6 +84,7 @@ class TFModelParser : public ModelParser {
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
};
} // namespace lite
} // namespace mindspore


+ 49
- 0
mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.cc View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_next_iteration_parser.h"
#include <string>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/next_iteration.h"

namespace mindspore {
namespace lite {
STATUS TFNextIterationParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF NextIterationParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

*primitiveC = new (std::nothrow) NextIteration();
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfNextIterationParser("NextIteration", new TFNextIterationParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_next_iteration_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NEXT_ITERATION_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFNextIterationParser : public TFNodeParser {
public:
TFNextIterationParser() = default;
~TFNextIterationParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 62
- 0
mindspore/lite/tools/converter/parser/tf/tf_switch_parser.cc View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_switch_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SwitchParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SwitchT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_Switch;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tfSwitchParser("Switch", new TFSwitchParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_switch_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SWITCH_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFSwitchParser : public TFNodeParser {
public:
TFSwitchParser() = default;
~TFSwitchParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_

+ 140
- 0
mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.cc View File

@@ -0,0 +1,140 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <deque>
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
#include "tools/optimizer/graph/functionalize_while.h"
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"

namespace mindspore::opt {

FuncGraphPtr FunctionalizeControlOpPass::NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type) {
auto fg = std::make_shared<FuncGraph>();
if (fg == nullptr) {
MS_LOG(ERROR) << "new func)graph failed.";
return nullptr;
}
fg->set_attr("graph_name", MakeValue(subgraph_name));
fg->set_attr("fmk", MakeValue(static_cast<int>(fmk_type)));
return fg;
}

std::string FunctionalizeControlOpPass::NodeClusterName(const AnfNodePtr &node) {
std::string cluster_name{};
// tf node name use '/' split node name
auto cnode = utils::cast<CNodePtr>(node);
size_t pos = cnode->fullname_with_scope().rfind('/');
if (pos != std::string::npos) {
cluster_name = cnode->fullname_with_scope().substr(0, pos);
} else {
cluster_name = cnode->fullname_with_scope();
}
return cluster_name;
}

void FunctionalizeControlOpPass::InitNodeClusters(const FuncGraphPtr &func_graph) {
for (auto &node : func_graph->nodes()) {
auto cluster_name = NodeClusterName(node);
auto cluster_pos = WhichCluster(cluster_name);
if (cluster_pos == node_clusters_.size()) {
std::vector<AnfNodePtr> node_list{node};
node_clusters_.emplace_back(std::make_pair(cluster_name, node_list));
} else {
node_clusters_[cluster_pos].second.push_back(node);
}
}
}

size_t FunctionalizeControlOpPass::WhichCluster(const std::string &cluster_name) {
size_t pos = node_clusters_.size();
for (size_t i = 0; i < pos; ++i) {
if (node_clusters_[i].first == cluster_name) {
return i;
}
}
return pos;
}

STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_graph) {
int ret = RET_OK;
for (auto &node_cluster : node_clusters_) {
for (auto &node : node_cluster.second) {
if (IsLoopCond(node)) {
loop_cond_nodes_.push_back(node->cast<CNodePtr>());
FunctionalizeWhile fw(node_cluster.second, node->cast<CNodePtr>(), func_graph);
ret = fw.Process();
if (ret != RET_OK) {
MS_LOG(ERROR) << "run functionalize while failed, ret: " << ret;
return ret;
}
}
}
}
return ret;
}

bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
// use name to find the frame
InitNodeClusters(func_graph);
if (BuildWhileSubgraph(func_graph) != RET_OK) {
MS_LOG(ERROR) << "build while subgraph failed.";
return false;
}
return true;
}
CNodePtr FunctionalizeControlOpPass::BelongToWhichNode(const CNodePtr &node, const FilterFunc &func) {
if (node == nullptr) {
return nullptr;
}
if (func(node)) {
return node;
}
CNodePtr aim_node = nullptr;
std::deque<AnfNodePtr> todo(256);
todo.clear();
for (auto &input_node : node->inputs()) {
if (func(input_node)) {
aim_node = utils::cast<CNodePtr>(input_node);
todo.clear();
break;
}
todo.push_back(input_node);
}

while (!todo.empty()) {
AnfNodePtr todo_node = todo.front();
todo.pop_front();
if (func(todo_node)) {
aim_node = utils::cast<CNodePtr>(todo_node);
todo.clear();
break;
}
if (utils::isa<CNodePtr>(todo_node)) {
auto cnode = utils::cast<CNodePtr>(todo_node);
for (size_t i = 0; i < cnode->inputs().size(); i++) {
todo.push_back(cnode->input(i));
}
}
}
if (aim_node == nullptr) {
MS_LOG(WARNING) << "not found belonging enter node.";
return nullptr;
}

return aim_node;
}
} // namespace mindspore::opt

+ 71
- 0
mindspore/lite/tools/optimizer/graph/functionalize_control_op_pass.h View File

@@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_
#include <string>
#include <set>
#include <utility>
#include <vector>
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class FunctionalizeControlOpPass : public Pass {
public:
FunctionalizeControlOpPass() : Pass("functionalize_control_op_pass") {}
~FunctionalizeControlOpPass() override = default;
bool Run(const FuncGraphPtr &graph) override;
static FuncGraphPtr NewFuncGraph(const std::string &subgraph_name, const FmkType &fmk_type);
static bool IsMerge(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Merge; }
static bool IsLoopCond(const AnfNodePtr &node) {
return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_LoopCond);
}
static bool IsEnter(const AnfNodePtr &node) {
return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_Enter);
}
static bool IsExit(const AnfNodePtr &node) {
return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_Exit);
}
static bool IsSwitch(const AnfNodePtr &node) { return opt::GetCNodeType(node) == schema::PrimitiveType_Switch; }
static bool IsNextIteration(const AnfNodePtr &node) {
return static_cast<int>(opt::GetCNodeType(node)) == static_cast<int>(lite::ConverterPrimitiveType_NextIteration);
}
static bool IsControlFlowOp(const AnfNodePtr &node) {
return IsLoopCond(node) || IsEnter(node) || IsMerge(node) || IsSwitch(node) || IsExit(node) ||
IsNextIteration(node);
}
static CNodePtr BelongToWhichNode(const CNodePtr &node, const FilterFunc &func);
static int GetSubgraphIndex() {
static int subgraph_index = 1;
return subgraph_index++;
}
// The names of nodes with the same prefix are a cluster.
static std::string NodeClusterName(const AnfNodePtr &node);
void InitNodeClusters(const FuncGraphPtr &func_graph);
// return the position in node_clusters_
size_t WhichCluster(const std::string &cluster_name);

protected:
STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph);
std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{};
std::vector<CNodePtr> loop_cond_nodes_{};
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_CONTROL_OP_PASS_H_

+ 521
- 0
mindspore/lite/tools/optimizer/graph/functionalize_while.cc View File

@@ -0,0 +1,521 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include <memory>
#include <deque>
#include "tools/optimizer/graph/functionalize_while.h"
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"
#include "src/ops/while.h"

namespace {
mindspore::ValueNodePtr GetWhileAnfPrim() {
auto while_primitiveT = new (std::nothrow) mindspore::schema::PrimitiveT;
if (while_primitiveT == nullptr) {
MS_LOG(ERROR) << "new while_primitiveT failed";
return nullptr;
}
while_primitiveT->value.type = mindspore::schema::PrimitiveType_While;
auto whileT = new (std::nothrow) mindspore::schema::WhileT;
whileT->condSubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex();
whileT->bodySubgraphIndex = mindspore::opt::FunctionalizeControlOpPass::GetSubgraphIndex();
while_primitiveT->value.value = whileT;
if (while_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new WhileT failed";
delete (while_primitiveT);
return nullptr;
}

auto while_prim = std::make_shared<mindspore::lite::While>(while_primitiveT);
mindspore::ValueNodePtr partial_anf_prim = NewValueNode(while_prim);
return partial_anf_prim;
}
} // namespace

namespace mindspore::opt {

using mindspore::lite::RET_NULL_PTR;

CNodePtr FunctionalizeWhile::BlongToWhichSwitch(const CNodePtr &node) {
return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsSwitch);
}
CNodePtr FunctionalizeWhile::BlongToWhichMerge(const CNodePtr &node) {
return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsMerge);
}
CNodePtr FunctionalizeWhile::BlongToWhichEnter(const CNodePtr &node) {
return FunctionalizeControlOpPass::BelongToWhichNode(node, FunctionalizeControlOpPass::IsEnter);
}

int FunctionalizeWhile::PosInInputEnterNodes(const CNodePtr &node) {
auto index = std::find(input_enter_nodes_.begin(), input_enter_nodes_.end(), node);
if (index == input_enter_nodes_.end()) {
MS_LOG(WARNING) << node->fullname_with_scope() << " is not in input_enter_nodes_";
return -1;
}
return index - input_enter_nodes_.begin();
}

STATUS FunctionalizeWhile::NewWhileNode() {
ValueNodePtr while_anf_primitive = GetWhileAnfPrim();
if (while_anf_primitive == nullptr) {
MS_LOG(ERROR) << "Get while anf primitive failed.";
return RET_NULL_PTR;
}

static int count = 0;
std::vector<AnfNodePtr> while_op_inputs = {while_anf_primitive};
while_node_ = fg_->NewCNode(while_op_inputs);
while_node_->set_fullname_with_scope(loop_cond_node_->fullname_with_scope() + "-while-" + std::to_string(count++));
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyWhileNodeInput() {
for (auto &node : node_cluster_) {
if (FunctionalizeControlOpPass::IsEnter(node)) {
auto enter_cnode = node->cast<CNodePtr>();
input_enter_nodes_.push_back(enter_cnode);
while_node_->add_input(enter_cnode->input(1));
}
}
if (input_enter_nodes_.empty()) {
MS_LOG(ERROR) << "not found input of while node.";
return RET_ERROR;
}
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyWhileNodeOutput() {
output_exit_nodes_.resize(input_enter_nodes_.size());
for (auto &node : node_cluster_) {
// exit ->switch->merge->enter
if (FunctionalizeControlOpPass::IsExit(node)) {
auto exit_node = node->cast<CNodePtr>();
auto switch_node = BlongToWhichSwitch(exit_node);
auto merge_node = BlongToWhichMerge(switch_node);
auto enter_node = BlongToWhichEnter(merge_node);
int pos = PosInInputEnterNodes(enter_node);
if (pos == -1) {
MS_LOG(ERROR) << "not find in input enter nodes.";
return RET_ERROR;
}
output_exit_nodes_.at(pos) = exit_node;
}
}

if (output_exit_nodes_.size() == 1) {
while_node_->set_abstract(output_exit_nodes_[0]->abstract());
} else {
AbstractBasePtrList abstract_list;
abstract_list.resize(output_exit_nodes_.size());
std::transform(output_exit_nodes_.begin(), output_exit_nodes_.end(), abstract_list.begin(),
[](const CNodePtr &cnode) { return cnode->abstract(); });
while_node_->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
}
return RET_OK;
}

STATUS FunctionalizeWhile::UpdateExitNodeUser() {
if (output_exit_nodes_.size() == 1) {
auto manager = fg_->manager();
auto node_users = manager->node_users()[output_exit_nodes_[0]];
for (auto &node_user : node_users) {
if (fg_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, while_node_);
}
}
} else {
for (auto &node : output_exit_nodes_) {
auto manager = fg_->manager();
auto node_users = manager->node_users()[node];
for (auto &node_user : node_users) {
// new getitem
AbstractBasePtrList abstractList;
std::vector<int64_t> shape_vector;
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
auto tuple_get_item_prim_ptr = lite::GetTupleGetItemPrim();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return RET_NULL_PTR;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
const auto &exit_node = node;
auto switch_node = BlongToWhichSwitch(exit_node);
auto merge_node = BlongToWhichMerge(switch_node);
auto enter_node = BlongToWhichEnter(merge_node);
int output_idx = PosInInputEnterNodes(enter_node);
auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
CNodePtr get_item_node = fg_->NewCNode(inputs);
std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_NULL_PTR;
}
get_item_node->set_abstract(abstract);
get_item_node->set_fullname_with_scope(output_item_name);
// set
if (fg_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, get_item_node);
}
}
}
}
return RET_OK;
}

STATUS FunctionalizeWhile::BuildWhileNode() {
int ret = NewWhileNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "new while node failed, ret:" << ret;
return ret;
}
ret = IdentifyWhileNodeInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify while node input failed, ret:" << ret;
return ret;
}
ret = IdentifyWhileNodeOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify while node output failed, ret:" << ret;
return ret;
}
// update exit node user from exit to while
ret = UpdateExitNodeUser();
if (ret != RET_OK) {
MS_LOG(ERROR) << "update while node users, ret:" << ret;
return ret;
}

return ret;
}

// nodes between loop_cond op and merge op be added into cond_func_graph
STATUS FunctionalizeWhile::CondSubgraphAddNodes() {
std::deque<AnfNodePtr> todo(512);
todo.clear();
for (size_t i = 1; i < loop_cond_node_->inputs().size(); i++) {
todo.push_back(loop_cond_node_->input(i));
}
while (!todo.empty()) {
AnfNodePtr node = todo.back();
todo.pop_back();
if (FunctionalizeControlOpPass::IsMerge(node)) {
continue;
}
if (utils::isa<ParameterPtr>(node)) {
cond_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
} else {
cond_sub_func_graph_->AddNode(node);
}
node->set_func_graph(cond_sub_func_graph_);
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
todo.push_back(cnode->input(i));
}
}
}
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
std::vector<AnfNodePtr> nodes_need_drop{};
for (auto &cnode : cond_sub_func_graph_->GetOrderedCnodes()) {
for (auto &input_node : cnode->inputs()) {
if (FunctionalizeControlOpPass::IsMerge(input_node)) {
auto merge_node = input_node->cast<CNodePtr>();
auto enter_node = BlongToWhichEnter(merge_node);
int pos = PosInInputEnterNodes(enter_node);
nodes_need_drop.push_back(cnode);

// set parameter
auto parameter = cond_sub_func_graph_->add_parameter();
parameter->set_abstract(cnode->abstract());
// hardcode for subgraph input name
parameter->set_name(cond_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");

// replace merge
auto manager = fg_->manager();
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
if (cond_sub_func_graph_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, parameter);
}
}
}
}
}

// drop node from cond_func_graph
for (const auto &node : nodes_need_drop) {
cond_sub_func_graph_->DropNode(node);
}
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
auto return_prim_ptr = lite::GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
if (value_node == nullptr) {
MS_LOG(ERROR) << "new value_node failed.";
return RET_NULL_PTR;
}
// cond subgraph output is LoopCond's input
std::vector<AnfNodePtr> op_inputs{value_node, loop_cond_node_->input(1)};
auto return_cnode = cond_sub_func_graph_->NewCNode(op_inputs);
return_cnode->set_fullname_with_scope(cond_subgraph_name_ + "-return");
cond_sub_func_graph_->set_return(return_cnode);

// hardcode subgraph outputs name
cond_sub_func_graph_->output()->cast<CNodePtr>()->set_fullname_with_scope(cond_subgraph_name_ + "_output_0_cnode");
return RET_OK;
}

STATUS FunctionalizeWhile::BuildCondGraph() {
cond_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_cond";
cond_sub_func_graph_ =
FunctionalizeControlOpPass::NewFuncGraph(cond_subgraph_name_, mindspore::lite::converter::FmkType_TF);
if (cond_sub_func_graph_ == nullptr) {
MS_LOG(ERROR) << "new cond_sub_func_graph_ return nullptr";
return RET_NULL_PTR;
}
cond_sub_func_graph_->set_manager(fg_->manager());

int ret = CondSubgraphAddNodes();
if (ret != RET_OK) {
MS_LOG(ERROR) << "add cond_subgraph node failed, ret:" << ret;
return ret;
}
ret = IdentifyCondSubgraphOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify cond_subgraph output failed, ret:" << ret;
return ret;
}
ret = IdentifyCondSubgraphInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify cond_subgraph input failed, ret:" << ret;
return ret;
}

return ret;
}

// nodes between next_iteration op and switch op will be added into body_func_graph
STATUS FunctionalizeWhile::BodySubgraphAddNodes() {
std::deque<AnfNodePtr> todo(512);
todo.clear();
for (auto &node : node_cluster_) {
if (FunctionalizeControlOpPass::IsNextIteration(node)) {
auto next_iteration_cnode = node->cast<CNodePtr>();
for (size_t i = 1; i < next_iteration_cnode->inputs().size(); i++) {
todo.push_back(next_iteration_cnode->input(i));
}
body_subgraph_output_map_[node] = next_iteration_cnode->input(1);
}
}

while (!todo.empty()) {
AnfNodePtr node = todo.back();
todo.pop_back();
if (FunctionalizeControlOpPass::IsSwitch(node)) {
continue;
}
if (utils::isa<ParameterPtr>(node)) {
body_sub_func_graph_->add_parameter(node->cast<ParameterPtr>());
} else {
body_sub_func_graph_->AddNode(node);
}
node->set_func_graph(body_sub_func_graph_);
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
todo.push_back(cnode->input(i));
}
}
}
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyBodySubgraphInput() {
std::vector<AnfNodePtr> nodes_need_drop{};
for (auto &cnode : body_sub_func_graph_->GetOrderedCnodes()) {
for (auto &input_node : cnode->inputs()) {
if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
auto switch_node = input_node->cast<CNodePtr>();
auto merge_node = BlongToWhichMerge(switch_node);
auto enter_node = BlongToWhichEnter(merge_node);
int pos = PosInInputEnterNodes(enter_node);
nodes_need_drop.push_back(cnode);

// set parameter
auto parameter = body_sub_func_graph_->add_parameter();
parameter->set_abstract(cnode->abstract());
// hardcode for subgraph input name
parameter->set_name(body_subgraph_name_ + "_input_" + std::to_string(pos) + "_parameter");

// replace switch
auto manager = fg_->manager();
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
if (body_sub_func_graph_->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, parameter);
}
}
}
}
}

// drop node from cond_func_graph
for (const auto &node : nodes_need_drop) {
body_sub_func_graph_->DropNode(node);
}
return RET_OK;
}

STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
std::vector<AnfNodePtr> tmp_output{};
tmp_output.resize(input_enter_nodes_.size());
// next_iteration -> switch -> merge -> enter
for (auto &node_pair : body_subgraph_output_map_) {
auto next_iteration_cnode = utils::cast<CNodePtr>(node_pair.first);
auto switch_node = BlongToWhichSwitch(next_iteration_cnode);
auto merge_node = BlongToWhichMerge(switch_node);
auto enter_node = BlongToWhichEnter(merge_node);
int pos = PosInInputEnterNodes(enter_node);

tmp_output[pos] = node_pair.second;
// hard code. set cnode output name
node_pair.second->cast<CNodePtr>()->set_fullname_with_scope(body_subgraph_name_ + "_output_" + std::to_string(pos) +
"_cnode");
}

auto return_prim_ptr = lite::GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
// cond subgraph output is LoopCond's input
std::vector<AnfNodePtr> op_inputs{value_node};
auto return_cnode = body_sub_func_graph_->NewCNode(op_inputs);
return_cnode->set_fullname_with_scope(body_subgraph_name_ + "-return");

if (tmp_output.size() == 1) {
return_cnode->add_input(tmp_output[0]);
} else {
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
auto make_tuple_prim_ptr = lite::GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR;
}
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
auto make_tuple_cnode = body_sub_func_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope(return_cnode->fullname_with_scope() + "tuple");

return_cnode->add_input(make_tuple_cnode);
}

body_sub_func_graph_->set_return(return_cnode);
return RET_OK;
}

STATUS FunctionalizeWhile::BuildBodyGraph() {
body_subgraph_name_ = FunctionalizeControlOpPass::NodeClusterName(loop_cond_node_) + "_body";
body_sub_func_graph_ =
FunctionalizeControlOpPass::NewFuncGraph(body_subgraph_name_, mindspore::lite::converter::FmkType_TF);
if (body_sub_func_graph_ == nullptr) {
MS_LOG(ERROR) << "new body_sub_func_graph_ return nullptr";
return RET_NULL_PTR;
}
body_sub_func_graph_->set_manager(fg_->manager());

int ret = BodySubgraphAddNodes();
if (ret != RET_OK) {
MS_LOG(ERROR) << "add body_subgraph node failed, ret:" << ret;
return ret;
}
ret = IdentifyBodySubgraphOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify body_subgraph output failed, ret:" << ret;
return ret;
}
ret = IdentifyBodySubgraphInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "identify body_subgraph input failed, ret:" << ret;
return ret;
}
return ret;
}

STATUS FunctionalizeWhile::InsertFuncGraphToWhileInput() {
// set while input cond and body vnode
auto cond_value_node = NewValueNode(cond_sub_func_graph_);
auto body_value_node = NewValueNode(body_sub_func_graph_);
auto inputs = while_node_->inputs();
inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
while_node_->set_inputs(inputs);
return RET_OK;
}

STATUS FunctionalizeWhile::DropUselessNodesInMainGraph() {
// fg_ drop cluster node
for (auto &node : node_cluster_) {
fg_->DropNode(node);
}
return RET_OK;
}

STATUS FunctionalizeWhile::Process() {
int ret = BuildWhileNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "build while node failed, ret:" << ret;
return ret;
}

ret = BuildCondGraph();
if (ret != RET_OK) {
MS_LOG(ERROR) << "build while node failed, ret:" << ret;
return ret;
}

ret = BuildBodyGraph();
if (ret != RET_OK) {
MS_LOG(ERROR) << "build while node failed, ret:" << ret;
return ret;
}

ret = InsertFuncGraphToWhileInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "insert func_graph to while input failed, ret:" << ret;
return ret;
}

ret = DropUselessNodesInMainGraph();
if (ret != RET_OK) {
MS_LOG(ERROR) << "main func_graph drop nodes failed, ret:" << ret;
return ret;
}
return ret;
}
} // namespace mindspore::opt

+ 89
- 0
mindspore/lite/tools/optimizer/graph/functionalize_while.h View File

@@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_
#define MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_H_
#include <string>
#include <set>
#include <vector>
#include <map>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"

using mindspore::lite::converter::FmkType;
namespace mindspore::opt {

class FunctionalizeWhile {
public:
FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg)
: node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {}

// while
STATUS BuildWhileNode();
STATUS IdentifyWhileNodeInput();
STATUS IdentifyWhileNodeOutput();
STATUS UpdateExitNodeUser();
STATUS NewWhileNode();
STATUS InsertFuncGraphToWhileInput();

// cond subgraph
STATUS BuildCondGraph();
STATUS CondSubgraphAddNodes();
STATUS IdentifyCondSubgraphInput();
STATUS IdentifyCondSubgraphOutput();

// body subgraph
STATUS BuildBodyGraph();
STATUS BodySubgraphAddNodes();
STATUS IdentifyBodySubgraphInput();
STATUS IdentifyBodySubgraphOutput();

CNodePtr BlongToWhichSwitch(const CNodePtr &node);
CNodePtr BlongToWhichMerge(const CNodePtr &node);
CNodePtr BlongToWhichEnter(const CNodePtr &node);
int PosInInputEnterNodes(const CNodePtr &node);
STATUS DropUselessNodesInMainGraph();

STATUS Process();

private:
std::vector<AnfNodePtr> node_cluster_{};
const CNodePtr loop_cond_node_;
FuncGraphPtr fg_;

FuncGraphPtr cond_sub_func_graph_ = nullptr;
FuncGraphPtr body_sub_func_graph_ = nullptr;
CNodePtr while_node_ = nullptr;

std::string cond_subgraph_name_{};
std::string body_subgraph_name_{};

// while
std::vector<CNodePtr> input_enter_nodes_{};
std::vector<CNodePtr> output_exit_nodes_{};

// pair (next iteration node, next iteration node input)
std::map<AnfNodePtr, AnfNodePtr> body_subgraph_output_map_{};
// pair (switch node, switch output in body graph)
std::map<AnfNodePtr, AnfNodePtr> body_subgraph_input_map_{};
// pair (switch node, switch output in body graph)
std::map<AnfNodePtr, AnfNodePtr> cond_subgraph_input_map_{};
};

} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUNCTIONALIZE_WHILE_PASS_H_

+ 0
- 28
mindspore/lite/tools/optimizer/graph/infershape_pass.cc View File

@@ -325,26 +325,6 @@ STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const Fu
return RET_OK;
}

STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) {
auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>();
MS_ASSERT(body_partial_cnode != nullptr);
auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(body_vnode != nullptr);
auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode);
MS_ASSERT(body_fg != nullptr);
AbstractBasePtrList abstract_list;
auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
for (auto &cnode : body_fg_output_cnode->inputs()) {
if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
continue;
}
abstract_list.push_back(cnode->abstract());
}

switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return RET_OK;
}

bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
@@ -384,14 +364,6 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
}
auto type = GetCNodeType(cnode);

if (type == schema::PrimitiveType_Switch) {
int ret = SwitchCNodeInferShape(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PartialCNodeInferShape failed.";
return false;
}
}

if ((type == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) ||


+ 0
- 1
mindspore/lite/tools/optimizer/graph/infershape_pass.h View File

@@ -41,7 +41,6 @@ class InferShapePass : public Pass {
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
STATUS SetParameterAbstract(const ParameterPtr &parameter);
STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode);
STATUS SwitchCNodeInferShape(const CNodePtr &cnode);
int StrIsContain(const std::vector<std::string> &total, const std::string &aim);
int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph);



Loading…
Cancel
Save