Browse Source

unify format in anf

pull/14636/head
xuanyue 5 years ago
parent
commit
85741bac9d
38 changed files with 2391 additions and 498 deletions
  1. +11
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fused_batchnorm_infer.c
  2. +4
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c
  3. +0
    -15
      mindspore/lite/src/common/tensor_util.cc
  4. +0
    -1
      mindspore/lite/src/common/tensor_util.h
  5. +16
    -17
      mindspore/lite/src/runtime/infer_manager.cc
  6. +4
    -0
      mindspore/lite/test/CMakeLists.txt
  7. +1
    -1
      mindspore/lite/test/models_arm32.cfg
  8. +9
    -9
      mindspore/lite/test/models_caffe.cfg
  9. +9
    -9
      mindspore/lite/test/models_caffe_fp16.cfg
  10. +1
    -1
      mindspore/lite/test/models_gpu_fp16.cfg
  11. +1
    -1
      mindspore/lite/test/models_gpu_fp32.cfg
  12. +1
    -1
      mindspore/lite/test/models_npu.cfg
  13. +5
    -5
      mindspore/lite/test/models_onnx.cfg
  14. +5
    -5
      mindspore/lite/test/models_onnx_fp16.cfg
  15. +1
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  16. +4
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  17. +91
    -80
      mindspore/lite/tools/converter/anf_transform.cc
  18. +8
    -8
      mindspore/lite/tools/converter/anf_transform.h
  19. +1
    -14
      mindspore/lite/tools/converter/graphdef_transform.cc
  20. +0
    -225
      mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc
  21. +0
    -75
      mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h
  22. +184
    -0
      mindspore/lite/tools/optimizer/common/format_utils.cc
  23. +47
    -0
      mindspore/lite/tools/optimizer/common/format_utils.h
  24. +24
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  25. +7
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.h
  26. +3
    -1
      mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc
  27. +22
    -18
      mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
  28. +1
    -1
      mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h
  29. +529
    -0
      mindspore/lite/tools/optimizer/graph/node_infershape.cc
  30. +60
    -0
      mindspore/lite/tools/optimizer/graph/node_infershape.h
  31. +47
    -5
      mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc
  32. +1
    -0
      mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h
  33. +364
    -0
      mindspore/lite/tools/optimizer/graph/transpose_strategy.cc
  34. +65
    -0
      mindspore/lite/tools/optimizer/graph/transpose_strategy.h
  35. +785
    -0
      mindspore/lite/tools/optimizer/graph/unify_format_pass.cc
  36. +80
    -0
      mindspore/lite/tools/optimizer/graph/unify_format_pass.h
  37. +0
    -2
      mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc
  38. +0
    -2
      mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc

+ 11
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fused_batchnorm_infer.c View File

@@ -19,6 +19,13 @@

int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
if (check_ret != NNACL_OK) {
return check_ret;
}
#endif

for (size_t i = 0; i < inputs_size; i++) {
if (outputs_size <= i) {
break;
@@ -31,7 +38,10 @@ int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, T
outputs[5]->shape_size_ = 1;
outputs[5]->shape_[0] = 1;
}
return 0;
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
return NNACL_OK;
}

REG_INFER(FusedBatchNorm, PrimType_FusedBatchNorm, FusedBatchNormInferShape)

+ 4
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c View File

@@ -31,6 +31,10 @@ int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size
if (GetElementNum(get_index) != 1) {
return NNACL_ERR;
}
TensorC *output = outputs[0];
if (!parameter->infer_flag_ || input0->element_num_ == 0) {
return NNACL_INFER_INVALID;
}
if (get_index->data_ == NULL) {
return NNACL_INFER_INVALID;
}
@@ -40,7 +44,6 @@ int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size
}
TensorC *tensor_index = &input0->tensors_[index];

TensorC *output = outputs[0];
if (tensor_index->data_type_ != kTypeUnknown) {
output->data_type_ = tensor_index->data_type_;
} else {


+ 0
- 15
mindspore/lite/src/common/tensor_util.cc View File

@@ -61,21 +61,6 @@ int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors, std::vector
return RET_OK;
}

void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out) {
for (size_t i = 0; i < tensors_in.size(); ++i) {
if (tensors_in[i] != nullptr) {
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) {
auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i));
auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i));
tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_));
}
}
}
}

void FreeAllTensorC(std::vector<TensorC *> *tensors_in) {
for (auto &i : *tensors_in) {
if (i == nullptr) {


+ 0
- 1
mindspore/lite/src/common/tensor_util.h View File

@@ -26,7 +26,6 @@ namespace mindspore {
namespace lite {
int InputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out);
int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out);
void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out);
void FreeAllTensorC(std::vector<TensorC *> *tensors_in);
void FreeTensorListC(TensorListC *tensorListC);
void Tensor2TensorC(Tensor *src, TensorC *dst);


+ 16
- 17
mindspore/lite/src/runtime/infer_manager.cc View File

@@ -49,24 +49,23 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite
ret = infer_shape_func(static_cast<TensorC **>(in_tensors.data()), in_tensors.size(), out_tensors.data(),
out_tensors.size(), parameter);

if (ret == RET_OK) {
for (size_t i = 0; i < out_tensors.size(); i++) {
if (reinterpret_cast<TensorListC *>(out_tensors.at(i))->data_type_ == TypeIdC::kObjectTypeTensorType) {
auto *tensor_list_c = reinterpret_cast<TensorListC *>(out_tensors.at(i));
auto *tensor_list = reinterpret_cast<TensorList *>(outputs->at(i));
tensor_list->set_shape({static_cast<int>(tensor_list_c->element_num_)});
auto tensor_shape = std::vector<std::vector<int>>(
tensor_list_c->element_num_,
std::vector<int>(tensor_list_c->element_shape_,
tensor_list_c->element_shape_ + tensor_list_c->element_shape_size_));
tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape);
TensorListC2TensorList(tensor_list_c, tensor_list);
} else {
TensorC2Tensor(out_tensors.at(i), outputs->at(i));
}
for (size_t i = 0; i < out_tensors.size(); i++) {
if (out_tensors.at(i) == nullptr) {
continue;
}
if (reinterpret_cast<TensorListC *>(out_tensors.at(i))->data_type_ == TypeIdC::kObjectTypeTensorType) {
auto *tensor_list_c = reinterpret_cast<TensorListC *>(out_tensors.at(i));
auto *tensor_list = reinterpret_cast<TensorList *>(outputs->at(i));
tensor_list->set_shape({static_cast<int>(tensor_list_c->element_num_)});
auto tensor_shape = std::vector<std::vector<int>>(
tensor_list_c->element_num_,
std::vector<int>(tensor_list_c->element_shape_,
tensor_list_c->element_shape_ + tensor_list_c->element_shape_size_));
tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape);
TensorListC2TensorList(tensor_list_c, tensor_list);
} else {
TensorC2Tensor(out_tensors.at(i), outputs->at(i));
}
} else {
SetOutputTensorAttr(out_tensors, outputs);
}

FreeAllTensorC(&in_tensors);


+ 4
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -225,6 +225,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
${LITE_DIR}/tools/optimizer/common/format_utils.cc
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_activation_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_tuple_activation_fusion.cc
@@ -271,6 +272,9 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc
${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/unify_format_pass.cc
${LITE_DIR}/tools/optimizer/graph/node_infershape.cc
${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc
${LITE_DIR}/tools/common/graph_util.cc
${LITE_DIR}/tools/common/tensor_util.cc
${LITE_DIR}/tools/common/node_util.cc


+ 1
- 1
mindspore/lite/test/models_arm32.cfg View File

@@ -10,7 +10,7 @@ tracking
ml_face_contour
mnet
ml_face_landmark
ml_liveness_detect_landmark
ml_liveness_detect_landmark_tmp
deconv_test_model
deconvs_model
# onnx


+ 9
- 9
mindspore/lite/test/models_caffe.cfg View File

@@ -5,7 +5,7 @@ gender_res_large_deploy
glasses
hat
isface
ml_bank_detect_0312
ml_bank_detect_0312_tmp
ml_face_div_parsing
ml_hardware_eyeclose
ml_ocr_detect_20200305
@@ -26,9 +26,9 @@ hiai_face_landmark
hiai_face_pose_tuku
ml_hand_detection
ml_ocr_cn
ml_ocr_sfz_detect_0325
ml_ocr_sfz_detect_0325_tmp
ml_hardware_liveness
ml_liveness_detect_landmark
ml_liveness_detect_landmark_tmp
ml_face_contour
2012_ATLANTA_1class_20190621_v4.x_nomean
ml_ocr_sfz_add_final_0325
@@ -63,7 +63,7 @@ age_new
detection_retinaface_fix
landmark
plat_isface
PoseNet_dla_17_x512
PoseNet_dla_17_x512_tmp
ml_location_scene_division
ml_tabel_recog
ml_text_division
@@ -106,15 +106,15 @@ ml_Hand_deploy
ml_hand_3d_detection
ml_hand_3d_regression
ml_ARengine23_bodypose
ml_ocr_bank_card_detection_inception
ml_ocr_bank_card_detection_inception_tmp
ml_ocr_bank_card_recognition_fcny
hiai_cv_aestheticsEngineModel_osp
bank_card_recognition_fcny
bank_card_detection_inception
bank_card_detection_inception_tmp
ml_ocr_identify_card_fcny
ml_ocr_identify_card_detect
identify_card_detect
ml_ocr_identify_card_detect_tmp
identify_card_detect_tmp
ml_2012_ocr_rec_caffe
ml_2012_ocr_detection_caffe
ml_2012_ocr_detection_caffe_tmp
ml_face_mnet
ml_segmentation_atlanta_1

+ 9
- 9
mindspore/lite/test/models_caffe_fp16.cfg View File

@@ -5,7 +5,7 @@ gender_res_large_deploy 0.1
glasses 4
hat 1
isface 1
ml_bank_detect_0312 20
ml_bank_detect_0312_tmp 20
ml_face_div_parsing 8
ml_hardware_eyeclose 0.1
ml_ocr_detect_20200305 10
@@ -25,9 +25,9 @@ hiai_face_landmark 0.2
hiai_face_pose_tuku 1.3
ml_hand_detection 8
ml_ocr_cn 6
ml_ocr_sfz_detect_0325 3
ml_ocr_sfz_detect_0325_tmp 3
ml_hardware_liveness 3
ml_liveness_detect_landmark 1
ml_liveness_detect_landmark_tmp 1
ml_face_contour 0.5
2012_ATLANTA_1class_20190621_v4.x_nomean 1
ml_ocr_sfz_add_final_0325 0.1
@@ -62,7 +62,7 @@ age_new 22
detection_retinaface_fix 13
landmark 1
plat_isface 6
PoseNet_dla_17_x512 5
PoseNet_dla_17_x512_tmp 5
ml_location_scene_division 8
ml_tabel_recog 0.1
ml_text_division 12
@@ -101,16 +101,16 @@ ml_Hand_deploy 4
ml_hand_3d_detection 12
ml_hand_3d_regression 3
ml_ARengine23_bodypose 56
ml_ocr_bank_card_detection_inception 20
ml_ocr_bank_card_detection_inception_tmp 20
ml_ocr_bank_card_recognition_fcny 0.5
hiai_cv_aestheticsEngineModel_osp 1.5
ml_face_hat 0.5
bank_card_recognition_fcny 17
bank_card_detection_inception 12
bank_card_detection_inception_tmp 12
ml_ocr_identify_card_fcny 0.5
ml_ocr_identify_card_detect 2
identify_card_detect 0.5
ml_2012_ocr_detection_caffe 1
ml_ocr_identify_card_detect_tmp 2
identify_card_detect_tmp 0.5
ml_2012_ocr_detection_caffe_tmp 1
ml_2012_ocr_rec_caffe 0.5
ml_lable_model_hebing_device 2
ml_face_sex 0.5


+ 1
- 1
mindspore/lite/test/models_gpu_fp16.cfg View File

@@ -7,5 +7,5 @@ mtk_new_detect.tflite
mtk_pose.tflite
mtk_model_emotions_0727_nosoftmax.tflite
landmark
PoseNet_dla_17_x512
PoseNet_dla_17_x512_tmp
plat_isface

+ 1
- 1
mindspore/lite/test/models_gpu_fp32.cfg View File

@@ -20,7 +20,7 @@ mtk_convert_model.tflite
mtk_model_face_dress_fp16.tflite
detection_retinaface_fix
landmark
PoseNet_dla_17_x512
PoseNet_dla_17_x512_tmp
age_new
plat_isface
Q_hand_0812.pb


+ 1
- 1
mindspore/lite/test/models_npu.cfg View File

@@ -81,4 +81,4 @@ posenet_mobilenet_float_075_1_default_1.tflite 395
nasnet_mobile.tflite 1
ml_video_edit_art_generate.onnx 0.5
ml_video_edit_art_transfer.onnx 3 3
ml_video_edit_enhance_update.onnx 0.5
ml_video_edit_enhance_update_tmp.onnx 0.5

+ 5
- 5
mindspore/lite/test/models_onnx.cfg View File

@@ -7,7 +7,7 @@ mobilenetv2-7.onnx
shufflenet-v2-10.onnx
squeezenet1.1-7.onnx
densenet-9.onnx
ml_table_detection_fp32.onnx
ml_table_detection_fp32_tmp.onnx
ml_table_segment.onnx
googlenet-9.onnx
inception-v1-9.onnx
@@ -52,7 +52,7 @@ hdc_resnet_1w_class.onnx
ml_video_edit_imitate_filter.onnx
#ml_voice_detect.onnx #Accuracy error: 4.59655%, the result is close to 1.0e-8 except for the last one
hdc_ocr_attention.onnx
hdc_ocr_detect.onnx
hdc_ocr_detect_tmp.onnx
ml_edu_kit_hand_detection.onnx
ml_edu_kit_hand_key_position.onnx
ml_facedetector.onnx
@@ -71,8 +71,8 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1,480,640,3
mtk_face_features_v2.onnx;1,256,192,3
mtk_face_recognition_v3.onnx
mtk_face_recognition_v2.onnx
ml_2012_ocr_detection.onnx
ml_video_edit_enhance_update.onnx
ml_2012_ocr_detection_tmp.onnx
ml_video_edit_enhance_update_tmp.onnx
#Harmony_Voiceprint_resnet18.onnx
bloom_hongmo_detection.onnx
bloom_hongmo_detection_tmp.onnx
Q_face_recognition.onnx

+ 5
- 5
mindspore/lite/test/models_onnx_fp16.cfg View File

@@ -7,7 +7,7 @@ mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1
densenet-9.onnx 6
ml_table_detection_fp32.onnx 2
ml_table_detection_fp32_tmp.onnx 2
ml_table_segment.onnx 2
googlenet-9.onnx 3
inception-v1-9.onnx 3
@@ -37,7 +37,7 @@ residual_distill_res34_cifar10_bs_1_update.onnx 2
residual_distill_res50_cifar10_bs_1_update.onnx 2
#ml_voice_detect.onnx #out of float16 range because power op
hdc_ocr_attention.onnx 1.6
hdc_ocr_detect.onnx 30 #one of the output has small values
hdc_ocr_detect_tmp.onnx 30 #one of the output has small values
ml_edu_kit_hand_detection.onnx 2
ml_edu_kit_hand_key_position.onnx 2
ml_video_edit_judge.onnx 12
@@ -71,8 +71,8 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1,480,640,3 2
mtk_face_features_v2.onnx;1,256,192,3 0.5
mtk_face_recognition_v3.onnx 0.5
mtk_face_recognition_v2.onnx 2.5
ml_2012_ocr_detection.onnx 0.5
ml_2012_ocr_detection_tmp.onnx 0.5
#Harmony_Voiceprint_resnet18.onnx;1,1,200,40 4.5
bloom_hongmo_detection.onnx 0.5
bloom_hongmo_detection_tmp.onnx 0.5
Q_face_recognition.onnx 2
ml_video_edit_enhance_update.onnx 0.5
ml_video_edit_enhance_update_tmp.onnx 0.5

+ 1
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -37,6 +37,7 @@ class AnfExporter {
public:
AnfExporter() = default;
virtual ~AnfExporter() = default;
void set_train_flag(bool train_flag) { train_flag_ = train_flag; }
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
bool train_flag = false);
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,


+ 4
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -35,6 +35,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/common/node_pass_extends.cc
../optimizer/common/pass_manager_extends.cc
../optimizer/common/gllo_utils.cc
../optimizer/common/format_utils.cc
../optimizer/fusion/conv_biasadd_fusion.cc
../optimizer/fusion/conv_activation_fusion.cc
../optimizer/fusion/conv_tuple_activation_fusion.cc
@@ -81,6 +82,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/functionalize_cond.cc
../optimizer/graph/inputs_adjust_pass.cc
../optimizer/graph/primitive_adjust_pass.cc
../optimizer/graph/unify_format_pass.cc
../optimizer/graph/node_infershape.cc
../optimizer/graph/transpose_strategy.cc
)

add_subdirectory(../anf_exporter anf_exporter)


+ 91
- 80
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -61,6 +61,7 @@
#include "tools/optimizer/graph/if_pass.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
#include "tools/optimizer/graph/inputs_adjust_pass.h"
#include "tools/optimizer/graph/unify_format_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@@ -71,7 +72,8 @@ AnfTransform::AnfTransform() = default;

AnfTransform::~AnfTransform() = default;

int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) {
int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);

// for now - training is not supporting fuse operations
@@ -106,24 +108,20 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
remove_unused_cast_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(remove_unused_cast_pass);
}
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
if (remove_unused_transpose_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified";
return RET_ERROR;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(remove_unused_transpose_pass);
}
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
if (!config->trainModel) {
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
}
optimizer->AddPassManager(fusion_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run op fusion failed.";
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) {
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
@@ -134,8 +132,6 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto conv1d_weight_expanding_pass = std::make_shared<opt::Conv1DWeightExpandingPass>();
graph_pm->AddPass(conv1d_weight_expanding_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
@@ -144,11 +140,15 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim
slice_prepose_pass->SetFmkType(config->fmk);
graph_pm->AddPass(slice_prepose_pass);
optimizer->AddPassManager(graph_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run graph pass failed.";
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
@@ -156,11 +156,15 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
convert_pm->AddPass(std::make_shared<opt::TfliteInputsAdjustPass>());
}
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run graph convert pass failed.";
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) {
@@ -179,6 +183,10 @@ int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &o
infershape_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(infershape_pass);
optimizer->AddPassManager(const_fold_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run const fold failed.";
return RET_ERROR;
}
return RET_OK;
}

@@ -203,12 +211,18 @@ int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter::
}
}

int AnfTransform::AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
int AnfTransform::RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto conv1d_pm = std::make_shared<opt::PassManager>("conv1d adjust pass manager", true);
conv1d_pm->AddPass(std::make_shared<opt::Conv1DInOutAdjustPass>());
conv1d_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
auto conv1d_weight_expanding_pass = std::make_shared<opt::Conv1DWeightExpandingPass>();
conv1d_pm->AddPass(conv1d_weight_expanding_pass);
optimizer->AddPassManager(conv1d_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run conv1d adjust failed.";
return RET_ERROR;
}
return RET_OK;
}

@@ -276,18 +290,17 @@ int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converte
return RET_OK;
}

int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config,
const FuncGraphPtr &new_graph) {
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->quantType == schema::QuantType_PostTraining) {
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum);
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum);
if (m_quantizer_ == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(new_graph, *config);
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
if (m_quantizer_ == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
@@ -296,7 +309,7 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
}
if (m_quantizer_ != nullptr) {
m_quantizer_->flags = *config;
auto status = m_quantizer_->DoQuantize(new_graph);
auto status = m_quantizer_->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -306,70 +319,72 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
return RET_OK;
}

FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config should be specified";
return nullptr;
}
int status;
for (auto &fg : func_graphs_) {
status = RunPrecedingPass(fg, *config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Preceding pass failed.";
return nullptr;
}

auto status = RunPrecedingPass(old_graph, *config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Preceding pass failed.";
return nullptr;
}

status = RunAdjustPass(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Adjust pass failed.";
return nullptr;
}

auto optimizer = std::make_shared<opt::GraphOptimizer>();
status = RunAdjustPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Adjust pass failed.";
return nullptr;
}

status = AddConstFoldPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add const fold pass failed.";
return nullptr;
}
status = RunConstFoldPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run const fold pass failed.";
return nullptr;
}

status = AddConvertPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add convert pass failed.";
return nullptr;
}
status = RunConvertPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
}

status = AddFusionPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add fusion pass failed.";
return nullptr;
}
status = RunFusionPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run fusion pass failed.";
return nullptr;
}

status = AddGraphPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add graph pass failed.";
return nullptr;
status = RunConv1DAdjustPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run conv1d adjust pass failed.";
return nullptr;
}
}

status = AddConv1DAdjustPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add conv1d adjust pass failed.";
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed.";
return nullptr;
}

auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
return nullptr;
}
for (auto &fg : func_graphs_) {
status = RunGraphPass(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
}

status = DoQuantize(old_graph, config, new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
status = DoQuantize(fg, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
}
}

return new_graph;
return old_graph;
}

void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) {
@@ -401,15 +416,11 @@ void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) {

FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
GetAllFuncGraph(main_graph);

for (auto &fg : func_graphs_) {
auto new_main_graph = TransformSingleFuncGraph(fg, config);
if (new_main_graph == nullptr) {
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
auto new_graph = TransformFuncGraph(main_graph, config);
if (new_graph == nullptr) {
MS_LOG(ERROR) << "optimizer failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
}
return main_graph;
return new_graph;
}
} // namespace mindspore::lite

+ 8
- 8
mindspore/lite/tools/converter/anf_transform.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -39,19 +39,19 @@ class AnfTransform {
private:
std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr;

FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);

static int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

static int AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

static int AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config);

static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);

static int AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

@@ -61,7 +61,7 @@ class AnfTransform {

static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);

void GetAllFuncGraph(const FuncGraphPtr &func_graph);



+ 1
- 14
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -94,14 +94,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
auto old_nodes = GetGraphNodes();

Optimizer format_trans_optimizer;
auto format_trans_pass = new (std::nothrow) FormatTransPass();
if (format_trans_pass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
format_trans_pass->set_quant_type(ctx.quantType);
format_trans_pass->set_fmk_type(ctx.fmk);
format_trans_optimizer.AddPass(format_trans_pass);
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
if (ctx.fmk != converter::FmkType_TF) {
@@ -117,11 +109,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass());
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = format_trans_optimizer.Run(graph_defT_);


+ 0
- 225
mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc View File

@@ -1,225 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
#define MATMUL_BIASADD_MATCH_PATH_LEN 2
#define BIASADD_OP_BIAS_INDEX 1
#define BIASADD_OP_INPUT_NUM 2

STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }

STATUS MatMulBiasAddFusionPass::DefinePattern() {
auto matMulOp = std::make_shared<PatternOp>();
matMulOp->id = MATMUL_NAME;
matMulOp->types = {schema::PrimitiveType_MatMul};
auto baOp = std::make_shared<PatternOp>();
baOp->id = BIASADD_NAME;
baOp->types = {schema::PrimitiveType_BiasAdd};
baOp->left = matMulOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(matMulOp);
fusionPattern->AddPatternOp(baOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());

return RET_OK;
}

STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != MATMUL_BIASADD_MATCH_PATH_LEN) {
MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

auto matMulPath = matchedPath[MATMUL_NAME];
auto baPath = matchedPath[BIASADD_NAME];
auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx);
auto &baNode = graph->nodes.at(baPath->nodeIdx);
// can not check shape because there is now shape infer in converter
MS_ASSERT(matMulNode != nullptr);
MS_ASSERT(matMulNode->inputIndex.size() == 2);
// biasadd node the second tensor is not constant tensor, don't fusion
auto baNodeInputIndex = baNode->inputIndex;
if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) {
MS_LOG(ERROR) << "input num is invalid! node: " << baNode->name.c_str();
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX));
const auto &baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX));
MS_ASSERT(baNodeBiasTensor != nullptr);
if (baNodeBiasTensor->refCount != NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}

// 1. add biasTensor for matMul
auto status = AddFullConnectionBiasTensor(matMulPath, baPath, graph);
if (RET_OK != status) {
MS_LOG(ERROR) << "AddFullConnectionBiasTensor failed, ret: " << status;
return status;
}

// 2. change matmul to full connection op
matMulNode->name += "-fc";
std::unique_ptr<FullConnectionT> fcAttr(new (std::nothrow) FullConnectionT());
if (fcAttr == nullptr) {
MS_LOG(ERROR) << "new FullConnectionT node failed";
return RET_ERROR;
}
fcAttr->has_bias = true;
fcAttr->axis = 1;
MS_ASSERT(matMulNode->primitive != nullptr);
MS_ASSERT(matMulNode->primitive->value != nullptr);
MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr);
transA = matMulNode->primitive->value.AsMatMul()->transpose_a;
transB = matMulNode->primitive->value.AsMatMul()->transpose_b;
matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection;
matMulNode->primitive->value.value = fcAttr.release();

// 3. delete BiasAdd node
MergeNodeAttrFromPost(matMulNode, baNode);
status = IsolateOneWayNode(graph, baPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << baPath->subGraphIdx << ", node: " << baPath->nodeIdx
<< ", ret: " << status;
return status;
}

// 4. addTranspose node
status = InsertTransposeNode(graph, matMulPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertTransposeNode failed, subGraph: " << matMulPath->subGraphIdx
<< ", node: " << matMulPath->nodeIdx << ", ret: " << status;
return status;
}
return RET_OK;
}

STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std::shared_ptr<Path> &matMulPath) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(matMulPath != nullptr);

std::vector<size_t> insertNodeIdxList;
if (transA) {
insertNodeIdxList.emplace_back(0);
}
if (!transB) {
insertNodeIdxList.emplace_back(1);
}

auto matmulOpIter = graph->nodes.begin() + matMulPath->nodeIdx;
STATUS errorCode = RET_OK;
auto perm_tensor = std::make_unique<schema::TensorT>();
perm_tensor->dataType = kNumberTypeInt32;
perm_tensor->dims = {2};
std::vector<int> perm{1, 0};
size_t bytes = perm.size() * sizeof(int);
perm_tensor->data.resize(bytes);
perm_tensor->name = "perm_" + std::to_string(id++);
if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
return RET_ERROR;
}
size_t index = graph->allTensors.size();
graph->allTensors.push_back(std::move(perm_tensor));
for (auto needInsertIdx : insertNodeIdxList) {
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) {
MS_LOG(ERROR) << "new TransNode failed";
return RET_ERROR;
}
transNode->name = "transpose" + std::to_string(id++);
transNode->primitive->value.type = schema::PrimitiveType_Transpose;
int insert_num = 0;
matmulOpIter = InsertNode(graph, matmulOpIter, kBefore, needInsertIdx, std::move(transNode), &errorCode,
&insert_num, TransposeOpCopyer);
if (errorCode != RET_OK) {
MS_LOG(ERROR) << "InsertNode failed: " << errorCode;
return errorCode;
}
for (int i = insert_num; i > 0; --i) {
(*(matmulOpIter - i))->inputIndex.push_back(index);
}
}
graph->allTensors.at(index)->refCount = insertNodeIdxList.size();
return RET_OK;
}

#define BIASADD_WEIGHT_SHAPE_SIZE 1
#define BIASADD_BIAS_DIM_INDEX 0

STATUS MatMulBiasAddFusionPass::AddFullConnectionBiasTensor(const std::shared_ptr<Path> &matMulPath,
const std::shared_ptr<Path> &baPath, MetaGraphT *graph) {
MS_ASSERT(matMulPath != nullptr);
MS_ASSERT(baPath != nullptr);
MS_ASSERT(graph != nullptr);

MS_ASSERT(graph->nodes.size() > matMulPath->nodeIdx);
auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx);
MS_ASSERT(matMulNode != nullptr);
auto baNode = graph->nodes.at(baPath->nodeIdx).get();
MS_ASSERT(baNode != nullptr);

// check biasTensor
auto baWeightTensorIdxes = baNode->inputIndex;
if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) {
MS_LOG(ERROR) << "input number is invalid! node: " << baNode->name.c_str();
return RET_ERROR;
}
MS_ASSERT(graph->allTensors.size() > baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX));
auto &biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX));
MS_ASSERT(biasTensor != nullptr);
auto biasDims = biasTensor->dims;
// if biasTensor is a scaler
if (biasDims.empty() && biasTensor->data.data() == nullptr) {
MS_LOG(ERROR) << "bias tensor is invalid, node: " << baNode->name.c_str();
return RET_ERROR;
}
if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) {
MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size()
<< ". or bias tensor is a scaler";
return RET_ERROR;
}
// add biasTensor to matmul
matMulNode->inputIndex.emplace_back(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX));
baNode->inputIndex.erase(baNode->inputIndex.begin() + BIASADD_OP_BIAS_INDEX);

return RET_OK;
}

MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default;
} // namespace lite
} // namespace mindspore

+ 0
- 75
mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h View File

@@ -1,75 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
#define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H

#include <string>
#include <unordered_map>
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
#include "tools/common/graph_util.h"

namespace mindspore {
namespace lite {
constexpr const char *MATMUL_NAME = "MATMUL";

class MatMulBiasAddFusionPass : public FusionPass {
public:
MatMulBiasAddFusionPass() = default;

~MatMulBiasAddFusionPass() override;

STATUS DefinePattern() override;

STATUS DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(MetaGraphT *graph) override;

protected:
static STATUS AddFullConnectionBiasTensor(const std::shared_ptr<Path> &matMulPath,
const std::shared_ptr<Path> &dstPath, MetaGraphT *subGraph);
STATUS InsertTransposeNode(MetaGraphT *subGraph, const std::shared_ptr<Path> &matMulPath);

protected:
bool transA = false;
bool transB = false;
size_t id = 0;

OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
auto newOpDef = std::make_unique<schema::CNodeT>();
if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newOpDef->name = inOpDef->name;
newOpDef->quantType = inOpDef->quantType;
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
if (newOpDef->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr;
}
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
return newOpDef;
};
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H

+ 184
- 0
mindspore/lite/tools/optimizer/common/format_utils.cc View File

@@ -0,0 +1,184 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/optimizer/common/format_utils.h"
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "ops/adam.h"
#include "ops/addn.h"
#include "ops/apply_momentum.h"
#include "ops/batch_norm.h"
#include "ops/batch_to_space.h"
#include "ops/bias_add.h"
#include "ops/concat.h"
#include "ops/crop.h"
#include "ops/depth_to_space.h"
#include "ops/fusion/activation.h"
#include "ops/fusion/add_fusion.h"
#include "ops/fused_batch_norm.h"
#include "ops/fusion/avg_pool_fusion.h"
#include "ops/fusion/conv2d_backprop_input_fusion.h"
#include "ops/fusion/conv2d_backprop_filter_fusion.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/fusion/max_pool_fusion.h"
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/pow_fusion.h"
#include "ops/fusion/prelu_fusion.h"
#include "ops/fusion/slice_fusion.h"
#include "ops/fusion/topk_fusion.h"
#include "ops/eltwise.h"
#include "ops/grad/activation_grad.h"
#include "ops/grad/avg_pool_grad.h"
#include "ops/grad/batch_norm_grad.h"
#include "ops/grad/bias_add_grad.h"
#include "ops/grad/max_pool_grad.h"
#include "ops/grad/resize_grad.h"
#include "ops/instance_norm.h"
#include "ops/lrn.h"
#include "ops/maximum.h"
#include "ops/op_utils.h"
#include "ops/quant_dtype_cast.h"
#include "ops/resize.h"
#include "ops/sgd.h"
#include "ops/space_to_batch.h"
#include "ops/space_to_batch_nd.h"
#include "ops/space_to_depth.h"
#include "ops/split.h"
#include "ops/strided_slice.h"

namespace mindspore {
namespace opt {
static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
{ops::kNameAdam, {10}},
{ops::kNameApplyMomentum, {4}},
{ops::kNameAvgPoolFusion, {1}},
{ops::kNameAvgPoolGrad, {}},
{ops::kNameBatchNorm, {1}},
{ops::kNameBatchNormGrad, {1, 2}},
{ops::kNameBatchToSpace, {1}},
{ops::kNameBiasAdd, {1}},
{ops::kNameBiasAddGrad, {1}},
{ops::kNameConv2DBackpropInputFusion, {1}},
{ops::kNameConv2DBackpropFilterFusion, {1, 2}},
{ops::kNameConv2DFusion, {1}},
{ops::kNameConv2dTransposeFusion, {1}},
{ops::kNameDepthToSpace, {1}},
{ops::kNameFusedBatchNorm, {1}},
{ops::kNameLRN, {1}},
{ops::kNameMaxPoolFusion, {1}},
{ops::kNameMaxPoolGrad, {}},
{ops::kNamePReLUFusion, {1}},
{ops::kNameResize, {1}},
{ops::kNameResizeGrad, {}},
{ops::kNameSGD, {2}},
{ops::kNameSpaceToBatch, {1}},
{ops::kNameSpaceToBatchND, {1}},
{ops::kNameSpaceToDepth, {1}},
{ops::kNameTopKFusion, {1}}};

static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}};

// a certain op whose input's format is not fixed.
static const std::vector<std::string> DynamicFormatOpList = {
ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNamePowFusion, ops::kNameStridedSlice,
ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, ops::kNameCrop,
ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast};

static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}};

const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
const std::unordered_map<int, int> &GetNC2NHAxisMap() { return NC2NHAxisMap; }
const std::vector<std::string> &GetDynamicFormatOpList() { return DynamicFormatOpList; }

Format GetFormat(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
MS_ASSERT(prim_node != nullptr);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
Format format = NHWC;
if (prim->GetAttr(ops::kFormat) != nullptr) {
format = static_cast<Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
}
return format;
}

STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm) {
MS_ASSERT(perm_node != nullptr);
if (!utils::isa<ParameterPtr>(perm_node)) {
return lite::RET_OK;
}
auto perm_param = perm_node->cast<ParameterPtr>();
if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
return lite::RET_OK;
}
auto tensor_info = perm_param->default_param()->cast<tensor::TensorPtr>();
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "default param is not a tensor.";
return lite::RET_ERROR;
}
if (tensor_info->data_type() != kNumberTypeInt && tensor_info->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "data type is error, which is " << tensor_info->data_type();
return lite::RET_ERROR;
}
auto tensor_shape = tensor_info->shape();
if (tensor_shape.empty()) {
return lite::RET_OK;
}
if (tensor_shape.size() > 1) {
return lite::RET_ERROR;
}
perm->resize(tensor_shape[0]);
if (memcpy_s(perm->data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}

void RemoveIfMonad(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
std::vector<AnfNodePtr> inputs{cnode->input(0)};
for (size_t i = 1; i < cnode->size(); ++i) {
if (utils::isa<ValueNodePtr>(cnode->input(i))) {
auto value_node = cnode->input(i)->cast<ValueNodePtr>();
auto value = value_node->value();
if (value->isa<Monad>()) {
continue;
}
}
inputs.push_back(cnode->input(i));
}
cnode->set_inputs(inputs);
}

bool IsMonadNode(const AnfNodePtr &node) {
if (!utils::isa<ValueNodePtr>(node)) {
return false;
}
auto value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
if (value->isa<Monad>()) {
return true;
}
return false;
}
} // namespace opt
} // namespace mindspore

+ 47
- 0
mindspore/lite/tools/optimizer/common/format_utils.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_

#include <vector>
#include <string>
#include <unordered_map>
#include "tools/optimizer/common/gllo_utils.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace opt {
constexpr auto kInferDone = "infer_done";
constexpr auto kTransDone = "trans_done";
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
struct TransTypePair {
FormatTransNodeType pre_;
FormatTransNodeType post_;
TransTypePair() : pre_(kNONE), post_(kNONE) {}
};
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap();
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap();
const std::unordered_map<int, int> &GetNC2NHAxisMap();
const std::vector<std::string> &GetDynamicFormatOpList();
Format GetFormat(const CNodePtr &cnode);
STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm);
void RemoveIfMonad(const CNodePtr &cnode);
bool IsMonadNode(const AnfNodePtr &node);
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_

+ 24
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -22,6 +22,8 @@
#include <string>
#include "Eigen/Core"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/transpose.h"
#include "ops/tuple_get_item.h"
#include "src/common/common.h"
#include "tools/common/tensor_util.h"
#include "frontend/operator/ops.h"
@@ -1351,5 +1353,27 @@ ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const
}
return param_node;
}

CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
const std::string &cnode_name) {
MS_ASSERT(func_graph != nullptr && input_node != nullptr);
auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm");
MS_ASSERT(perm_node != nullptr);
auto trans_prim = std::make_shared<ops::Transpose>();
MS_ASSERT(trans_prim != nullptr);
auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node});
MS_ASSERT(cnode != nullptr);
cnode->set_fullname_with_scope(cnode_name);
return cnode;
}

CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) {
MS_ASSERT(func_graph != nullptr && input != nullptr);
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
auto second_input = NewValueNode(MakeValue<int>(index));
auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input});
tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index));
return tuple_cnode;
}
} // namespace opt
} // namespace mindspore

+ 7
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.h View File

@@ -25,6 +25,7 @@
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "ops/fusion/conv2d_backprop_input_fusion.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_context.h"

@@ -36,6 +37,7 @@ namespace mindspore {
namespace opt {
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion);
constexpr auto kWeightFormat = "weight_format";
std::vector<int> CastToInt(const ValuePtr &value);

@@ -146,6 +148,11 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st
ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
const std::string &node_name);

CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm,
const std::string &cnode_name);

CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index);

template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {


+ 3
- 1
mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc View File

@@ -68,10 +68,12 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
(!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) {
continue;
}
matmul_cnode->add_input(bias_node);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
auto tr = manager->Transact();
tr.AddEdge(matmul_cnode, bias_node);
tr.Commit();
manager->Replace(node, matmul_cnode);
}
return false;


+ 22
- 18
mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc View File

@@ -23,12 +23,17 @@ namespace {
constexpr size_t kTripleNum = 3;
constexpr size_t kConvWeightIndex = 2;
} // namespace
lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const tensor::TensorPtr &tensor,
const schema::Format &format) {
if (tensor == nullptr) {
return lite::RET_NULL_PTR;
lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const AnfNodePtr &weight_node, const schema::Format &format) {
MS_ASSERT(weight_node != nullptr);
auto weight_tensor = GetTensorInfo(weight_node);
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "weight node must be param value.";
return lite::RET_ERROR;
}
auto shape = weight_tensor->shape();
if (shape.size() != kTripleNum) {
return lite::RET_OK;
}
auto shape = tensor->shape();
std::vector<int64_t> new_shape(shape);
switch (format) {
case schema::Format_NCHW:
@@ -43,7 +48,13 @@ lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const tensor::TensorPt
MS_LOG(ERROR) << "Unsupported format.";
return RET_ERROR;
}
tensor->set_shape(new_shape);
weight_tensor->set_shape(new_shape);
if (!utils::isa<ParameterPtr>(weight_node)) {
return lite::RET_OK;
}
auto weight_param = weight_node->cast<ParameterPtr>();
auto type = weight_tensor->data_type();
weight_param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), new_shape));
return RET_OK;
}

@@ -62,25 +73,18 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto weight_value = GetTensorInfo(weight_node);
if (weight_value == nullptr) {
MS_LOG(ERROR) << "weight node must be param value.";
return false;
}

auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_ASSERT(prim != nullptr);

schema::Format schema_format = schema::Format::Format_KCHW;
if (prim->GetAttr(opt::kWeightFormat) != nullptr) {
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)));
}
// expand weight tensor to 4 dimensions.
if (weight_value->shape().size() == kTripleNum) {
auto status = ExpandFilterShape(weight_value, schema_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "Expand filter shape failed.";
return false;
}
auto status = ExpandFilterShape(weight_node, schema_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "Expand filter shape failed.";
return false;
}
}
return RET_OK;


+ 1
- 1
mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h View File

@@ -30,7 +30,7 @@ class Conv1DWeightExpandingPass : public Pass {
bool Run(const FuncGraphPtr &graph) override;

private:
lite::STATUS ExpandFilterShape(const tensor::TensorPtr &tensor, const schema::Format &format);
lite::STATUS ExpandFilterShape(const AnfNodePtr &weight_node, const schema::Format &format);
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_

+ 529
- 0
mindspore/lite/tools/optimizer/graph/node_infershape.cc View File

@@ -0,0 +1,529 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/optimizer/graph/node_infershape.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "src/ops/populate/populate_register.h"
#include "src/ops/ops_utils.h"
#include "src/runtime/infer_manager.h"
#include "src/tensorlist.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t INITIAL_SIZE = 1024;
void FreeTensors(std::vector<lite::Tensor *> *tensors) {
if (tensors == nullptr) {
return;
}
for (auto &v : *tensors) {
delete v;
v = nullptr;
}
tensors->resize(0);
}

void SetConvWeightFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs) {
MS_ASSERT(cnode != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) &&
!CheckPrimitiveType(cnode, kPrimConv2DBackpropInputFusion) &&
!CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
return;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(kWeightFormat) != nullptr && inputs.size() > 1) {
inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat))));
}
}

bool DuceInferFlag(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
MS_ASSERT(cnode != nullptr);
for (auto &input : inputs) {
auto shape = input->shape();
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
if (fmk_type == lite::converter::FmkType_ONNX && shape.size() == 4 && shape[3] == 3 && shape[1] == -1) {
input->set_format(schema::Format_NHWC);
}
return false;
}
}
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfDepend(cnode);
lite::AnfExporter::RemoveIfMakeTuple(cnode);
for (size_t i = 1; i < cnode->size(); ++i) {
if (!utils::isa<CNodePtr>(cnode->input(i))) {
continue;
}
auto input_cnode = cnode->input(i)->cast<CNodePtr>();
if (CheckPrimitiveType(cnode->input(i), prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
if (input_cnode == nullptr) {
MS_LOG(ERROR) << "input is not cnode.";
cnode->set_inputs(origin_inputs);
return false;
}
auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (prim == nullptr || prim->GetAttr(kInferDone) == nullptr) {
MS_LOG(ERROR) << "prim is invalid.";
cnode->set_inputs(origin_inputs);
return false;
}
if (!GetValue<bool>(prim->GetAttr(kInferDone))) {
cnode->set_inputs(origin_inputs);
return false;
}
}
cnode->set_inputs(origin_inputs);
return true;
}

tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
std::vector<int> shape(tensor->shape());
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
return nullptr;
}
return tensor_info;
}
} // namespace

STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
if (anf_prim == nullptr) {
MS_LOG(DEBUG) << "primitive is nullptr";
return lite::RET_ERROR;
}
anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
std::vector<lite::Tensor *> inputs;
std::vector<lite::Tensor *> outputs;
if (GetCNodeInputTensors(cnode, &inputs) != lite::RET_OK) {
FreeTensors(&inputs);
MS_LOG(ERROR) << "get inputs failed.";
return lite::RET_ERROR;
}
SetConvWeightFormat(cnode, inputs);
if (GetCNodeOutputTensors(cnode, &outputs) != lite::RET_OK) {
FreeTensors(&inputs);
FreeTensors(&outputs);
MS_LOG(ERROR) << "get outputs failed.";
return lite::RET_ERROR;
}
auto prim_t = lite::GetPrimitiveT(cnode->input(0));
if (prim_t == nullptr) {
MS_LOG(DEBUG) << "prim_t is nullptr";
FreeTensors(&inputs);
FreeTensors(&outputs);
return lite::RET_ERROR;
}
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
auto prim = lite::ConvertToPrimitive(prim_t, &fbb);
delete prim_t;
if (prim == nullptr) {
MS_LOG(ERROR) << "get primitive failed.";
FreeTensors(&inputs);
FreeTensors(&outputs);
fbb.Clear();
return lite::RET_ERROR;
}
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR);
if (parameter_gen == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
FreeTensors(&inputs);
FreeTensors(&outputs);
fbb.Clear();
return lite::RET_ERROR;
}
auto parameter = parameter_gen(prim);
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
FreeTensors(&inputs);
FreeTensors(&outputs);
fbb.Clear();
return lite::RET_ERROR;
}
parameter->infer_flag_ = DuceInferFlag(cnode, inputs, fmk_type_);
auto status = KernelInferShape(inputs, &outputs, parameter);
if (status == lite::RET_OK) {
anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
}
if (status == lite::RET_OK || status == lite::RET_INFER_INVALID) {
auto set_status = SetCNodeAbstract(cnode, outputs);
if (set_status != lite::RET_OK) {
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
return set_status;
}
} else {
MS_LOG(ERROR) << "infer shape failed.";
}
FreeTensors(&inputs);
FreeTensors(&outputs);
free(parameter);
fbb.Clear();
return status;
}

std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {
MS_ASSERT(cnode != nullptr);
if (index >= cnode->size()) {
return {};
}
auto origin_inputs = cnode->inputs();
std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
cnode->set_inputs(specify_inputs);
std::vector<lite::Tensor *> specify_tensors;
if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) {
cnode->set_inputs(origin_inputs);
return {};
}
cnode->set_inputs(origin_inputs);
auto shape = specify_tensors.front()->shape();
FreeTensors(&specify_tensors);
return shape;
}

std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) {
MS_ASSERT(cnode != nullptr);
if (index >= cnode->size()) {
return {};
}
auto origin_inputs = cnode->inputs();
std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
cnode->set_inputs(specify_inputs);
std::vector<lite::Tensor *> specify_tensors;
if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) {
cnode->set_inputs(origin_inputs);
return {};
}
cnode->set_inputs(origin_inputs);
std::vector<int> tensor_data;
if (specify_tensors.front()->data_type() != kNumberTypeInt32 &&
specify_tensors.front()->data_type() != kNumberTypeInt) {
FreeTensors(&specify_tensors);
return {};
}
if (specify_tensors.front()->shape().size() != 1) {
FreeTensors(&specify_tensors);
return {};
}
tensor_data.resize(specify_tensors.front()->shape()[0]);
if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data_c(),
tensor_data.size() * sizeof(int)) != EOK) {
FreeTensors(&specify_tensors);
return {};
}
return tensor_data;
}

STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(inputs != nullptr);
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfDepend(cnode);
lite::AnfExporter::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
std::vector<lite::Tensor *> const_inputs;
if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "get const inputs failed.";
FreeTensors(&const_inputs);
cnode->set_inputs(origin_inputs);
return lite::RET_ERROR;
}
std::vector<lite::Tensor *> var_inputs;
if (GetCNodeVarInput(cnode, &var_inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "get var inputs failed.";
FreeTensors(&var_inputs);
cnode->set_inputs(origin_inputs);
return lite::RET_ERROR;
}
size_t const_index = 0;
size_t var_index = 0;
bool input_valid = true;
for (size_t i = 1; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
if (var_index >= var_inputs.size()) {
MS_LOG(ERROR) << "var inputs size invalid.";
input_valid = false;
break;
}
inputs->emplace_back(var_inputs[var_index++]);
} else {
if (const_index >= const_inputs.size()) {
MS_LOG(ERROR) << "const inputs size invalid.";
input_valid = false;
break;
}
inputs->emplace_back(const_inputs[const_index++]);
}
}
cnode->set_inputs(origin_inputs);
if (!input_valid) {
FreeTensors(&const_inputs);
FreeTensors(&var_inputs);
inputs->resize(0);
}
return lite::RET_OK;
}

STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) {
MS_ASSERT(cnode != nullptr);
auto origin_inputs = cnode->inputs();
std::vector<AnfNodePtr> const_inputs;
for (auto &input : origin_inputs) {
if (utils::isa<CNodePtr>(input)) {
continue;
}
const_inputs.push_back(input);
}
cnode->set_inputs(const_inputs);
auto meta_graph = std::make_unique<schema::MetaGraphT>();
meta_graph->fmkType = fmk_type_;
auto fb_node = std::make_unique<schema::CNodeT>();
lite::AnfExporter anf_exporter;
anf_exporter.set_train_flag(train_flag_);
auto status = anf_exporter.SetOpInputNode(cnode, meta_graph, fb_node.get());
cnode->set_inputs(origin_inputs);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get const inputs failed.";
return status;
}
return ConvertToLiteTensor(meta_graph, fb_node->inputIndex, const_ms_inputs);
}

STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(var_ms_inputs != nullptr);
for (size_t i = 1; i < cnode->size(); ++i) {
if (!utils::isa<CNodePtr>(cnode->input(i))) {
continue;
}
auto abstract = GetCNodeInputAbstract(cnode, i);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract cnode is nullptr.";
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
MS_LOG(ERROR) << "Abstract should be anstract tensor.";
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
lite::Tensor *tensor = nullptr;
if (type_ptr->type_id() == kObjectTypeTensorType) {
tensor = GetCNodeTensorListVarInput(dims, abstract_tensor);
} else {
tensor = new (std::nothrow) lite::Tensor(TypeId(type_ptr->type_id()), dims);
}
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite tensor failed";
return lite::RET_ERROR;
}
var_ms_inputs->emplace_back(tensor);
}
return lite::RET_OK;
}

lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape,
const abstract::AbstractTensorPtr &abstract_tensor) {
MS_ASSERT(abstract_tensor != nullptr);
auto tensor_list = new (std::nothrow) lite::TensorList(shape, {});
if (tensor_list == nullptr) {
MS_LOG(ERROR) << "new a lite tensor list failed";
return nullptr;
}
auto tensor_info = abstract_tensor->GetValueTrack();
if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
delete tensor_list;
MS_LOG(ERROR) << "nsor list abstract is invalid.";
return nullptr;
}
auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
if (tensor_value->data_c() == nullptr) {
delete tensor_list;
MS_LOG(ERROR) << "cannot get tensor list abstract's info.";
return nullptr;
}
auto status = tensor_list->Decode(static_cast<int *>(tensor_value->data_c()));
if (status != lite::RET_OK) {
delete tensor_list;
MS_LOG(ERROR) << "decode tensor list failed.";
return nullptr;
}
return tensor_list;
}

STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(outputs != nullptr);
auto meta_graph = std::make_unique<schema::MetaGraphT>();
meta_graph->fmkType = fmk_type_;
auto fb_node = std::make_unique<schema::CNodeT>();
lite::AnfExporter anf_exporter;
anf_exporter.set_train_flag(train_flag_);
anf_exporter.SetOpOutputNode(cnode, meta_graph, fb_node.get());
return ConvertToLiteTensor(meta_graph, fb_node->outputIndex, outputs);
}

STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::vector<uint32_t> &tensor_indexes,
std::vector<lite::Tensor *> *tensors) {
MS_ASSERT(meta_graph != nullptr);
MS_ASSERT(tensors != nullptr);
for (auto index : tensor_indexes) {
auto tensor_t = meta_graph->allTensors.at(index).get();
auto tensor_shape = tensor_t->dims;
auto tensor_category = lite::TensorCategory(tensor_t->nodeType, tensor_t->dims.size(), TypeId(tensor_t->dataType),
tensor_t->data.size());
lite::Tensor *tensor = nullptr;
if (tensor_t->dataType != kObjectTypeTensorType) {
tensor =
new (std::nothrow) lite::Tensor(TypeId(tensor_t->dataType), tensor_shape, tensor_t->format, tensor_category);
} else {
tensor = new (std::nothrow) lite::TensorList(tensor_shape, std::vector<int>(), tensor_category);
}
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite tensor failed";
return lite::RET_ERROR;
}
auto tensor_size = tensor_t->data.size() * sizeof(char);
if (tensor_size > 0) {
if (tensor_t->dataType == kObjectTypeTensorType) {
auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
if (tensor_list->Decode(reinterpret_cast<const int *>(tensor_t->data.data())) != RET_OK) {
MS_LOG(ERROR) << "Decode tensorlist data failed";
return RET_ERROR;
}
} else {
auto tensor_data = new (std::nothrow) char[tensor_size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
delete tensor;
return lite::RET_ERROR;
}
if (memcpy_s(tensor_data, tensor_size, tensor_t->data.data(), tensor_size) != EOK) {
delete tensor;
delete[](tensor_data);
MS_LOG(ERROR) << "memcpy error: ";
return lite::RET_ERROR;
}
tensor->set_data(tensor_data);
}
}
tensors->emplace_back(tensor);
}
return lite::RET_OK;
}

STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode,
const std::vector<lite::Tensor *> &outputs) {
MS_ASSERT(cnode != nullptr);
if (outputs.size() == 0) {
MS_LOG(ERROR) << "empty output_tensors";
return RET_ERROR;
}
auto origin_abstract = cnode->abstract();
if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) {
auto tensor = outputs.front();
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) {
return RET_ERROR;
}
cnode->set_abstract(new_abstract);
} else {
AbstractBasePtrList abstract_list;
for (size_t i = 0; i < outputs.size(); i++) {
auto tensor = outputs.at(i);
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) {
return RET_ERROR;
}
abstract_list.emplace_back(new_abstract);
}
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
}
return RET_OK;
}

abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor);
if (tensor->data_type() == kObjectTypeTensorType) {
return ConvertTensorListToAbstract(tensor);
}
auto tensor_info = NewTensorInfo(tensor);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
return nullptr;
}
return tensor_info->ToAbstract();
}

// stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type.
// both of them is different in term of shape and type.
abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor);
auto tensor_list = dynamic_cast<lite::TensorList *>(tensor);
if (tensor_list == nullptr) {
MS_LOG(ERROR) << "cast tensor_list failed";
return nullptr;
}
std::vector<int> shape(tensor->shape());
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto tensor_list_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector);
if (tensor_list_abstract == nullptr) {
MS_LOG(ERROR) << "new AbstractTensor failed";
return nullptr;
}
auto elememt_shape = tensor_list->element_shape();
std::vector<int> data_info;
data_info.push_back(tensor_list->tensors_data_type());
data_info.push_back(elememt_shape.size());
std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info));
data_info.push_back(tensor_list->tensors().size());
for (size_t i = 0; i < tensor_list->tensors().size(); ++i) {
auto tensor_mem = tensor_list->tensors()[i];
auto tensor_mem_shape = tensor_mem->shape();
data_info.push_back(tensor_mem_shape.size());
std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info));
}
std::vector<int64_t> data_shape;
data_shape.push_back(data_info.size());
auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
return nullptr;
}
tensor_list_abstract->set_value(tensor_info);
return tensor_list_abstract;
}
} // namespace opt
} // namespace mindspore

+ 60
- 0
mindspore/lite/tools/optimizer/graph/node_infershape.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_

#include <vector>
#include <memory>
#include <string>
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"

using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace opt {
class NodeInferShape {
public:
NodeInferShape() = default;
virtual ~NodeInferShape() = default;
void Init(FmkType fmk_type, bool train_flag) {
fmk_type_ = fmk_type;
train_flag_ = train_flag;
}
STATUS InferShape(const CNodePtr &cnode);
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);

private:
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
lite::Tensor *GetCNodeTensorListVarInput(std::vector<int> shape, const abstract::AbstractTensorPtr &abstract_tensor);
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs);
STATUS ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::vector<uint32_t> &tensor_indexes, std::vector<lite::Tensor *> *tensors);
STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs);
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool train_flag_{false};
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_

+ 47
- 5
mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc View File

@@ -22,8 +22,8 @@

namespace mindspore::opt {
namespace {
constexpr size_t InputDoubleNum = 2;
constexpr size_t InputTripleNum = 3;
constexpr size_t kInputDoubleNum = 2;
constexpr size_t kInputTripleNum = 3;
void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) {
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(inputs != nullptr);
@@ -45,14 +45,14 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
}
auto cnode = anf_node->cast<CNodePtr>();
if (CheckPrimitiveType(anf_node, kPrimIdentity)) {
if (cnode->size() != InputDoubleNum) {
if (cnode->size() != kInputDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
}
}
if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) {
if (cnode->size() != InputDoubleNum) {
if (cnode->size() != kInputDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
@@ -106,7 +106,7 @@ int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->inputs().size() != InputTripleNum) {
if (cnode->inputs().size() != kInputTripleNum) {
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
return RET_ERROR;
}
@@ -133,6 +133,45 @@ int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
return lite::RET_OK;
}

int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(manager != nullptr);
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
return lite::RET_NO_CHANGE;
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->size() > kInputDoubleNum) {
MS_LOG(ERROR) << "dropout input invalid.";
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) {
MS_LOG(DEBUG) << "dropout output size is one.";
manager->Replace(anf_node, cnode->input(1));
} else {
auto node_users = manager->node_users()[anf_node];
for (auto &node_user : node_users) {
auto node = node_user.first;
if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
MS_LOG(ERROR) << "dropout out node is invalid.";
return lite::RET_ERROR;
}
auto get_index_node = node->cast<CNodePtr>()->input(kInputDoubleNum)->cast<ValueNodePtr>();
if (get_index_node == nullptr) {
MS_LOG(ERROR) << "tuple get item node is invalid.";
return lite::RET_ERROR;
}
auto get_index = CastToInt(get_index_node->value()).front();
if (get_index > 0 && !manager->node_users()[node].empty()) {
MS_LOG(ERROR) << "dropout's second output is useful.";
return lite::RET_ERROR;
}
manager->Replace(node, cnode->input(1));
}
}
return lite::RET_OK;
}

bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
@@ -155,6 +194,9 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
status = ReplaceTupleGetItem(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimDropout)) {
status = RemoveDropoutOp(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
if (sub_func_graph == nullptr) {


+ 1
- 0
mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h View File

@@ -31,6 +31,7 @@ class RemoveRedundantOpPass : public Pass {
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
bool Run(const FuncGraphPtr &graph) override;

private:


+ 364
- 0
mindspore/lite/tools/optimizer/graph/transpose_strategy.cc View File

@@ -0,0 +1,364 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/optimizer/graph/transpose_strategy.h"
#include <algorithm>
#include <memory>
#include <vector>
#include <string>
#include <utility>
#include "ops/crop.h"
#include "ops/fusion/activation.h"
#include "ops/fusion/slice_fusion.h"
#include "ops/op_utils.h"
#include "ops/strided_slice.h"
#include "tools/converter/quant_param_holder.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kFirstInput = 1;
constexpr size_t kTransposePerm = 2;
constexpr size_t kOnnxStridedSlice = 6;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return lite::RET_ERROR;
}
auto node_users = manager->node_users()[cnode];
if (node_users.empty()) {
MS_LOG(ERROR) << "cnode is isolated.";
return lite::RET_ERROR;
}
std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes),
[](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
return lite::RET_OK;
}
} // namespace

AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
// judge pair transpose after insert.
if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) {
std::vector<int> trans_perm;
auto input_cnode = trans_input_node->cast<CNodePtr>();
if (input_cnode == nullptr) {
MS_LOG(ERROR) << "input node is invalid.";
return nullptr;
}
if (GetTransposePerm(input_cnode->input(kTransposePerm), &trans_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "transpose perm get failed.";
return nullptr;
}
if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) {
return input_cnode->input(kFirstInput);
}
}
// insert depend on shape
return TransposeDependOnShape(func_graph, cnode, perm, before, index);
}

AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index);
if (status == lite::RET_ERROR) {
return nullptr;
} else if (status == lite::RET_NO_CHANGE) {
return before ? cnode->input(index) : cnode;
}
// insert tranpsoe
std::string trans_name =
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>();
quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0));
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
return trans_insert_node;
}

bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
TransTypePair *trans_info, TransTypePair *trans_insert_info) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
MS_ASSERT(pre_type != nullptr && post_type != nullptr);
size_t trans_count = 0;
std::vector<AnfNodePtr> in_nodes;
for (size_t i = 1; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
in_nodes.push_back(cnode->input(i));
}
}
if (!IsInOutCanFuison(func_graph, in_nodes, &trans_count, &trans_info->pre_)) {
return false;
}
std::vector<AnfNodePtr> out_nodes;
if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) {
return false;
}
if (!IsInOutCanFuison(func_graph, out_nodes, &trans_count, &trans_info->post_)) {
return false;
}
if (trans_info->pre_ == trans_info->post_) {
return false;
}
auto total_node_count = in_nodes.size() + out_nodes.size();
bool can_insert = trans_count > total_node_count / 2;
if (CheckPrimitiveType(cnode, prim::kPrimActivation)) {
auto prim_act = GetValueNode<std::shared_ptr<ops::Activation>>(cnode->input(0));
MS_ASSERT(prim_act != nullptr);
if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) {
can_insert = trans_count >= total_node_count / 2;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
can_insert = trans_count >= total_node_count / 2;
}
if (!can_insert) {
return can_insert;
}
DecidePreAndPostTransType(trans_info, trans_insert_info);
return can_insert;
}

STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != 4) {
if (cnode->size() > 2) {
shape = node_infer_shape_.GetInputShape(cnode, 2);
if (shape.size() != 4 && !shape.empty()) {
return lite::RET_NOT_SUPPORT;
}
} else {
return lite::RET_NOT_SUPPORT;
}
}
auto axis_map = GetNC2NHAxisMap();
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
auto new_axis = axis_map[axis < 0 ? axis + 4 : axis];
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
}
if (CheckPrimitiveType(cnode, prim::kPrimCrop)) {
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
if (crop_prim == nullptr) {
return lite::RET_NULL_PTR;
}
auto axis = crop_prim->get_axis();
auto offsets = crop_prim->get_offsets();
auto new_axis = axis_map[axis < 0 ? axis + 4 : axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
} else if (new_axis == 3) {
offsets = {offsets[1], offsets[2], offsets[0]};
} else {
offsets.push_back(0);
}
crop_prim->set_offsets(offsets);
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
return ChangeOpSlice(func_graph, cnode);
}
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
return ChangeOpStrideSlice(func_graph, cnode);
}
return lite::RET_OK;
}

STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return lite::RET_ERROR;
}
auto node_users = manager->node_users()[cnode];
if (node_users.empty()) {
MS_LOG(ERROR) << "cnode is isolated.";
return lite::RET_ERROR;
}
if (!utils::isa<CNodePtr>(node_users.front().first)) {
return lite::RET_ERROR;
}
CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
size_t input_index = before ? index : node_users.front().second;
auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
if (!shape.empty() && shape.size() != NH2NC.size()) {
return lite::RET_NO_CHANGE;
}
return lite::RET_OK;
}

bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
size_t *trans_count, FormatTransNodeType *trans_type) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(trans_count != nullptr && trans_type != nullptr);
for (auto &node : nodes) {
if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
FormatTransNodeType cur_type;
std::vector<int> perm;
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
}
if (GetTransposePerm(cnode->input(kTransposePerm), &perm) != lite::RET_OK) {
return false;
}
if (perm == NH2NC) {
cur_type = kNHWC2NCHW;
} else if (perm == NC2NH) {
cur_type = kNCHW2NHWC;
} else {
return false;
}
if (*trans_type == kNONE) {
*trans_type = cur_type;
} else if (*trans_type != cur_type) {
return false;
}
*trans_count += 1;
}
}
return true;
}

void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) {
if (trans_info->pre_ == trans_info->post_) {
return;
}
if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) {
trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
} else if (trans_info->pre_ == kNONE) {
trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
} else {
trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
}
}

STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
auto shape = node_infer_shape_.GetInputShape(cnode, 2);
if (shape.empty()) {
return lite::RET_NOT_SUPPORT;
}
int element_num = shape.front();
auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
std::vector<int> axes;
if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
for (int index = 0; index < element_num; ++index) {
axes.push_back(index);
}
} else {
auto origin_axes = prim->get_axes();
std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
[](int64_t v) { return static_cast<int>(v); });
}
for (size_t i = 2; i < cnode->size(); ++i) {
TransformAttrByAxes(func_graph, cnode, i, axes);
}
auto tmp_axes = TransformOpAxesAttr(axes);
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
prim->set_axes(new_axes);
return lite::RET_OK;
}

STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (cnode->size() != kOnnxStridedSlice) {
return lite::RET_NOT_SUPPORT;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
std::vector<int> axes = node_infer_shape_.GetIntVecInput(cnode, kOnnxStridedSlice - 2);
if (axes.empty()) {
MS_LOG(ERROR) << "strided slice input invalid.";
return lite::RET_ERROR;
}
for (size_t index = 2; index < cnode->size(); ++index) {
if (index == 4) {
continue;
}
TransformAttrByAxes(func_graph, cnode, index, axes);
}
auto cur_axes = TransformOpAxesAttr(axes);
auto param_node = BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(4)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(4), param_node);
return lite::RET_OK;
}

void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes) {
if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) {
return;
}
auto axis_map = GetNC2NHAxisMap();
auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index);
if (origin_input.size() != axes.size()) {
return;
}
std::vector<int> cur_input;
for (int dim = 0; dim < 4; ++dim) {
for (size_t index = 0; index < axes.size(); ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
if (nhwc_dim == dim) {
cur_input.push_back(origin_input[index]);
}
}
}
auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(input_index), param_node);
}

std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes) {
auto axis_map = GetNC2NHAxisMap();
std::vector<int> cur_axis;
for (size_t i = 0; i < origin_axes.size(); ++i) {
cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + 4 : origin_axes[i]]);
}
std::sort(cur_axis.begin(), cur_axis.end());
return cur_axis;
}
} // namespace opt
} // namespace mindspore

+ 65
- 0
mindspore/lite/tools/optimizer/graph/transpose_strategy.h View File

@@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_

#include <vector>
#include <memory>
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/optimizer/graph/node_infershape.h"

using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace opt {
class TransposeStrategy {
public:
TransposeStrategy() = default;
~TransposeStrategy() = default;
void Init(FmkType fmk_type, bool train_flag) {
fmk_type_ = fmk_type;
train_flag_ = train_flag;
node_infer_shape_.Init(fmk_type, train_flag);
}
AnfNodePtr TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &code,
const std::vector<int> &perm, bool before, size_t index);
AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm,
bool before, size_t index);
bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info,
TransTypePair *trans_insert_info);
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

private:
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
FormatTransNodeType *trans_type);
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes);
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool train_flag_{false};
NodeInferShape node_infer_shape_;
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_

+ 785
- 0
mindspore/lite/tools/optimizer/graph/unify_format_pass.cc View File

@@ -0,0 +1,785 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/optimizer/graph/unify_format_pass.h"
#include <unordered_map>
#include <utility>
#include "ops/op_utils.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/anf_exporter/anf_exporter.h"

using mindspore::lite::NCHW_SHAPE;
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kNCHWDimNumber = 4;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
bool IsSpecialType(const CNodePtr &cnode) {
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) ||
CheckPrimitiveType(cnode, prim::kPrimReturn)) {
return true;
}
return false;
}
} // namespace

void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = GetNHWCOpMap();
auto &specify_nchw_op_map = GetNCHWOpMap();
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return;
}
trans_info->pre_ = kNHWC2NCHW;
trans_info->post_ = kNCHW2NHWC;
} else if (fmk_type_ == lite::converter::FmkType_TF) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && GetFormat(cnode) == NCHW) {
trans_info->pre_ = kNCHW2NHWC;
trans_info->post_ = kNHWC2NCHW;
}
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
trans_info->pre_ = kNHWC2NCHW;
trans_info->post_ = kNCHW2NHWC;
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return;
}
trans_info->pre_ = kNCHW2NHWC;
trans_info->post_ = kNHWC2NCHW;
}
}
}

bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) {
return false;
}
std::vector<int> post_perm;
if (GetTransposePerm(cnode->input(2), &post_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
std::vector<int> pre_perm;
auto pre_node = cnode->input(1);
auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) {
return false;
}
if (GetTransposePerm(pre_cnode->input(2), &pre_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
func_graph->manager()->Replace(cnode, pre_cnode->input(1));
return true;
}
return false;
}

STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
return lite::RET_OK;
}
std::vector<int> cur_perm;
if (GetTransposePerm(cnode->input(2), &cur_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get transpose perm failed.";
return lite::RET_ERROR;
}
auto node_users = func_graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
std::vector<int> post_trans_perm;
auto post_trans_node = post_node->cast<CNodePtr>();
if (GetTransposePerm(post_trans_node->input(2), &post_trans_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get post transpose node perm failed.";
return lite::RET_ERROR;
}
if ((cur_perm == NH2NC && post_trans_perm == NC2NH) || (cur_perm == NC2NH && post_trans_perm == NH2NC)) {
func_graph->manager()->Replace(post_node, cnode->input(1));
}
}
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
if (need_reset_) {
new_input = transpose_strategy_.TransposeDependOnShape(func_graph, cnode, perm, before, index);
} else {
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
}
if (new_input == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
if (new_input == cnode->input(index) || new_input == cnode) {
return lite::RET_OK;
} else if (utils::isa<CNodePtr>(new_input)) {
auto new_cnode_input = new_input->cast<CNodePtr>();
int status = lite::RET_OK;
if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) {
if (need_reset_) {
if (before) {
pre_insert_trans_.insert(new_cnode_input);
} else {
post_insert_trans_.insert(new_cnode_input);
}
}
status = node_infer_shape_.InferShape(new_cnode_input);
}
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
if (before) {
tr.SetEdge(cnode, index, new_input);
tr.Commit();
} else {
func_graph->manager()->Replace(cnode, new_input);
if (!need_reset_ && PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) {
MS_LOG(ERROR) << "post transpose fusion failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = GetNHWCOpMap();
auto &specify_nchw_op_map = GetNCHWOpMap();
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
MS_LOG(ERROR) << "op don't meet nhwc condition.";
return lite::RET_ERROR;
}
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
? specify_nhwc_op_map.at(prim->name())
: specify_nchw_op_map.at(prim->name());
if (insert_index.empty()) {
if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
insert_index.push_back(1);
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
insert_index.push_back(i);
}
}
}
for (auto &index : insert_index) {
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
TransTypePair *trans_insert_info) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
MS_ASSERT(trans_insert_info != nullptr);
TransTypePair trans_info;
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
cnode->set_inputs(origin_inputs);
return lite::RET_NO_CHANGE;
}
cnode->set_inputs(origin_inputs);
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode);
if (status == lite::RET_NOT_SUPPORT) {
return lite::RET_NO_CHANGE;
} else if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
for (size_t i = 1; i < cnode->size(); ++i) {
if (IsMonadNode(cnode->input(i))) {
continue;
}
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
MS_ASSERT(input_make_tuple != nullptr);
for (size_t j = 1; j < input_make_tuple->size(); ++j) {
if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
continue;
}
if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
} else {
auto node_users = func_graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
CNodePtr tuple_get_item = nullptr;
if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
if (!train_flag_) {
MS_LOG(ERROR) << "post node is invalid.";
return lite::RET_ERROR;
} else {
tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
post_node = tuple_get_item;
func_graph->manager()->Replace(cnode, tuple_get_item);
}
}
if (func_graph->manager()->node_users()[post_node].empty()) {
continue;
}
auto post_cnode = post_node->cast<CNodePtr>();
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
if (tuple_get_item != nullptr) {
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
}
}
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
return lite::RET_NO_CHANGE;
}
for (size_t i = 1; i < cnode->size(); ++i) {
auto node = cnode->input(i);
if (!utils::isa<ParameterPtr>(node)) {
continue;
}
auto param_node = node->cast<ParameterPtr>();
if (param_node->has_default()) {
continue;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (shape_vector.size() != kNCHWDimNumber) {
continue;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
shape_vector[1] == -1) {
continue;
}
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
auto trans_cnode = GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
auto status = node_infer_shape_.InferShape(trans_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
func_graph->manager()->Replace(param_node, trans_cnode);
}
return lite::RET_OK;
}

STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(kTransDone) != nullptr && GetValue<bool>(prim->GetAttr(kTransDone))) {
return lite::RET_OK;
}
prim->AddAttr(kTransDone, MakeValue<bool>(true));
TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (!need_reset_ && (trans_info.pre_ == kNONE || trans_info.post_ == kNONE)) {
if (TransTransFusion(func_graph, cnode)) {
return lite::RET_OK;
}
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
PreProcessFowardInsert(func_graph, cnode, &match);
auto status = node_infer_shape_.InferShape(cnode);
PostProcessFowardInsert(func_graph, cnode, match);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_NO_CHANGE;
}
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
PreProcessFowardInsert(func_graph, cnode, &match);
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
auto status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
PostProcessFowardInsert(func_graph, cnode, match);
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_OK;
}

void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, AnfNodePtr> *match) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) {
return;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
for (size_t i = 1; i < cnode->size(); ++i) {
if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) {
continue;
}
match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i)));
tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]);
tr.Commit();
}
}

void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (match.empty()) {
return;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
for (size_t i = 1; i < cnode->size(); ++i) {
if (match.find(cnode->input(i)) != match.end()) {
tr.SetEdge(cnode, i, match.at(cnode->input(i)));
tr.Commit();
}
if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) {
auto trans_cnode = cnode->input(i)->cast<CNodePtr>();
for (size_t j = 1; j < trans_cnode->size(); ++j) {
if (match.find(trans_cnode->input(j)) == match.end()) {
continue;
}
tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j)));
tr.Commit();
}
}
}
}

void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto subgraph_name = GetValue<std::string>(sub_graph->get_attr("graph_name"));
sub_inputs_map_[subgraph_name] = {};
auto sub_inputs = sub_graph->get_inputs();
for (auto &node : sub_inputs) {
auto param_node = node->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
auto node_name = node->fullname_with_scope();
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
if (utils::isa<CNodePtr>(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
std::vector<int> shape = {-1};
auto trans_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(trans_prim->GetAttr(kInferDone))) {
shape = node_infer_shape_.GetInputShape(cnode, index);
}
auto type = trans_cnode->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetTypeTrack();
std::vector<int64_t> shape_vec(shape.begin(), shape.end());
param_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vec));
} else {
sub_inputs_map_[subgraph_name][node] = cnode->input(index);
}
}
}

void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_input = return_node->inputs();
lite::AnfExporter::RemoveIfDepend(return_node);
lite::AnfExporter::RemoveIfMakeTuple(return_node);
for (size_t i = 1; i < return_node->size(); ++i) {
if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
continue;
}
auto node_name = return_node->input(i)->fullname_with_scope();
if (node_name.substr(node_name.size() - 5) != "_post") {
continue;
}
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto trans_input = trans_cnode->input(1);
auto trans_input_name = trans_input->fullname_with_scope();
if (utils::isa<ParameterPtr>(trans_input)) {
trans_input->cast<ParameterPtr>()->set_name(node_name);
} else if (utils::isa<CNodePtr>(trans_input)) {
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
}
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
}

void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_inputs = return_node->inputs();
lite::AnfExporter::RemoveIfDepend(return_node);
lite::AnfExporter::RemoveIfMakeTuple(return_node);
AbstractBasePtrList abstract_list;
bool infer_done = true;
for (size_t i = 1; i < return_node->size(); ++i) {
auto abstract_base = GetCNodeInputAbstract(return_node, i);
MS_ASSERT(abstract_base != nullptr);
abstract_list.emplace_back(abstract_base->Clone());
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_ASSERT(abstract_tensor != nullptr);
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
MS_ASSERT(shape_ptr != nullptr);
auto shape = shape_ptr->shape();
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
infer_done = false;
}
if (utils::isa<CNodePtr>(return_node->input(i))) {
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
infer_done = false;
}
}
}
return_node->set_inputs(origin_inputs);
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
} else {
if (abstract_list.size() != 1) {
MS_LOG(ERROR) << "cnode output is invalid.";
}
cnode->set_abstract(abstract_list.front());
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
}

bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (main_graph && !need_reset_) {
status = HandleGraphInput(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto origin_inputs = cnode->inputs();
for (size_t i = 3; i < cnode->size(); ++i) {
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
}
}
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
cnode->set_inputs(origin_inputs);
continue;
}
status = HandleGraphNode(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}

bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto origin_inputs = cnode->inputs();
for (size_t i = 3; i < cnode->size(); ++i) {
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() &&
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) {
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]);
}
}
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)DecreaseTransposeForSingleOp(sub_func_graph);
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)DecreaseTransposeForSingleOp(sub_func_graph);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
cnode->set_inputs(origin_inputs);
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (!lite::IsContain(GetDynamicFormatOpList(), prim->name())) {
continue;
}
TransTypePair trans_insert_info;
status = InsertPreTransNode(func_graph, cnode, &trans_insert_info);
if (status == lite::RET_NO_CHANGE) {
continue;
} else if (status != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed.";
return false;
}
auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? NH2NC : NC2NH;
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return false;
}
}
return true;
}

bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(kInferDone) != nullptr) {
prim->EraseAttr(kInferDone);
}
if (prim->GetAttr(kTransDone) != nullptr) {
prim->EraseAttr(kTransDone);
}
if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) {
manager->Replace(node, cnode->input(1));
}
if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) {
auto cnode_abstract = cnode->abstract();
if (!utils::isa<abstract::AbstractTensorPtr>(cnode_abstract)) {
MS_LOG(ERROR) << "abstract is not abstract tensor.";
return false;
}
auto cnode_abstract_tensor = cnode_abstract->cast<abstract::AbstractTensorPtr>();
if (!utils::isa<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr.";
return false;
}
auto shape_ptr = utils::cast<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
auto input_abstract = GetCNodeInputAbstract(cnode, 1);
if (!utils::isa<abstract::AbstractTensorPtr>(input_abstract)) {
MS_LOG(ERROR) << "abstract is not abstract tensor.";
return false;
}
auto input_abstract_tensor = input_abstract->cast<abstract::AbstractTensorPtr>();
input_abstract_tensor->set_shape(shape_ptr);
manager->Replace(node, cnode->input(1));
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
}
}
return true;
}

bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
need_reset_ = true;
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, transpose op cannot be fused to restore the original graph.
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
// delete insert transpose op and update op output shape.
if (!ResetFuncGraph(func_graph)) {
MS_LOG(ERROR) << "reset func_graph failed.";
return false;
}
return true;
}

bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
auto prim = GetValueNode<PrimitivePtr>(node);
if (prim == nullptr) {
continue;
}
if (prim->GetAttr(kTransDone) != nullptr) {
return true;
}
}
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, tranpose can be fused, which the original graph may not be able to restored.
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
// if input's format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
if (!DecreaseTransposeForSingleOp(func_graph)) {
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

+ 80
- 0
mindspore/lite/tools/optimizer/graph/unify_format_pass.h View File

@@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_

#include <vector>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/optimizer/graph/transpose_strategy.h"

using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace opt {
class UnifyFormatPass : public Pass {
public:
UnifyFormatPass() : Pass("unify_format_pass") {}
~UnifyFormatPass() override = default;
void Init(FmkType fmk_type, bool train_flag) {
fmk_type_ = fmk_type;
train_flag_ = train_flag;
node_infer_shape_.Init(fmk_type, train_flag);
transpose_strategy_.Init(fmk_type, train_flag);
}
bool Run(const FuncGraphPtr &func_graph) override;
bool RunOnlyForShape(const FuncGraphPtr &func_graph);

private:
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
void PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, AnfNodePtr> *match);
void PostProcessFowardInsert(const FuncGraphPtr &funcgraph, const CNodePtr &cnode,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool need_reset_{false};
bool train_flag_{false};
NodeInferShape node_infer_shape_;
TransposeStrategy transpose_strategy_;
std::set<AnfNodePtr> pre_insert_trans_;
std::set<AnfNodePtr> post_insert_trans_;
std::unordered_map<std::string, std::unordered_map<AnfNodePtr, AnfNodePtr>> sub_inputs_map_;
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_

+ 0
- 2
mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc View File

@@ -16,7 +16,6 @@
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include <memory>
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_backprop_input_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
@@ -33,7 +32,6 @@ using mindspore::schema::QuantType_WeightQuant;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion);
} // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; }
void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; }


+ 0
- 2
mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc View File

@@ -17,7 +17,6 @@
#include <memory>
#include <algorithm>
#include <vector>
#include "ops/fusion/conv2d_backprop_input_fusion.h"
#include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"

@@ -34,7 +33,6 @@ namespace mindspore::opt {
namespace {
constexpr size_t kFirstInputIndex = 1;
constexpr size_t kConvWeightIndex = 2;
const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion);
lite::STATUS GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
MS_ASSERT(perm != nullptr);
auto src_format_str = std::string(schema::EnumNameFormat(src_format));


Loading…
Cancel
Save