| @@ -19,6 +19,13 @@ | |||
| int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||
| OpParameter *parameter) { | |||
| #ifdef Debug | |||
| int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); | |||
| if (check_ret != NNACL_OK) { | |||
| return check_ret; | |||
| } | |||
| #endif | |||
| for (size_t i = 0; i < inputs_size; i++) { | |||
| if (outputs_size <= i) { | |||
| break; | |||
| @@ -31,7 +38,10 @@ int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, T | |||
| outputs[5]->shape_size_ = 1; | |||
| outputs[5]->shape_[0] = 1; | |||
| } | |||
| return 0; | |||
| if (!parameter->infer_flag_) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| REG_INFER(FusedBatchNorm, PrimType_FusedBatchNorm, FusedBatchNormInferShape) | |||
| @@ -31,6 +31,10 @@ int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size | |||
| if (GetElementNum(get_index) != 1) { | |||
| return NNACL_ERR; | |||
| } | |||
| TensorC *output = outputs[0]; | |||
| if (!parameter->infer_flag_ || input0->element_num_ == 0) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| if (get_index->data_ == NULL) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| @@ -40,7 +44,6 @@ int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size | |||
| } | |||
| TensorC *tensor_index = &input0->tensors_[index]; | |||
| TensorC *output = outputs[0]; | |||
| if (tensor_index->data_type_ != kTypeUnknown) { | |||
| output->data_type_ = tensor_index->data_type_; | |||
| } else { | |||
| @@ -61,21 +61,6 @@ int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors, std::vector | |||
| return RET_OK; | |||
| } | |||
| void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out) { | |||
| for (size_t i = 0; i < tensors_in.size(); ++i) { | |||
| if (tensors_in[i] != nullptr) { | |||
| tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_)); | |||
| tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_)); | |||
| tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_}); | |||
| if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) { | |||
| auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i)); | |||
| auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i)); | |||
| tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void FreeAllTensorC(std::vector<TensorC *> *tensors_in) { | |||
| for (auto &i : *tensors_in) { | |||
| if (i == nullptr) { | |||
| @@ -26,7 +26,6 @@ namespace mindspore { | |||
| namespace lite { | |||
| int InputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out); | |||
| int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out); | |||
| void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out); | |||
| void FreeAllTensorC(std::vector<TensorC *> *tensors_in); | |||
| void FreeTensorListC(TensorListC *tensorListC); | |||
| void Tensor2TensorC(Tensor *src, TensorC *dst); | |||
| @@ -49,24 +49,23 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite | |||
| ret = infer_shape_func(static_cast<TensorC **>(in_tensors.data()), in_tensors.size(), out_tensors.data(), | |||
| out_tensors.size(), parameter); | |||
| if (ret == RET_OK) { | |||
| for (size_t i = 0; i < out_tensors.size(); i++) { | |||
| if (reinterpret_cast<TensorListC *>(out_tensors.at(i))->data_type_ == TypeIdC::kObjectTypeTensorType) { | |||
| auto *tensor_list_c = reinterpret_cast<TensorListC *>(out_tensors.at(i)); | |||
| auto *tensor_list = reinterpret_cast<TensorList *>(outputs->at(i)); | |||
| tensor_list->set_shape({static_cast<int>(tensor_list_c->element_num_)}); | |||
| auto tensor_shape = std::vector<std::vector<int>>( | |||
| tensor_list_c->element_num_, | |||
| std::vector<int>(tensor_list_c->element_shape_, | |||
| tensor_list_c->element_shape_ + tensor_list_c->element_shape_size_)); | |||
| tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape); | |||
| TensorListC2TensorList(tensor_list_c, tensor_list); | |||
| } else { | |||
| TensorC2Tensor(out_tensors.at(i), outputs->at(i)); | |||
| } | |||
| for (size_t i = 0; i < out_tensors.size(); i++) { | |||
| if (out_tensors.at(i) == nullptr) { | |||
| continue; | |||
| } | |||
| if (reinterpret_cast<TensorListC *>(out_tensors.at(i))->data_type_ == TypeIdC::kObjectTypeTensorType) { | |||
| auto *tensor_list_c = reinterpret_cast<TensorListC *>(out_tensors.at(i)); | |||
| auto *tensor_list = reinterpret_cast<TensorList *>(outputs->at(i)); | |||
| tensor_list->set_shape({static_cast<int>(tensor_list_c->element_num_)}); | |||
| auto tensor_shape = std::vector<std::vector<int>>( | |||
| tensor_list_c->element_num_, | |||
| std::vector<int>(tensor_list_c->element_shape_, | |||
| tensor_list_c->element_shape_ + tensor_list_c->element_shape_size_)); | |||
| tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape); | |||
| TensorListC2TensorList(tensor_list_c, tensor_list); | |||
| } else { | |||
| TensorC2Tensor(out_tensors.at(i), outputs->at(i)); | |||
| } | |||
| } else { | |||
| SetOutputTensorAttr(out_tensors, outputs); | |||
| } | |||
| FreeAllTensorC(&in_tensors); | |||
| @@ -225,6 +225,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc | |||
| ${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc | |||
| ${LITE_DIR}/tools/optimizer/common/gllo_utils.cc | |||
| ${LITE_DIR}/tools/optimizer/common/format_utils.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_activation_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_tuple_activation_fusion.cc | |||
| @@ -271,6 +272,9 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/unify_format_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/node_infershape.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/transpose_strategy.cc | |||
| ${LITE_DIR}/tools/common/graph_util.cc | |||
| ${LITE_DIR}/tools/common/tensor_util.cc | |||
| ${LITE_DIR}/tools/common/node_util.cc | |||
| @@ -10,7 +10,7 @@ tracking | |||
| ml_face_contour | |||
| mnet | |||
| ml_face_landmark | |||
| ml_liveness_detect_landmark | |||
| ml_liveness_detect_landmark_tmp | |||
| deconv_test_model | |||
| deconvs_model | |||
| # onnx | |||
| @@ -5,7 +5,7 @@ gender_res_large_deploy | |||
| glasses | |||
| hat | |||
| isface | |||
| ml_bank_detect_0312 | |||
| ml_bank_detect_0312_tmp | |||
| ml_face_div_parsing | |||
| ml_hardware_eyeclose | |||
| ml_ocr_detect_20200305 | |||
| @@ -26,9 +26,9 @@ hiai_face_landmark | |||
| hiai_face_pose_tuku | |||
| ml_hand_detection | |||
| ml_ocr_cn | |||
| ml_ocr_sfz_detect_0325 | |||
| ml_ocr_sfz_detect_0325_tmp | |||
| ml_hardware_liveness | |||
| ml_liveness_detect_landmark | |||
| ml_liveness_detect_landmark_tmp | |||
| ml_face_contour | |||
| 2012_ATLANTA_1class_20190621_v4.x_nomean | |||
| ml_ocr_sfz_add_final_0325 | |||
| @@ -63,7 +63,7 @@ age_new | |||
| detection_retinaface_fix | |||
| landmark | |||
| plat_isface | |||
| PoseNet_dla_17_x512 | |||
| PoseNet_dla_17_x512_tmp | |||
| ml_location_scene_division | |||
| ml_tabel_recog | |||
| ml_text_division | |||
| @@ -106,15 +106,15 @@ ml_Hand_deploy | |||
| ml_hand_3d_detection | |||
| ml_hand_3d_regression | |||
| ml_ARengine23_bodypose | |||
| ml_ocr_bank_card_detection_inception | |||
| ml_ocr_bank_card_detection_inception_tmp | |||
| ml_ocr_bank_card_recognition_fcny | |||
| hiai_cv_aestheticsEngineModel_osp | |||
| bank_card_recognition_fcny | |||
| bank_card_detection_inception | |||
| bank_card_detection_inception_tmp | |||
| ml_ocr_identify_card_fcny | |||
| ml_ocr_identify_card_detect | |||
| identify_card_detect | |||
| ml_ocr_identify_card_detect_tmp | |||
| identify_card_detect_tmp | |||
| ml_2012_ocr_rec_caffe | |||
| ml_2012_ocr_detection_caffe | |||
| ml_2012_ocr_detection_caffe_tmp | |||
| ml_face_mnet | |||
| ml_segmentation_atlanta_1 | |||
| @@ -5,7 +5,7 @@ gender_res_large_deploy 0.1 | |||
| glasses 4 | |||
| hat 1 | |||
| isface 1 | |||
| ml_bank_detect_0312 20 | |||
| ml_bank_detect_0312_tmp 20 | |||
| ml_face_div_parsing 8 | |||
| ml_hardware_eyeclose 0.1 | |||
| ml_ocr_detect_20200305 10 | |||
| @@ -25,9 +25,9 @@ hiai_face_landmark 0.2 | |||
| hiai_face_pose_tuku 1.3 | |||
| ml_hand_detection 8 | |||
| ml_ocr_cn 6 | |||
| ml_ocr_sfz_detect_0325 3 | |||
| ml_ocr_sfz_detect_0325_tmp 3 | |||
| ml_hardware_liveness 3 | |||
| ml_liveness_detect_landmark 1 | |||
| ml_liveness_detect_landmark_tmp 1 | |||
| ml_face_contour 0.5 | |||
| 2012_ATLANTA_1class_20190621_v4.x_nomean 1 | |||
| ml_ocr_sfz_add_final_0325 0.1 | |||
| @@ -62,7 +62,7 @@ age_new 22 | |||
| detection_retinaface_fix 13 | |||
| landmark 1 | |||
| plat_isface 6 | |||
| PoseNet_dla_17_x512 5 | |||
| PoseNet_dla_17_x512_tmp 5 | |||
| ml_location_scene_division 8 | |||
| ml_tabel_recog 0.1 | |||
| ml_text_division 12 | |||
| @@ -101,16 +101,16 @@ ml_Hand_deploy 4 | |||
| ml_hand_3d_detection 12 | |||
| ml_hand_3d_regression 3 | |||
| ml_ARengine23_bodypose 56 | |||
| ml_ocr_bank_card_detection_inception 20 | |||
| ml_ocr_bank_card_detection_inception_tmp 20 | |||
| ml_ocr_bank_card_recognition_fcny 0.5 | |||
| hiai_cv_aestheticsEngineModel_osp 1.5 | |||
| ml_face_hat 0.5 | |||
| bank_card_recognition_fcny 17 | |||
| bank_card_detection_inception 12 | |||
| bank_card_detection_inception_tmp 12 | |||
| ml_ocr_identify_card_fcny 0.5 | |||
| ml_ocr_identify_card_detect 2 | |||
| identify_card_detect 0.5 | |||
| ml_2012_ocr_detection_caffe 1 | |||
| ml_ocr_identify_card_detect_tmp 2 | |||
| identify_card_detect_tmp 0.5 | |||
| ml_2012_ocr_detection_caffe_tmp 1 | |||
| ml_2012_ocr_rec_caffe 0.5 | |||
| ml_lable_model_hebing_device 2 | |||
| ml_face_sex 0.5 | |||
| @@ -7,5 +7,5 @@ mtk_new_detect.tflite | |||
| mtk_pose.tflite | |||
| mtk_model_emotions_0727_nosoftmax.tflite | |||
| landmark | |||
| PoseNet_dla_17_x512 | |||
| PoseNet_dla_17_x512_tmp | |||
| plat_isface | |||
| @@ -20,7 +20,7 @@ mtk_convert_model.tflite | |||
| mtk_model_face_dress_fp16.tflite | |||
| detection_retinaface_fix | |||
| landmark | |||
| PoseNet_dla_17_x512 | |||
| PoseNet_dla_17_x512_tmp | |||
| age_new | |||
| plat_isface | |||
| Q_hand_0812.pb | |||
| @@ -81,4 +81,4 @@ posenet_mobilenet_float_075_1_default_1.tflite 395 | |||
| nasnet_mobile.tflite 1 | |||
| ml_video_edit_art_generate.onnx 0.5 | |||
| ml_video_edit_art_transfer.onnx 3 3 | |||
| ml_video_edit_enhance_update.onnx 0.5 | |||
| ml_video_edit_enhance_update_tmp.onnx 0.5 | |||
| @@ -7,7 +7,7 @@ mobilenetv2-7.onnx | |||
| shufflenet-v2-10.onnx | |||
| squeezenet1.1-7.onnx | |||
| densenet-9.onnx | |||
| ml_table_detection_fp32.onnx | |||
| ml_table_detection_fp32_tmp.onnx | |||
| ml_table_segment.onnx | |||
| googlenet-9.onnx | |||
| inception-v1-9.onnx | |||
| @@ -52,7 +52,7 @@ hdc_resnet_1w_class.onnx | |||
| ml_video_edit_imitate_filter.onnx | |||
| #ml_voice_detect.onnx #Accuracy error: 4.59655%, the result is close to 1.0e-8 except for the last one | |||
| hdc_ocr_attention.onnx | |||
| hdc_ocr_detect.onnx | |||
| hdc_ocr_detect_tmp.onnx | |||
| ml_edu_kit_hand_detection.onnx | |||
| ml_edu_kit_hand_key_position.onnx | |||
| ml_facedetector.onnx | |||
| @@ -71,8 +71,8 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1,480,640,3 | |||
| mtk_face_features_v2.onnx;1,256,192,3 | |||
| mtk_face_recognition_v3.onnx | |||
| mtk_face_recognition_v2.onnx | |||
| ml_2012_ocr_detection.onnx | |||
| ml_video_edit_enhance_update.onnx | |||
| ml_2012_ocr_detection_tmp.onnx | |||
| ml_video_edit_enhance_update_tmp.onnx | |||
| #Harmony_Voiceprint_resnet18.onnx | |||
| bloom_hongmo_detection.onnx | |||
| bloom_hongmo_detection_tmp.onnx | |||
| Q_face_recognition.onnx | |||
| @@ -7,7 +7,7 @@ mobilenetv2-7.onnx 8 | |||
| shufflenet-v2-10.onnx 5 | |||
| squeezenet1.1-7.onnx 1 | |||
| densenet-9.onnx 6 | |||
| ml_table_detection_fp32.onnx 2 | |||
| ml_table_detection_fp32_tmp.onnx 2 | |||
| ml_table_segment.onnx 2 | |||
| googlenet-9.onnx 3 | |||
| inception-v1-9.onnx 3 | |||
| @@ -37,7 +37,7 @@ residual_distill_res34_cifar10_bs_1_update.onnx 2 | |||
| residual_distill_res50_cifar10_bs_1_update.onnx 2 | |||
| #ml_voice_detect.onnx #out of float16 range because power op | |||
| hdc_ocr_attention.onnx 1.6 | |||
| hdc_ocr_detect.onnx 30 #one of the output has small values | |||
| hdc_ocr_detect_tmp.onnx 30 #one of the output has small values | |||
| ml_edu_kit_hand_detection.onnx 2 | |||
| ml_edu_kit_hand_key_position.onnx 2 | |||
| ml_video_edit_judge.onnx 12 | |||
| @@ -71,8 +71,8 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx.onnx;1,480,640,3 2 | |||
| mtk_face_features_v2.onnx;1,256,192,3 0.5 | |||
| mtk_face_recognition_v3.onnx 0.5 | |||
| mtk_face_recognition_v2.onnx 2.5 | |||
| ml_2012_ocr_detection.onnx 0.5 | |||
| ml_2012_ocr_detection_tmp.onnx 0.5 | |||
| #Harmony_Voiceprint_resnet18.onnx;1,1,200,40 4.5 | |||
| bloom_hongmo_detection.onnx 0.5 | |||
| bloom_hongmo_detection_tmp.onnx 0.5 | |||
| Q_face_recognition.onnx 2 | |||
| ml_video_edit_enhance_update.onnx 0.5 | |||
| ml_video_edit_enhance_update_tmp.onnx 0.5 | |||
| @@ -37,6 +37,7 @@ class AnfExporter { | |||
| public: | |||
| AnfExporter() = default; | |||
| virtual ~AnfExporter() = default; | |||
| void set_train_flag(bool train_flag) { train_flag_ = train_flag; } | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, | |||
| bool train_flag = false); | |||
| void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| @@ -35,6 +35,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/common/node_pass_extends.cc | |||
| ../optimizer/common/pass_manager_extends.cc | |||
| ../optimizer/common/gllo_utils.cc | |||
| ../optimizer/common/format_utils.cc | |||
| ../optimizer/fusion/conv_biasadd_fusion.cc | |||
| ../optimizer/fusion/conv_activation_fusion.cc | |||
| ../optimizer/fusion/conv_tuple_activation_fusion.cc | |||
| @@ -81,6 +82,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/functionalize_cond.cc | |||
| ../optimizer/graph/inputs_adjust_pass.cc | |||
| ../optimizer/graph/primitive_adjust_pass.cc | |||
| ../optimizer/graph/unify_format_pass.cc | |||
| ../optimizer/graph/node_infershape.cc | |||
| ../optimizer/graph/transpose_strategy.cc | |||
| ) | |||
| add_subdirectory(../anf_exporter anf_exporter) | |||
| @@ -61,6 +61,7 @@ | |||
| #include "tools/optimizer/graph/if_pass.h" | |||
| #include "tools/optimizer/graph/functionalize_control_op_pass.h" | |||
| #include "tools/optimizer/graph/inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -71,7 +72,8 @@ AnfTransform::AnfTransform() = default; | |||
| AnfTransform::~AnfTransform() = default; | |||
| int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) { | |||
| int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| // for now - training is not supporting fuse operations | |||
| @@ -106,24 +108,20 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||
| fusion_pm->AddPass(remove_unused_cast_pass); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | |||
| if (remove_unused_transpose_pass == nullptr) { | |||
| MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | |||
| fusion_pm->AddPass(remove_unused_transpose_pass); | |||
| } | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||
| if (!config->trainModel) { | |||
| fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>()); | |||
| } | |||
| optimizer->AddPassManager(fusion_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| MS_LOG(ERROR) << "run op fusion failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) { | |||
| int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | |||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || | |||
| config->fmk == lite::converter::FmkType_ONNX) { | |||
| @@ -134,8 +132,6 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim | |||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | |||
| weight_format_hardcode_pass->SetQuantType(config->quantType); | |||
| graph_pm->AddPass(weight_format_hardcode_pass); | |||
| auto conv1d_weight_expanding_pass = std::make_shared<opt::Conv1DWeightExpandingPass>(); | |||
| graph_pm->AddPass(conv1d_weight_expanding_pass); | |||
| auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>(); | |||
| weight_format_transform_pass->SetFmkType(config->fmk); | |||
| weight_format_transform_pass->SetQuantType(config->quantType); | |||
| @@ -144,11 +140,15 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim | |||
| slice_prepose_pass->SetFmkType(config->fmk); | |||
| graph_pm->AddPass(slice_prepose_pass); | |||
| optimizer->AddPassManager(graph_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| MS_LOG(ERROR) << "run graph pass failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | |||
| if (config->fmk == lite::converter::FmkType_TFLITE) { | |||
| @@ -156,11 +156,15 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt | |||
| convert_pm->AddPass(std::make_shared<opt::TfliteInputsAdjustPass>()); | |||
| } | |||
| optimizer->AddPassManager(convert_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| MS_LOG(ERROR) << "run graph convert pass failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||
| if (!config->trainModel) { | |||
| @@ -179,6 +183,10 @@ int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &o | |||
| infershape_pass->SetFmkType(config->fmk); | |||
| const_fold_pm->AddPass(infershape_pass); | |||
| optimizer->AddPassManager(const_fold_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| MS_LOG(ERROR) << "run const fold failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -203,12 +211,18 @@ int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter:: | |||
| } | |||
| } | |||
| int AnfTransform::AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| int AnfTransform::RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto conv1d_pm = std::make_shared<opt::PassManager>("conv1d adjust pass manager", true); | |||
| conv1d_pm->AddPass(std::make_shared<opt::Conv1DInOutAdjustPass>()); | |||
| conv1d_pm->AddPass(std::make_shared<opt::SqueezeFusion>()); | |||
| auto conv1d_weight_expanding_pass = std::make_shared<opt::Conv1DWeightExpandingPass>(); | |||
| conv1d_pm->AddPass(conv1d_weight_expanding_pass); | |||
| optimizer->AddPassManager(conv1d_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| MS_LOG(ERROR) << "run conv1d adjust failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -276,18 +290,17 @@ int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converte | |||
| return RET_OK; | |||
| } | |||
| int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, | |||
| const FuncGraphPtr &new_graph) { | |||
| int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| // quant | |||
| if (config->quantType == schema::QuantType_PostTraining) { | |||
| this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum); | |||
| this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum); | |||
| if (m_quantizer_ == nullptr) { | |||
| MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| return RET_ERROR; | |||
| } | |||
| } else if (config->quantType == schema::QuantType_WeightQuant) { | |||
| this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(new_graph, *config); | |||
| this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config); | |||
| if (m_quantizer_ == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||
| @@ -296,7 +309,7 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||
| } | |||
| if (m_quantizer_ != nullptr) { | |||
| m_quantizer_->flags = *config; | |||
| auto status = m_quantizer_->DoQuantize(new_graph); | |||
| auto status = m_quantizer_->DoQuantize(old_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Quant failed " << status; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -306,70 +319,72 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| MS_ASSERT(nullptr != old_graph); | |||
| if (config == nullptr) { | |||
| MS_LOG(ERROR) << "config should be specified"; | |||
| return nullptr; | |||
| } | |||
| int status; | |||
| for (auto &fg : func_graphs_) { | |||
| status = RunPrecedingPass(fg, *config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Preceding pass failed."; | |||
| return nullptr; | |||
| } | |||
| auto status = RunPrecedingPass(old_graph, *config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Preceding pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunAdjustPass(old_graph, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Adjust pass failed."; | |||
| return nullptr; | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| status = RunAdjustPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Adjust pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = AddConstFoldPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add const fold pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunConstFoldPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run const fold pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = AddConvertPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add convert pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunConvertPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run convert pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = AddFusionPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add fusion pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = RunFusionPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run fusion pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = AddGraphPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add graph pass failed."; | |||
| return nullptr; | |||
| status = RunConv1DAdjustPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run conv1d adjust pass failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| status = AddConv1DAdjustPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add conv1d adjust pass failed."; | |||
| auto format_pass = std::make_shared<opt::UnifyFormatPass>(); | |||
| format_pass->Init(config->fmk, config->trainModel); | |||
| if (!format_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "Run format pass failed."; | |||
| return nullptr; | |||
| } | |||
| auto new_graph = optimizer->Optimize(old_graph); | |||
| if (new_graph == nullptr) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| } | |||
| for (auto &fg : func_graphs_) { | |||
| status = RunGraphPass(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run convert pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = DoQuantize(old_graph, config, new_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Do Quantize failed."; | |||
| return nullptr; | |||
| status = DoQuantize(fg, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Do Quantize failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return new_graph; | |||
| return old_graph; | |||
| } | |||
| void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) { | |||
| @@ -401,15 +416,11 @@ void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) { | |||
| FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { | |||
| GetAllFuncGraph(main_graph); | |||
| for (auto &fg : func_graphs_) { | |||
| auto new_main_graph = TransformSingleFuncGraph(fg, config); | |||
| if (new_main_graph == nullptr) { | |||
| MS_LOG(ERROR) << "TransformSingleFuncGraph failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| auto new_graph = TransformFuncGraph(main_graph, config); | |||
| if (new_graph == nullptr) { | |||
| MS_LOG(ERROR) << "optimizer failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| } | |||
| return main_graph; | |||
| return new_graph; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -39,19 +39,19 @@ class AnfTransform { | |||
| private: | |||
| std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr; | |||
| FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | |||
| FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); | |||
| static int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static int AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static int AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config); | |||
| static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config); | |||
| static int AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); | |||
| static int RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| @@ -61,7 +61,7 @@ class AnfTransform { | |||
| static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); | |||
| int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config); | |||
| void GetAllFuncGraph(const FuncGraphPtr &func_graph); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -94,14 +94,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer format_trans_optimizer; | |||
| auto format_trans_pass = new (std::nothrow) FormatTransPass(); | |||
| if (format_trans_pass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| format_trans_pass->set_quant_type(ctx.quantType); | |||
| format_trans_pass->set_fmk_type(ctx.fmk); | |||
| format_trans_optimizer.AddPass(format_trans_pass); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| if (ctx.fmk != converter::FmkType_TF) { | |||
| @@ -117,11 +109,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer format_trans_optimizer; | |||
| format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = format_trans_optimizer.Run(graph_defT_); | |||
| @@ -1,225 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define MATMUL_BIASADD_MATCH_PATH_LEN 2 | |||
| #define BIASADD_OP_BIAS_INDEX 1 | |||
| #define BIASADD_OP_INPUT_NUM 2 | |||
| STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| STATUS MatMulBiasAddFusionPass::DefinePattern() { | |||
| auto matMulOp = std::make_shared<PatternOp>(); | |||
| matMulOp->id = MATMUL_NAME; | |||
| matMulOp->types = {schema::PrimitiveType_MatMul}; | |||
| auto baOp = std::make_shared<PatternOp>(); | |||
| baOp->id = BIASADD_NAME; | |||
| baOp->types = {schema::PrimitiveType_BiasAdd}; | |||
| baOp->left = matMulOp; | |||
| std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion")); | |||
| if (fusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new fusionPattern failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fusionPattern->AddPatternOp(matMulOp); | |||
| fusionPattern->AddPatternOp(baOp); | |||
| fusionPattern->Finish(); | |||
| this->patterns.emplace_back(fusionPattern.release()); | |||
| return RET_OK; | |||
| } | |||
| STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != MATMUL_BIASADD_MATCH_PATH_LEN) { | |||
| MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto matMulPath = matchedPath[MATMUL_NAME]; | |||
| auto baPath = matchedPath[BIASADD_NAME]; | |||
| auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); | |||
| auto &baNode = graph->nodes.at(baPath->nodeIdx); | |||
| // can not check shape because there is now shape infer in converter | |||
| MS_ASSERT(matMulNode != nullptr); | |||
| MS_ASSERT(matMulNode->inputIndex.size() == 2); | |||
| // biasadd node the second tensor is not constant tensor, don't fusion | |||
| auto baNodeInputIndex = baNode->inputIndex; | |||
| if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { | |||
| MS_LOG(ERROR) << "input num is invalid! node: " << baNode->name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); | |||
| const auto &baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); | |||
| MS_ASSERT(baNodeBiasTensor != nullptr); | |||
| if (baNodeBiasTensor->refCount != NodeType_ValueNode) { | |||
| // dont fusion, return | |||
| return RET_OK; | |||
| } | |||
| // 1. add biasTensor for matMul | |||
| auto status = AddFullConnectionBiasTensor(matMulPath, baPath, graph); | |||
| if (RET_OK != status) { | |||
| MS_LOG(ERROR) << "AddFullConnectionBiasTensor failed, ret: " << status; | |||
| return status; | |||
| } | |||
| // 2. change matmul to full connection op | |||
| matMulNode->name += "-fc"; | |||
| std::unique_ptr<FullConnectionT> fcAttr(new (std::nothrow) FullConnectionT()); | |||
| if (fcAttr == nullptr) { | |||
| MS_LOG(ERROR) << "new FullConnectionT node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fcAttr->has_bias = true; | |||
| fcAttr->axis = 1; | |||
| MS_ASSERT(matMulNode->primitive != nullptr); | |||
| MS_ASSERT(matMulNode->primitive->value != nullptr); | |||
| MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr); | |||
| transA = matMulNode->primitive->value.AsMatMul()->transpose_a; | |||
| transB = matMulNode->primitive->value.AsMatMul()->transpose_b; | |||
| matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection; | |||
| matMulNode->primitive->value.value = fcAttr.release(); | |||
| // 3. delete BiasAdd node | |||
| MergeNodeAttrFromPost(matMulNode, baNode); | |||
| status = IsolateOneWayNode(graph, baPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << baPath->subGraphIdx << ", node: " << baPath->nodeIdx | |||
| << ", ret: " << status; | |||
| return status; | |||
| } | |||
| // 4. addTranspose node | |||
| status = InsertTransposeNode(graph, matMulPath); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertTransposeNode failed, subGraph: " << matMulPath->subGraphIdx | |||
| << ", node: " << matMulPath->nodeIdx << ", ret: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std::shared_ptr<Path> &matMulPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(matMulPath != nullptr); | |||
| std::vector<size_t> insertNodeIdxList; | |||
| if (transA) { | |||
| insertNodeIdxList.emplace_back(0); | |||
| } | |||
| if (!transB) { | |||
| insertNodeIdxList.emplace_back(1); | |||
| } | |||
| auto matmulOpIter = graph->nodes.begin() + matMulPath->nodeIdx; | |||
| STATUS errorCode = RET_OK; | |||
| auto perm_tensor = std::make_unique<schema::TensorT>(); | |||
| perm_tensor->dataType = kNumberTypeInt32; | |||
| perm_tensor->dims = {2}; | |||
| std::vector<int> perm{1, 0}; | |||
| size_t bytes = perm.size() * sizeof(int); | |||
| perm_tensor->data.resize(bytes); | |||
| perm_tensor->name = "perm_" + std::to_string(id++); | |||
| if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| size_t index = graph->allTensors.size(); | |||
| graph->allTensors.push_back(std::move(perm_tensor)); | |||
| for (auto needInsertIdx : insertNodeIdxList) { | |||
| auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | |||
| if (transNode == nullptr) { | |||
| MS_LOG(ERROR) << "new TransNode failed"; | |||
| return RET_ERROR; | |||
| } | |||
| transNode->name = "transpose" + std::to_string(id++); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| int insert_num = 0; | |||
| matmulOpIter = InsertNode(graph, matmulOpIter, kBefore, needInsertIdx, std::move(transNode), &errorCode, | |||
| &insert_num, TransposeOpCopyer); | |||
| if (errorCode != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNode failed: " << errorCode; | |||
| return errorCode; | |||
| } | |||
| for (int i = insert_num; i > 0; --i) { | |||
| (*(matmulOpIter - i))->inputIndex.push_back(index); | |||
| } | |||
| } | |||
| graph->allTensors.at(index)->refCount = insertNodeIdxList.size(); | |||
| return RET_OK; | |||
| } | |||
| #define BIASADD_WEIGHT_SHAPE_SIZE 1 | |||
| #define BIASADD_BIAS_DIM_INDEX 0 | |||
| STATUS MatMulBiasAddFusionPass::AddFullConnectionBiasTensor(const std::shared_ptr<Path> &matMulPath, | |||
| const std::shared_ptr<Path> &baPath, MetaGraphT *graph) { | |||
| MS_ASSERT(matMulPath != nullptr); | |||
| MS_ASSERT(baPath != nullptr); | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(graph->nodes.size() > matMulPath->nodeIdx); | |||
| auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); | |||
| MS_ASSERT(matMulNode != nullptr); | |||
| auto baNode = graph->nodes.at(baPath->nodeIdx).get(); | |||
| MS_ASSERT(baNode != nullptr); | |||
| // check biasTensor | |||
| auto baWeightTensorIdxes = baNode->inputIndex; | |||
| if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { | |||
| MS_LOG(ERROR) << "input number is invalid! node: " << baNode->name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); | |||
| auto &biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); | |||
| MS_ASSERT(biasTensor != nullptr); | |||
| auto biasDims = biasTensor->dims; | |||
| // if biasTensor is a scaler | |||
| if (biasDims.empty() && biasTensor->data.data() == nullptr) { | |||
| MS_LOG(ERROR) << "bias tensor is invalid, node: " << baNode->name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { | |||
| MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size() | |||
| << ". or bias tensor is a scaler"; | |||
| return RET_ERROR; | |||
| } | |||
| // add biasTensor to matmul | |||
| matMulNode->inputIndex.emplace_back(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); | |||
| baNode->inputIndex.erase(baNode->inputIndex.begin() + BIASADD_OP_BIAS_INDEX); | |||
| return RET_OK; | |||
| } | |||
| MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,75 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr const char *MATMUL_NAME = "MATMUL"; | |||
| class MatMulBiasAddFusionPass : public FusionPass { | |||
| public: | |||
| MatMulBiasAddFusionPass() = default; | |||
| ~MatMulBiasAddFusionPass() override; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| protected: | |||
| static STATUS AddFullConnectionBiasTensor(const std::shared_ptr<Path> &matMulPath, | |||
| const std::shared_ptr<Path> &dstPath, MetaGraphT *subGraph); | |||
| STATUS InsertTransposeNode(MetaGraphT *subGraph, const std::shared_ptr<Path> &matMulPath); | |||
| protected: | |||
| bool transA = false; | |||
| bool transB = false; | |||
| size_t id = 0; | |||
| OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { | |||
| auto newOpDef = std::make_unique<schema::CNodeT>(); | |||
| if (newOpDef == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->name = inOpDef->name; | |||
| newOpDef->quantType = inOpDef->quantType; | |||
| newOpDef->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (newOpDef->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| return newOpDef; | |||
| }; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H | |||
| @@ -0,0 +1,184 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "ops/adam.h" | |||
| #include "ops/addn.h" | |||
| #include "ops/apply_momentum.h" | |||
| #include "ops/batch_norm.h" | |||
| #include "ops/batch_to_space.h" | |||
| #include "ops/bias_add.h" | |||
| #include "ops/concat.h" | |||
| #include "ops/crop.h" | |||
| #include "ops/depth_to_space.h" | |||
| #include "ops/fusion/activation.h" | |||
| #include "ops/fusion/add_fusion.h" | |||
| #include "ops/fused_batch_norm.h" | |||
| #include "ops/fusion/avg_pool_fusion.h" | |||
| #include "ops/fusion/conv2d_backprop_input_fusion.h" | |||
| #include "ops/fusion/conv2d_backprop_filter_fusion.h" | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/conv2d_transpose_fusion.h" | |||
| #include "ops/fusion/max_pool_fusion.h" | |||
| #include "ops/fusion/mul_fusion.h" | |||
| #include "ops/fusion/pow_fusion.h" | |||
| #include "ops/fusion/prelu_fusion.h" | |||
| #include "ops/fusion/slice_fusion.h" | |||
| #include "ops/fusion/topk_fusion.h" | |||
| #include "ops/eltwise.h" | |||
| #include "ops/grad/activation_grad.h" | |||
| #include "ops/grad/avg_pool_grad.h" | |||
| #include "ops/grad/batch_norm_grad.h" | |||
| #include "ops/grad/bias_add_grad.h" | |||
| #include "ops/grad/max_pool_grad.h" | |||
| #include "ops/grad/resize_grad.h" | |||
| #include "ops/instance_norm.h" | |||
| #include "ops/lrn.h" | |||
| #include "ops/maximum.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/quant_dtype_cast.h" | |||
| #include "ops/resize.h" | |||
| #include "ops/sgd.h" | |||
| #include "ops/space_to_batch.h" | |||
| #include "ops/space_to_batch_nd.h" | |||
| #include "ops/space_to_depth.h" | |||
| #include "ops/split.h" | |||
| #include "ops/strided_slice.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = { | |||
| {ops::kNameAdam, {10}}, | |||
| {ops::kNameApplyMomentum, {4}}, | |||
| {ops::kNameAvgPoolFusion, {1}}, | |||
| {ops::kNameAvgPoolGrad, {}}, | |||
| {ops::kNameBatchNorm, {1}}, | |||
| {ops::kNameBatchNormGrad, {1, 2}}, | |||
| {ops::kNameBatchToSpace, {1}}, | |||
| {ops::kNameBiasAdd, {1}}, | |||
| {ops::kNameBiasAddGrad, {1}}, | |||
| {ops::kNameConv2DBackpropInputFusion, {1}}, | |||
| {ops::kNameConv2DBackpropFilterFusion, {1, 2}}, | |||
| {ops::kNameConv2DFusion, {1}}, | |||
| {ops::kNameConv2dTransposeFusion, {1}}, | |||
| {ops::kNameDepthToSpace, {1}}, | |||
| {ops::kNameFusedBatchNorm, {1}}, | |||
| {ops::kNameLRN, {1}}, | |||
| {ops::kNameMaxPoolFusion, {1}}, | |||
| {ops::kNameMaxPoolGrad, {}}, | |||
| {ops::kNamePReLUFusion, {1}}, | |||
| {ops::kNameResize, {1}}, | |||
| {ops::kNameResizeGrad, {}}, | |||
| {ops::kNameSGD, {2}}, | |||
| {ops::kNameSpaceToBatch, {1}}, | |||
| {ops::kNameSpaceToBatchND, {1}}, | |||
| {ops::kNameSpaceToDepth, {1}}, | |||
| {ops::kNameTopKFusion, {1}}}; | |||
| static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}}; | |||
| // a certain op whose input's format is not fixed. | |||
| static const std::vector<std::string> DynamicFormatOpList = { | |||
| ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNamePowFusion, ops::kNameStridedSlice, | |||
| ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, ops::kNameCrop, | |||
| ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; | |||
| static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; } | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; } | |||
| const std::unordered_map<int, int> &GetNC2NHAxisMap() { return NC2NHAxisMap; } | |||
| const std::vector<std::string> &GetDynamicFormatOpList() { return DynamicFormatOpList; } | |||
| Format GetFormat(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| MS_ASSERT(prim_node != nullptr); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| Format format = NHWC; | |||
| if (prim->GetAttr(ops::kFormat) != nullptr) { | |||
| format = static_cast<Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat))); | |||
| } | |||
| return format; | |||
| } | |||
| STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm) { | |||
| MS_ASSERT(perm_node != nullptr); | |||
| if (!utils::isa<ParameterPtr>(perm_node)) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto perm_param = perm_node->cast<ParameterPtr>(); | |||
| if (!perm_param->has_default() || perm_param->default_param() == nullptr) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto tensor_info = perm_param->default_param()->cast<tensor::TensorPtr>(); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "default param is not a tensor."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (tensor_info->data_type() != kNumberTypeInt && tensor_info->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "data type is error, which is " << tensor_info->data_type(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto tensor_shape = tensor_info->shape(); | |||
| if (tensor_shape.empty()) { | |||
| return lite::RET_OK; | |||
| } | |||
| if (tensor_shape.size() > 1) { | |||
| return lite::RET_ERROR; | |||
| } | |||
| perm->resize(tensor_shape[0]); | |||
| if (memcpy_s(perm->data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void RemoveIfMonad(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| std::vector<AnfNodePtr> inputs{cnode->input(0)}; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||
| auto value_node = cnode->input(i)->cast<ValueNodePtr>(); | |||
| auto value = value_node->value(); | |||
| if (value->isa<Monad>()) { | |||
| continue; | |||
| } | |||
| } | |||
| inputs.push_back(cnode->input(i)); | |||
| } | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| bool IsMonadNode(const AnfNodePtr &node) { | |||
| if (!utils::isa<ValueNodePtr>(node)) { | |||
| return false; | |||
| } | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| auto value = value_node->value(); | |||
| if (value->isa<Monad>()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr auto kInferDone = "infer_done"; | |||
| constexpr auto kTransDone = "trans_done"; | |||
| enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; | |||
| struct TransTypePair { | |||
| FormatTransNodeType pre_; | |||
| FormatTransNodeType post_; | |||
| TransTypePair() : pre_(kNONE), post_(kNONE) {} | |||
| }; | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap(); | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap(); | |||
| const std::unordered_map<int, int> &GetNC2NHAxisMap(); | |||
| const std::vector<std::string> &GetDynamicFormatOpList(); | |||
| Format GetFormat(const CNodePtr &cnode); | |||
| STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm); | |||
| void RemoveIfMonad(const CNodePtr &cnode); | |||
| bool IsMonadNode(const AnfNodePtr &node); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_FORMAT_UTILS_H_ | |||
| @@ -22,6 +22,8 @@ | |||
| #include <string> | |||
| #include "Eigen/Core" | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/transpose.h" | |||
| #include "ops/tuple_get_item.h" | |||
| #include "src/common/common.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "frontend/operator/ops.h" | |||
| @@ -1351,5 +1353,27 @@ ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const | |||
| } | |||
| return param_node; | |||
| } | |||
| CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm, | |||
| const std::string &cnode_name) { | |||
| MS_ASSERT(func_graph != nullptr && input_node != nullptr); | |||
| auto perm_node = BuildIntVecParameterNode(func_graph, perm, cnode_name + "_perm"); | |||
| MS_ASSERT(perm_node != nullptr); | |||
| auto trans_prim = std::make_shared<ops::Transpose>(); | |||
| MS_ASSERT(trans_prim != nullptr); | |||
| auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node}); | |||
| MS_ASSERT(cnode != nullptr); | |||
| cnode->set_fullname_with_scope(cnode_name); | |||
| return cnode; | |||
| } | |||
| CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && input != nullptr); | |||
| auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>(); | |||
| auto second_input = NewValueNode(MakeValue<int>(index)); | |||
| auto tuple_cnode = func_graph->NewCNode(tuple_get_item_prim, {input, second_input}); | |||
| tuple_cnode->set_fullname_with_scope(input->fullname_with_scope() + "_getitem_" + std::to_string(index)); | |||
| return tuple_cnode; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "src/common/utils.h" | |||
| #include "backend/optimizer/common/pattern_engine.h" | |||
| #include "ops/fusion/conv2d_backprop_input_fusion.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_context.h" | |||
| @@ -36,6 +37,7 @@ namespace mindspore { | |||
| namespace opt { | |||
| inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple"); | |||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity"); | |||
| const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion); | |||
| constexpr auto kWeightFormat = "weight_format"; | |||
| std::vector<int> CastToInt(const ValuePtr &value); | |||
| @@ -146,6 +148,11 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st | |||
| ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, | |||
| const std::string &node_name); | |||
| CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &perm, | |||
| const std::string &cnode_name); | |||
| CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input, size_t index); | |||
| template <const PrimitivePtr *prim = nullptr> | |||
| inline bool IsSpecifiedNode(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| @@ -68,10 +68,12 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { | |||
| (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) { | |||
| continue; | |||
| } | |||
| matmul_cnode->add_input(bias_node); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); | |||
| auto tr = manager->Transact(); | |||
| tr.AddEdge(matmul_cnode, bias_node); | |||
| tr.Commit(); | |||
| manager->Replace(node, matmul_cnode); | |||
| } | |||
| return false; | |||
| @@ -23,12 +23,17 @@ namespace { | |||
| constexpr size_t kTripleNum = 3; | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const tensor::TensorPtr &tensor, | |||
| const schema::Format &format) { | |||
| if (tensor == nullptr) { | |||
| return lite::RET_NULL_PTR; | |||
| lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const AnfNodePtr &weight_node, const schema::Format &format) { | |||
| MS_ASSERT(weight_node != nullptr); | |||
| auto weight_tensor = GetTensorInfo(weight_node); | |||
| if (weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "weight node must be param value."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape = weight_tensor->shape(); | |||
| if (shape.size() != kTripleNum) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto shape = tensor->shape(); | |||
| std::vector<int64_t> new_shape(shape); | |||
| switch (format) { | |||
| case schema::Format_NCHW: | |||
| @@ -43,7 +48,13 @@ lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const tensor::TensorPt | |||
| MS_LOG(ERROR) << "Unsupported format."; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->set_shape(new_shape); | |||
| weight_tensor->set_shape(new_shape); | |||
| if (!utils::isa<ParameterPtr>(weight_node)) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto weight_param = weight_node->cast<ParameterPtr>(); | |||
| auto type = weight_tensor->data_type(); | |||
| weight_param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), new_shape)); | |||
| return RET_OK; | |||
| } | |||
| @@ -62,25 +73,18 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | |||
| auto weight_node = conv_cnode->input(kConvWeightIndex); | |||
| MS_ASSERT(weight_node != nullptr); | |||
| auto weight_value = GetTensorInfo(weight_node); | |||
| if (weight_value == nullptr) { | |||
| MS_LOG(ERROR) << "weight node must be param value."; | |||
| return false; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| schema::Format schema_format = schema::Format::Format_KCHW; | |||
| if (prim->GetAttr(opt::kWeightFormat) != nullptr) { | |||
| schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat))); | |||
| } | |||
| // expand weight tensor to 4 dimensions. | |||
| if (weight_value->shape().size() == kTripleNum) { | |||
| auto status = ExpandFilterShape(weight_value, schema_format); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Expand filter shape failed."; | |||
| return false; | |||
| } | |||
| auto status = ExpandFilterShape(weight_node, schema_format); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Expand filter shape failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -30,7 +30,7 @@ class Conv1DWeightExpandingPass : public Pass { | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| lite::STATUS ExpandFilterShape(const tensor::TensorPtr &tensor, const schema::Format &format); | |||
| lite::STATUS ExpandFilterShape(const AnfNodePtr &weight_node, const schema::Format &format); | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_ | |||
| @@ -0,0 +1,529 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/node_infershape.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "src/ops/ops_utils.h" | |||
| #include "src/runtime/infer_manager.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t INITIAL_SIZE = 1024; | |||
| void FreeTensors(std::vector<lite::Tensor *> *tensors) { | |||
| if (tensors == nullptr) { | |||
| return; | |||
| } | |||
| for (auto &v : *tensors) { | |||
| delete v; | |||
| v = nullptr; | |||
| } | |||
| tensors->resize(0); | |||
| } | |||
| void SetConvWeightFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) && | |||
| !CheckPrimitiveType(cnode, kPrimConv2DBackpropInputFusion) && | |||
| !CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| return; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(kWeightFormat) != nullptr && inputs.size() > 1) { | |||
| inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)))); | |||
| } | |||
| } | |||
| bool DuceInferFlag(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| for (auto &input : inputs) { | |||
| auto shape = input->shape(); | |||
| if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { | |||
| if (fmk_type == lite::converter::FmkType_ONNX && shape.size() == 4 && shape[3] == 3 && shape[1] == -1) { | |||
| input->set_format(schema::Format_NHWC); | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(cnode); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (!utils::isa<CNodePtr>(cnode->input(i))) { | |||
| continue; | |||
| } | |||
| auto input_cnode = cnode->input(i)->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimTupleGetItem)) { | |||
| input_cnode = input_cnode->input(1)->cast<CNodePtr>(); | |||
| } | |||
| if (input_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "input is not cnode."; | |||
| cnode->set_inputs(origin_inputs); | |||
| return false; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| if (prim == nullptr || prim->GetAttr(kInferDone) == nullptr) { | |||
| MS_LOG(ERROR) << "prim is invalid."; | |||
| cnode->set_inputs(origin_inputs); | |||
| return false; | |||
| } | |||
| if (!GetValue<bool>(prim->GetAttr(kInferDone))) { | |||
| cnode->set_inputs(origin_inputs); | |||
| return false; | |||
| } | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| return true; | |||
| } | |||
| tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { | |||
| std::vector<int> shape(tensor->shape()); | |||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||
| auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "new tensor::Tensor failed"; | |||
| return nullptr; | |||
| } | |||
| return tensor_info; | |||
| } | |||
| } // namespace | |||
| STATUS NodeInferShape::InferShape(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto anf_prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0)); | |||
| if (anf_prim == nullptr) { | |||
| MS_LOG(DEBUG) << "primitive is nullptr"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| anf_prim->AddAttr(kInferDone, MakeValue<bool>(false)); | |||
| std::vector<lite::Tensor *> inputs; | |||
| std::vector<lite::Tensor *> outputs; | |||
| if (GetCNodeInputTensors(cnode, &inputs) != lite::RET_OK) { | |||
| FreeTensors(&inputs); | |||
| MS_LOG(ERROR) << "get inputs failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| SetConvWeightFormat(cnode, inputs); | |||
| if (GetCNodeOutputTensors(cnode, &outputs) != lite::RET_OK) { | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| MS_LOG(ERROR) << "get outputs failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim_t = lite::GetPrimitiveT(cnode->input(0)); | |||
| if (prim_t == nullptr) { | |||
| MS_LOG(DEBUG) << "prim_t is nullptr"; | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| return lite::RET_ERROR; | |||
| } | |||
| flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE); | |||
| auto prim = lite::ConvertToPrimitive(prim_t, &fbb); | |||
| delete prim_t; | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "get primitive failed."; | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| fbb.Clear(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR); | |||
| if (parameter_gen == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| fbb.Clear(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto parameter = parameter_gen(prim); | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "parameter is nullptr."; | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| fbb.Clear(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| parameter->infer_flag_ = DuceInferFlag(cnode, inputs, fmk_type_); | |||
| auto status = KernelInferShape(inputs, &outputs, parameter); | |||
| if (status == lite::RET_OK) { | |||
| anf_prim->AddAttr(kInferDone, MakeValue<bool>(true)); | |||
| } | |||
| if (status == lite::RET_OK || status == lite::RET_INFER_INVALID) { | |||
| auto set_status = SetCNodeAbstract(cnode, outputs); | |||
| if (set_status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); | |||
| return set_status; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| } | |||
| FreeTensors(&inputs); | |||
| FreeTensors(&outputs); | |||
| free(parameter); | |||
| fbb.Clear(); | |||
| return status; | |||
| } | |||
| std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (index >= cnode->size()) { | |||
| return {}; | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]}; | |||
| cnode->set_inputs(specify_inputs); | |||
| std::vector<lite::Tensor *> specify_tensors; | |||
| if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) { | |||
| cnode->set_inputs(origin_inputs); | |||
| return {}; | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| auto shape = specify_tensors.front()->shape(); | |||
| FreeTensors(&specify_tensors); | |||
| return shape; | |||
| } | |||
| std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t index) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (index >= cnode->size()) { | |||
| return {}; | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]}; | |||
| cnode->set_inputs(specify_inputs); | |||
| std::vector<lite::Tensor *> specify_tensors; | |||
| if (GetCNodeInputTensors(cnode, &specify_tensors) != lite::RET_OK || specify_tensors.empty()) { | |||
| cnode->set_inputs(origin_inputs); | |||
| return {}; | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| std::vector<int> tensor_data; | |||
| if (specify_tensors.front()->data_type() != kNumberTypeInt32 && | |||
| specify_tensors.front()->data_type() != kNumberTypeInt) { | |||
| FreeTensors(&specify_tensors); | |||
| return {}; | |||
| } | |||
| if (specify_tensors.front()->shape().size() != 1) { | |||
| FreeTensors(&specify_tensors); | |||
| return {}; | |||
| } | |||
| tensor_data.resize(specify_tensors.front()->shape()[0]); | |||
| if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data_c(), | |||
| tensor_data.size() * sizeof(int)) != EOK) { | |||
| FreeTensors(&specify_tensors); | |||
| return {}; | |||
| } | |||
| return tensor_data; | |||
| } | |||
| STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(inputs != nullptr); | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(cnode); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| RemoveIfMonad(cnode); | |||
| std::vector<lite::Tensor *> const_inputs; | |||
| if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get const inputs failed."; | |||
| FreeTensors(&const_inputs); | |||
| cnode->set_inputs(origin_inputs); | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::vector<lite::Tensor *> var_inputs; | |||
| if (GetCNodeVarInput(cnode, &var_inputs) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get var inputs failed."; | |||
| FreeTensors(&var_inputs); | |||
| cnode->set_inputs(origin_inputs); | |||
| return lite::RET_ERROR; | |||
| } | |||
| size_t const_index = 0; | |||
| size_t var_index = 0; | |||
| bool input_valid = true; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| if (var_index >= var_inputs.size()) { | |||
| MS_LOG(ERROR) << "var inputs size invalid."; | |||
| input_valid = false; | |||
| break; | |||
| } | |||
| inputs->emplace_back(var_inputs[var_index++]); | |||
| } else { | |||
| if (const_index >= const_inputs.size()) { | |||
| MS_LOG(ERROR) << "const inputs size invalid."; | |||
| input_valid = false; | |||
| break; | |||
| } | |||
| inputs->emplace_back(const_inputs[const_index++]); | |||
| } | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| if (!input_valid) { | |||
| FreeTensors(&const_inputs); | |||
| FreeTensors(&var_inputs); | |||
| inputs->resize(0); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto origin_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> const_inputs; | |||
| for (auto &input : origin_inputs) { | |||
| if (utils::isa<CNodePtr>(input)) { | |||
| continue; | |||
| } | |||
| const_inputs.push_back(input); | |||
| } | |||
| cnode->set_inputs(const_inputs); | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| meta_graph->fmkType = fmk_type_; | |||
| auto fb_node = std::make_unique<schema::CNodeT>(); | |||
| lite::AnfExporter anf_exporter; | |||
| anf_exporter.set_train_flag(train_flag_); | |||
| auto status = anf_exporter.SetOpInputNode(cnode, meta_graph, fb_node.get()); | |||
| cnode->set_inputs(origin_inputs); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get const inputs failed."; | |||
| return status; | |||
| } | |||
| return ConvertToLiteTensor(meta_graph, fb_node->inputIndex, const_ms_inputs); | |||
| } | |||
| STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(var_ms_inputs != nullptr); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (!utils::isa<CNodePtr>(cnode->input(i))) { | |||
| continue; | |||
| } | |||
| auto abstract = GetCNodeInputAbstract(cnode, i); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract cnode is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| MS_LOG(ERROR) << "Abstract should be anstract tensor."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(typePtr != nullptr); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||
| lite::Tensor *tensor = nullptr; | |||
| if (type_ptr->type_id() == kObjectTypeTensorType) { | |||
| tensor = GetCNodeTensorListVarInput(dims, abstract_tensor); | |||
| } else { | |||
| tensor = new (std::nothrow) lite::Tensor(TypeId(type_ptr->type_id()), dims); | |||
| } | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor failed"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| var_ms_inputs->emplace_back(tensor); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape, | |||
| const abstract::AbstractTensorPtr &abstract_tensor) { | |||
| MS_ASSERT(abstract_tensor != nullptr); | |||
| auto tensor_list = new (std::nothrow) lite::TensorList(shape, {}); | |||
| if (tensor_list == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor list failed"; | |||
| return nullptr; | |||
| } | |||
| auto tensor_info = abstract_tensor->GetValueTrack(); | |||
| if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "nsor list abstract is invalid."; | |||
| return nullptr; | |||
| } | |||
| auto tensor_value = tensor_info->cast<tensor::TensorPtr>(); | |||
| if (tensor_value->data_c() == nullptr) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "cannot get tensor list abstract's info."; | |||
| return nullptr; | |||
| } | |||
| auto status = tensor_list->Decode(static_cast<int *>(tensor_value->data_c())); | |||
| if (status != lite::RET_OK) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "decode tensor list failed."; | |||
| return nullptr; | |||
| } | |||
| return tensor_list; | |||
| } | |||
| STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(outputs != nullptr); | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| meta_graph->fmkType = fmk_type_; | |||
| auto fb_node = std::make_unique<schema::CNodeT>(); | |||
| lite::AnfExporter anf_exporter; | |||
| anf_exporter.set_train_flag(train_flag_); | |||
| anf_exporter.SetOpOutputNode(cnode, meta_graph, fb_node.get()); | |||
| return ConvertToLiteTensor(meta_graph, fb_node->outputIndex, outputs); | |||
| } | |||
| STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::vector<uint32_t> &tensor_indexes, | |||
| std::vector<lite::Tensor *> *tensors) { | |||
| MS_ASSERT(meta_graph != nullptr); | |||
| MS_ASSERT(tensors != nullptr); | |||
| for (auto index : tensor_indexes) { | |||
| auto tensor_t = meta_graph->allTensors.at(index).get(); | |||
| auto tensor_shape = tensor_t->dims; | |||
| auto tensor_category = lite::TensorCategory(tensor_t->nodeType, tensor_t->dims.size(), TypeId(tensor_t->dataType), | |||
| tensor_t->data.size()); | |||
| lite::Tensor *tensor = nullptr; | |||
| if (tensor_t->dataType != kObjectTypeTensorType) { | |||
| tensor = | |||
| new (std::nothrow) lite::Tensor(TypeId(tensor_t->dataType), tensor_shape, tensor_t->format, tensor_category); | |||
| } else { | |||
| tensor = new (std::nothrow) lite::TensorList(tensor_shape, std::vector<int>(), tensor_category); | |||
| } | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor failed"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto tensor_size = tensor_t->data.size() * sizeof(char); | |||
| if (tensor_size > 0) { | |||
| if (tensor_t->dataType == kObjectTypeTensorType) { | |||
| auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor); | |||
| if (tensor_list->Decode(reinterpret_cast<const int *>(tensor_t->data.data())) != RET_OK) { | |||
| MS_LOG(ERROR) << "Decode tensorlist data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| auto tensor_data = new (std::nothrow) char[tensor_size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| delete tensor; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (memcpy_s(tensor_data, tensor_size, tensor_t->data.data(), tensor_size) != EOK) { | |||
| delete tensor; | |||
| delete[](tensor_data); | |||
| MS_LOG(ERROR) << "memcpy error: "; | |||
| return lite::RET_ERROR; | |||
| } | |||
| tensor->set_data(tensor_data); | |||
| } | |||
| } | |||
| tensors->emplace_back(tensor); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS NodeInferShape::SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, | |||
| const std::vector<lite::Tensor *> &outputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (outputs.size() == 0) { | |||
| MS_LOG(ERROR) << "empty output_tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| auto origin_abstract = cnode->abstract(); | |||
| if (outputs.size() == 1 && !utils::isa<abstract::AbstractTuple>(origin_abstract)) { | |||
| auto tensor = outputs.front(); | |||
| auto new_abstract = ConvertLiteTensorToAbstract(tensor); | |||
| if (new_abstract == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| cnode->set_abstract(new_abstract); | |||
| } else { | |||
| AbstractBasePtrList abstract_list; | |||
| for (size_t i = 0; i < outputs.size(); i++) { | |||
| auto tensor = outputs.at(i); | |||
| auto new_abstract = ConvertLiteTensorToAbstract(tensor); | |||
| if (new_abstract == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| abstract_list.emplace_back(new_abstract); | |||
| } | |||
| cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| abstract::AbstractBasePtr NodeInferShape::ConvertLiteTensorToAbstract(lite::Tensor *tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| if (tensor->data_type() == kObjectTypeTensorType) { | |||
| return ConvertTensorListToAbstract(tensor); | |||
| } | |||
| auto tensor_info = NewTensorInfo(tensor); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "new tensor::Tensor failed"; | |||
| return nullptr; | |||
| } | |||
| return tensor_info->ToAbstract(); | |||
| } | |||
| // stract save tensorlist's type and shape. tensor_info save tensorlist's data and data type. | |||
| // both of them is different in term of shape and type. | |||
| abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tensor *tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| auto tensor_list = dynamic_cast<lite::TensorList *>(tensor); | |||
| if (tensor_list == nullptr) { | |||
| MS_LOG(ERROR) << "cast tensor_list failed"; | |||
| return nullptr; | |||
| } | |||
| std::vector<int> shape(tensor->shape()); | |||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||
| auto tensor_list_abstract = | |||
| std::make_shared<abstract::AbstractTensor>(TypeIdToType(tensor_list->data_type()), shape_vector); | |||
| if (tensor_list_abstract == nullptr) { | |||
| MS_LOG(ERROR) << "new AbstractTensor failed"; | |||
| return nullptr; | |||
| } | |||
| auto elememt_shape = tensor_list->element_shape(); | |||
| std::vector<int> data_info; | |||
| data_info.push_back(tensor_list->tensors_data_type()); | |||
| data_info.push_back(elememt_shape.size()); | |||
| std::copy(elememt_shape.begin(), elememt_shape.end(), std::back_inserter(data_info)); | |||
| data_info.push_back(tensor_list->tensors().size()); | |||
| for (size_t i = 0; i < tensor_list->tensors().size(); ++i) { | |||
| auto tensor_mem = tensor_list->tensors()[i]; | |||
| auto tensor_mem_shape = tensor_mem->shape(); | |||
| data_info.push_back(tensor_mem_shape.size()); | |||
| std::copy(tensor_mem_shape.begin(), tensor_mem_shape.end(), std::back_inserter(data_info)); | |||
| } | |||
| std::vector<int64_t> data_shape; | |||
| data_shape.push_back(data_info.size()); | |||
| auto tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "new tensor::Tensor failed"; | |||
| return nullptr; | |||
| } | |||
| tensor_list_abstract->set_value(tensor_info); | |||
| return tensor_list_abstract; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/tensor.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class NodeInferShape { | |||
| public: | |||
| NodeInferShape() = default; | |||
| virtual ~NodeInferShape() = default; | |||
| void Init(FmkType fmk_type, bool train_flag) { | |||
| fmk_type_ = fmk_type; | |||
| train_flag_ = train_flag; | |||
| } | |||
| STATUS InferShape(const CNodePtr &cnode); | |||
| std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index); | |||
| std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index); | |||
| private: | |||
| STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs); | |||
| STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs); | |||
| STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs); | |||
| lite::Tensor *GetCNodeTensorListVarInput(std::vector<int> shape, const abstract::AbstractTensorPtr &abstract_tensor); | |||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs); | |||
| STATUS ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::vector<uint32_t> &tensor_indexes, std::vector<lite::Tensor *> *tensors); | |||
| STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs); | |||
| abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor); | |||
| abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor); | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool train_flag_{false}; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_NODE_INFERSHAPE_H_ | |||
| @@ -22,8 +22,8 @@ | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t InputDoubleNum = 2; | |||
| constexpr size_t InputTripleNum = 3; | |||
| constexpr size_t kInputDoubleNum = 2; | |||
| constexpr size_t kInputTripleNum = 3; | |||
| void FetchCNodeFromMakeTuple(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *inputs) { | |||
| MS_ASSERT(anf_node != nullptr); | |||
| MS_ASSERT(inputs != nullptr); | |||
| @@ -45,14 +45,14 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(anf_node, kPrimIdentity)) { | |||
| if (cnode->size() != InputDoubleNum) { | |||
| if (cnode->size() != kInputDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) { | |||
| if (cnode->size() != InputDoubleNum) { | |||
| if (cnode->size() != kInputDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| @@ -106,7 +106,7 @@ int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->inputs().size() != InputTripleNum) { | |||
| if (cnode->inputs().size() != kInputTripleNum) { | |||
| MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -133,6 +133,45 @@ int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const | |||
| return lite::RET_OK; | |||
| } | |||
| int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { | |||
| MS_ASSERT(anf_node != nullptr); | |||
| MS_ASSERT(manager != nullptr); | |||
| if (!utils::isa<CNodePtr>(anf_node)) { | |||
| MS_LOG(DEBUG) << "anf node is node a cnode."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->size() > kInputDoubleNum) { | |||
| MS_LOG(ERROR) << "dropout input invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) { | |||
| MS_LOG(DEBUG) << "dropout output size is one."; | |||
| manager->Replace(anf_node, cnode->input(1)); | |||
| } else { | |||
| auto node_users = manager->node_users()[anf_node]; | |||
| for (auto &node_user : node_users) { | |||
| auto node = node_user.first; | |||
| if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| MS_LOG(ERROR) << "dropout out node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto get_index_node = node->cast<CNodePtr>()->input(kInputDoubleNum)->cast<ValueNodePtr>(); | |||
| if (get_index_node == nullptr) { | |||
| MS_LOG(ERROR) << "tuple get item node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto get_index = CastToInt(get_index_node->value()).front(); | |||
| if (get_index > 0 && !manager->node_users()[node].empty()) { | |||
| MS_LOG(ERROR) << "dropout's second output is useful."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| manager->Replace(node, cnode->input(1)); | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| @@ -155,6 +194,9 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| status = ReplaceTupleGetItem(node, manager); | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimDropout)) { | |||
| status = RemoveDropoutOp(node, manager); | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| @@ -31,6 +31,7 @@ class RemoveRedundantOpPass : public Pass { | |||
| int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node); | |||
| int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| int RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| @@ -0,0 +1,364 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/transpose_strategy.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ops/crop.h" | |||
| #include "ops/fusion/activation.h" | |||
| #include "ops/fusion/slice_fusion.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/strided_slice.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kFirstInput = 1; | |||
| constexpr size_t kTransposePerm = 2; | |||
| constexpr size_t kOnnxStridedSlice = 6; | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) { | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| } | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto node_users = manager->node_users()[cnode]; | |||
| if (node_users.empty()) { | |||
| MS_LOG(ERROR) << "cnode is isolated."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes), | |||
| [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; }); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm, bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; | |||
| // judge pair transpose after insert. | |||
| if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) { | |||
| std::vector<int> trans_perm; | |||
| auto input_cnode = trans_input_node->cast<CNodePtr>(); | |||
| if (input_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "input node is invalid."; | |||
| return nullptr; | |||
| } | |||
| if (GetTransposePerm(input_cnode->input(kTransposePerm), &trans_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transpose perm get failed."; | |||
| return nullptr; | |||
| } | |||
| if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) { | |||
| return input_cnode->input(kFirstInput); | |||
| } | |||
| } | |||
| // insert depend on shape | |||
| return TransposeDependOnShape(func_graph, cnode, perm, before, index); | |||
| } | |||
| AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm, bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; | |||
| auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index); | |||
| if (status == lite::RET_ERROR) { | |||
| return nullptr; | |||
| } else if (status == lite::RET_NO_CHANGE) { | |||
| return before ? cnode->input(index) : cnode; | |||
| } | |||
| // insert tranpsoe | |||
| std::string trans_name = | |||
| before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; | |||
| auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(); | |||
| quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0)); | |||
| trans_insert_prim->AddAttr("quant_params", quant_params_holder); | |||
| return trans_insert_node; | |||
| } | |||
| bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| TransTypePair *trans_info, TransTypePair *trans_insert_info) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| MS_ASSERT(pre_type != nullptr && post_type != nullptr); | |||
| size_t trans_count = 0; | |||
| std::vector<AnfNodePtr> in_nodes; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| in_nodes.push_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (!IsInOutCanFuison(func_graph, in_nodes, &trans_count, &trans_info->pre_)) { | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> out_nodes; | |||
| if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| if (!IsInOutCanFuison(func_graph, out_nodes, &trans_count, &trans_info->post_)) { | |||
| return false; | |||
| } | |||
| if (trans_info->pre_ == trans_info->post_) { | |||
| return false; | |||
| } | |||
| auto total_node_count = in_nodes.size() + out_nodes.size(); | |||
| bool can_insert = trans_count > total_node_count / 2; | |||
| if (CheckPrimitiveType(cnode, prim::kPrimActivation)) { | |||
| auto prim_act = GetValueNode<std::shared_ptr<ops::Activation>>(cnode->input(0)); | |||
| MS_ASSERT(prim_act != nullptr); | |||
| if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) { | |||
| can_insert = trans_count >= total_node_count / 2; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { | |||
| can_insert = trans_count >= total_node_count / 2; | |||
| } | |||
| if (!can_insert) { | |||
| return can_insert; | |||
| } | |||
| DecidePreAndPostTransType(trans_info, trans_insert_info); | |||
| return can_insert; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != 4) { | |||
| if (cnode->size() > 2) { | |||
| shape = node_infer_shape_.GetInputShape(cnode, 2); | |||
| if (shape.size() != 4 && !shape.empty()) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis)); | |||
| auto new_axis = axis_map[axis < 0 ? axis + 4 : axis]; | |||
| prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis)); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimCrop)) { | |||
| auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0)); | |||
| if (crop_prim == nullptr) { | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto axis = crop_prim->get_axis(); | |||
| auto offsets = crop_prim->get_offsets(); | |||
| auto new_axis = axis_map[axis < 0 ? axis + 4 : axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | |||
| } else if (new_axis == 3) { | |||
| offsets = {offsets[1], offsets[2], offsets[0]}; | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_offsets(offsets); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { | |||
| return ChangeOpSlice(func_graph, cnode); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { | |||
| return ChangeOpStrideSlice(func_graph, cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| } | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto node_users = manager->node_users()[cnode]; | |||
| if (node_users.empty()) { | |||
| MS_LOG(ERROR) << "cnode is isolated."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<CNodePtr>(node_users.front().first)) { | |||
| return lite::RET_ERROR; | |||
| } | |||
| CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>(); | |||
| size_t input_index = before ? index : node_users.front().second; | |||
| auto shape = node_infer_shape_.GetInputShape(base_node, input_index); | |||
| if (!shape.empty() && shape.size() != NH2NC.size()) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||
| size_t *trans_count, FormatTransNodeType *trans_type) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(trans_count != nullptr && trans_type != nullptr); | |||
| for (auto &node : nodes) { | |||
| if (CheckPrimitiveType(node, prim::kPrimTranspose)) { | |||
| FormatTransNodeType cur_type; | |||
| std::vector<int> perm; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (GetTransposePerm(cnode->input(kTransposePerm), &perm) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| if (perm == NH2NC) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else if (perm == NC2NH) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else { | |||
| return false; | |||
| } | |||
| if (*trans_type == kNONE) { | |||
| *trans_type = cur_type; | |||
| } else if (*trans_type != cur_type) { | |||
| return false; | |||
| } | |||
| *trans_count += 1; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) { | |||
| if (trans_info->pre_ == trans_info->post_) { | |||
| return; | |||
| } | |||
| if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) { | |||
| trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (trans_info->pre_ == kNONE) { | |||
| trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else { | |||
| trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 2); | |||
| if (shape.empty()) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| int element_num = shape.front(); | |||
| auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0)); | |||
| std::vector<int> axes; | |||
| if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| axes.push_back(index); | |||
| } | |||
| } else { | |||
| auto origin_axes = prim->get_axes(); | |||
| std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes), | |||
| [](int64_t v) { return static_cast<int>(v); }); | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| TransformAttrByAxes(func_graph, cnode, i, axes); | |||
| } | |||
| auto tmp_axes = TransformOpAxesAttr(axes); | |||
| std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end()); | |||
| prim->set_axes(new_axes); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| if (cnode->size() != kOnnxStridedSlice) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| std::vector<int> axes = node_infer_shape_.GetIntVecInput(cnode, kOnnxStridedSlice - 2); | |||
| if (axes.empty()) { | |||
| MS_LOG(ERROR) << "strided slice input invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t index = 2; index < cnode->size(); ++index) { | |||
| if (index == 4) { | |||
| continue; | |||
| } | |||
| TransformAttrByAxes(func_graph, cnode, index, axes); | |||
| } | |||
| auto cur_axes = TransformOpAxesAttr(axes); | |||
| auto param_node = BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(4)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(4), param_node); | |||
| return lite::RET_OK; | |||
| } | |||
| void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes) { | |||
| if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) { | |||
| return; | |||
| } | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index); | |||
| if (origin_input.size() != axes.size()) { | |||
| return; | |||
| } | |||
| std::vector<int> cur_input; | |||
| for (int dim = 0; dim < 4; ++dim) { | |||
| for (size_t index = 0; index < axes.size(); ++index) { | |||
| int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; | |||
| if (nhwc_dim == dim) { | |||
| cur_input.push_back(origin_input[index]); | |||
| } | |||
| } | |||
| } | |||
| auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(input_index), param_node); | |||
| } | |||
| std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes) { | |||
| auto axis_map = GetNC2NHAxisMap(); | |||
| std::vector<int> cur_axis; | |||
| for (size_t i = 0; i < origin_axes.size(); ++i) { | |||
| cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + 4 : origin_axes[i]]); | |||
| } | |||
| std::sort(cur_axis.begin(), cur_axis.end()); | |||
| return cur_axis; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| #include "tools/optimizer/graph/node_infershape.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TransposeStrategy { | |||
| public: | |||
| TransposeStrategy() = default; | |||
| ~TransposeStrategy() = default; | |||
| void Init(FmkType fmk_type, bool train_flag) { | |||
| fmk_type_ = fmk_type; | |||
| train_flag_ = train_flag; | |||
| node_infer_shape_.Init(fmk_type, train_flag); | |||
| } | |||
| AnfNodePtr TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &code, | |||
| const std::vector<int> &perm, bool before, size_t index); | |||
| AnfNodePtr TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm, | |||
| bool before, size_t index); | |||
| bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info, | |||
| TransTypePair *trans_insert_info); | |||
| STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| private: | |||
| STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index); | |||
| bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count, | |||
| FormatTransNodeType *trans_type); | |||
| void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info); | |||
| STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes); | |||
| std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes); | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool train_flag_{false}; | |||
| NodeInferShape node_infer_shape_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_TRANSPOSE_STRATEGY_H_ | |||
| @@ -0,0 +1,785 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "ops/op_utils.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| using mindspore::lite::NCHW_SHAPE; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kNCHWDimNumber = 4; | |||
| const std::vector<int> NH2NC = {0, 3, 1, 2}; | |||
| const std::vector<int> NC2NH = {0, 2, 3, 1}; | |||
| bool IsSpecialType(const CNodePtr &cnode) { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | |||
| CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) || | |||
| CheckPrimitiveType(cnode, prim::kPrimReturn)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = GetNCHWOpMap(); | |||
| if (fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = kNHWC2NCHW; | |||
| trans_info->post_ = kNCHW2NHWC; | |||
| } else if (fmk_type_ == lite::converter::FmkType_TF) { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && GetFormat(cnode) == NCHW) { | |||
| trans_info->pre_ = kNCHW2NHWC; | |||
| trans_info->post_ = kNHWC2NCHW; | |||
| } | |||
| if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) { | |||
| trans_info->pre_ = kNHWC2NCHW; | |||
| trans_info->post_ = kNCHW2NHWC; | |||
| } | |||
| } else { | |||
| if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) { | |||
| if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) { | |||
| return; | |||
| } | |||
| trans_info->pre_ = kNCHW2NHWC; | |||
| trans_info->post_ = kNHWC2NCHW; | |||
| } | |||
| } | |||
| } | |||
| bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || !CheckPrimitiveType(cnode->input(1), prim::kPrimTranspose)) { | |||
| return false; | |||
| } | |||
| std::vector<int> post_perm; | |||
| if (GetTransposePerm(cnode->input(2), &post_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| std::vector<int> pre_perm; | |||
| auto pre_node = cnode->input(1); | |||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | |||
| if (pre_cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (GetTransposePerm(pre_cnode->input(2), &pre_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) { | |||
| func_graph->manager()->Replace(cnode, pre_cnode->input(1)); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose)) { | |||
| return lite::RET_OK; | |||
| } | |||
| std::vector<int> cur_perm; | |||
| if (GetTransposePerm(cnode->input(2), &cur_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get transpose perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto node_users = func_graph->manager()->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| auto post_node = node_user.first; | |||
| if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) { | |||
| std::vector<int> post_trans_perm; | |||
| auto post_trans_node = post_node->cast<CNodePtr>(); | |||
| if (GetTransposePerm(post_trans_node->input(2), &post_trans_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get post transpose node perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if ((cur_perm == NH2NC && post_trans_perm == NC2NH) || (cur_perm == NC2NH && post_trans_perm == NH2NC)) { | |||
| func_graph->manager()->Replace(post_node, cnode->input(1)); | |||
| } | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, | |||
| bool before, size_t index) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| AnfNodePtr new_input = nullptr; | |||
| if (need_reset_) { | |||
| new_input = transpose_strategy_.TransposeDependOnShape(func_graph, cnode, perm, before, index); | |||
| } else { | |||
| new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index); | |||
| } | |||
| if (new_input == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (new_input == cnode->input(index) || new_input == cnode) { | |||
| return lite::RET_OK; | |||
| } else if (utils::isa<CNodePtr>(new_input)) { | |||
| auto new_cnode_input = new_input->cast<CNodePtr>(); | |||
| int status = lite::RET_OK; | |||
| if (CheckPrimitiveType(new_cnode_input, prim::kPrimTranspose)) { | |||
| if (need_reset_) { | |||
| if (before) { | |||
| pre_insert_trans_.insert(new_cnode_input); | |||
| } else { | |||
| post_insert_trans_.insert(new_cnode_input); | |||
| } | |||
| } | |||
| status = node_infer_shape_.InferShape(new_cnode_input); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| if (before) { | |||
| tr.SetEdge(cnode, index, new_input); | |||
| tr.Commit(); | |||
| } else { | |||
| func_graph->manager()->Replace(cnode, new_input); | |||
| if (!need_reset_ && PostTransposeFusion(func_graph, new_input->cast<CNodePtr>()) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "post transpose fusion failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| auto &specify_nhwc_op_map = GetNHWCOpMap(); | |||
| auto &specify_nchw_op_map = GetNCHWOpMap(); | |||
| if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() && | |||
| specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { | |||
| MS_LOG(ERROR) << "op don't meet nhwc condition."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end() | |||
| ? specify_nhwc_op_map.at(prim->name()) | |||
| : specify_nchw_op_map.at(prim->name()); | |||
| if (insert_index.empty()) { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr && | |||
| GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) { | |||
| insert_index.push_back(1); | |||
| } else { | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| insert_index.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| for (auto &index : insert_index) { | |||
| if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| TransTypePair *trans_insert_info) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| MS_ASSERT(trans_insert_info != nullptr); | |||
| TransTypePair trans_info; | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| RemoveIfMonad(cnode); | |||
| if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) { | |||
| cnode->set_inputs(origin_inputs); | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| cnode->set_inputs(origin_inputs); | |||
| auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode); | |||
| if (status == lite::RET_NOT_SUPPORT) { | |||
| return lite::RET_NO_CHANGE; | |||
| } else if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "change op attr failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? NH2NC : NC2NH; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (IsMonadNode(cnode->input(i))) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) || | |||
| CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) { | |||
| auto input_make_tuple = cnode->input(i)->cast<CNodePtr>(); | |||
| MS_ASSERT(input_make_tuple != nullptr); | |||
| for (size_t j = 1; j < input_make_tuple->size(); ++j) { | |||
| if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| continue; | |||
| } | |||
| if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| status = node_infer_shape_.InferShape(cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<int> &perm) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (!cnode->abstract()->isa<abstract::AbstractTuple>()) { | |||
| if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } else { | |||
| auto node_users = func_graph->manager()->node_users()[cnode]; | |||
| for (auto &node_user : node_users) { | |||
| auto post_node = node_user.first; | |||
| CNodePtr tuple_get_item = nullptr; | |||
| if (!CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) { | |||
| if (!train_flag_) { | |||
| MS_LOG(ERROR) << "post node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } else { | |||
| tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0); | |||
| post_node = tuple_get_item; | |||
| func_graph->manager()->Replace(cnode, tuple_get_item); | |||
| } | |||
| } | |||
| if (func_graph->manager()->node_users()[post_node].empty()) { | |||
| continue; | |||
| } | |||
| auto post_cnode = post_node->cast<CNodePtr>(); | |||
| if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "generate a new input failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (tuple_get_item != nullptr) { | |||
| func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1)); | |||
| } | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| auto node = cnode->input(i); | |||
| if (!utils::isa<ParameterPtr>(node)) { | |||
| continue; | |||
| } | |||
| auto param_node = node->cast<ParameterPtr>(); | |||
| if (param_node->has_default()) { | |||
| continue; | |||
| } | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| if (shape_vector.size() != kNCHWDimNumber) { | |||
| continue; | |||
| } | |||
| if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 && | |||
| shape_vector[1] == -1) { | |||
| continue; | |||
| } | |||
| std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H], | |||
| shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]}; | |||
| abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims)); | |||
| auto trans_cnode = GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre"); | |||
| if (trans_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "generate a transpose node failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto status = node_infer_shape_.InferShape(trans_cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| func_graph->manager()->Replace(param_node, trans_cnode); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto prim_node = cnode->input(0); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_node); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(kTransDone) != nullptr && GetValue<bool>(prim->GetAttr(kTransDone))) { | |||
| return lite::RET_OK; | |||
| } | |||
| prim->AddAttr(kTransDone, MakeValue<bool>(true)); | |||
| TransTypePair trans_info; | |||
| GetTransNodeFormatType(cnode, &trans_info); | |||
| if (!need_reset_ && (trans_info.pre_ == kNONE || trans_info.post_ == kNONE)) { | |||
| if (TransTransFusion(func_graph, cnode)) { | |||
| return lite::RET_OK; | |||
| } | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> match; | |||
| PreProcessFowardInsert(func_graph, cnode, &match); | |||
| auto status = node_infer_shape_.InferShape(cnode); | |||
| PostProcessFowardInsert(func_graph, cnode, match); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH; | |||
| auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> match; | |||
| PreProcessFowardInsert(func_graph, cnode, &match); | |||
| if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto status = node_infer_shape_.InferShape(cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| PostProcessFowardInsert(func_graph, cnode, match); | |||
| if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *match) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); | |||
| if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) { | |||
| return; | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) { | |||
| continue; | |||
| } | |||
| match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i))); | |||
| tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]); | |||
| tr.Commit(); | |||
| } | |||
| } | |||
| void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::unordered_map<AnfNodePtr, AnfNodePtr> &match) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (match.empty()) { | |||
| return; | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (match.find(cnode->input(i)) != match.end()) { | |||
| tr.SetEdge(cnode, i, match.at(cnode->input(i))); | |||
| tr.Commit(); | |||
| } | |||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) { | |||
| auto trans_cnode = cnode->input(i)->cast<CNodePtr>(); | |||
| for (size_t j = 1; j < trans_cnode->size(); ++j) { | |||
| if (match.find(trans_cnode->input(j)) == match.end()) { | |||
| continue; | |||
| } | |||
| tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j))); | |||
| tr.Commit(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto subgraph_name = GetValue<std::string>(sub_graph->get_attr("graph_name")); | |||
| sub_inputs_map_[subgraph_name] = {}; | |||
| auto sub_inputs = sub_graph->get_inputs(); | |||
| for (auto &node : sub_inputs) { | |||
| auto param_node = node->cast<ParameterPtr>(); | |||
| MS_ASSERT(param_node != nullptr); | |||
| auto node_name = node->fullname_with_scope(); | |||
| auto last_underline = node_name.find_last_of("_"); | |||
| node_name = node_name.substr(0, last_underline); | |||
| last_underline = node_name.find_last_of("_"); | |||
| auto index = std::stoi(node_name.substr(last_underline + 1)) + 3; | |||
| if (utils::isa<CNodePtr>(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) { | |||
| std::vector<int> shape = {-1}; | |||
| auto trans_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| MS_ASSERT(trans_cnode != nullptr); | |||
| auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0)); | |||
| if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(trans_prim->GetAttr(kInferDone))) { | |||
| shape = node_infer_shape_.GetInputShape(cnode, index); | |||
| } | |||
| auto type = trans_cnode->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetTypeTrack(); | |||
| std::vector<int64_t> shape_vec(shape.begin(), shape.end()); | |||
| param_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vec)); | |||
| } else { | |||
| sub_inputs_map_[subgraph_name][node] = cnode->input(index); | |||
| } | |||
| } | |||
| } | |||
| void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_input = return_node->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(return_node); | |||
| lite::AnfExporter::RemoveIfMakeTuple(return_node); | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) { | |||
| continue; | |||
| } | |||
| auto node_name = return_node->input(i)->fullname_with_scope(); | |||
| if (node_name.substr(node_name.size() - 5) != "_post") { | |||
| continue; | |||
| } | |||
| auto trans_cnode = return_node->input(i)->cast<CNodePtr>(); | |||
| MS_ASSERT(trans_cnode != nullptr); | |||
| auto trans_input = trans_cnode->input(1); | |||
| auto trans_input_name = trans_input->fullname_with_scope(); | |||
| if (utils::isa<ParameterPtr>(trans_input)) { | |||
| trans_input->cast<ParameterPtr>()->set_name(node_name); | |||
| } else if (utils::isa<CNodePtr>(trans_input)) { | |||
| trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name); | |||
| } | |||
| trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode"; | |||
| trans_cnode->set_fullname_with_scope(trans_input_name); | |||
| } | |||
| return_node->set_inputs(origin_input); | |||
| } | |||
| void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_inputs = return_node->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(return_node); | |||
| lite::AnfExporter::RemoveIfMakeTuple(return_node); | |||
| AbstractBasePtrList abstract_list; | |||
| bool infer_done = true; | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| auto abstract_base = GetCNodeInputAbstract(return_node, i); | |||
| MS_ASSERT(abstract_base != nullptr); | |||
| abstract_list.emplace_back(abstract_base->Clone()); | |||
| auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>(); | |||
| MS_ASSERT(abstract_tensor != nullptr); | |||
| auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape()); | |||
| MS_ASSERT(shape_ptr != nullptr); | |||
| auto shape = shape_ptr->shape(); | |||
| if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { | |||
| infer_done = false; | |||
| } | |||
| if (utils::isa<CNodePtr>(return_node->input(i))) { | |||
| auto input_cnode = return_node->input(i)->cast<CNodePtr>(); | |||
| if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { | |||
| input_cnode = input_cnode->input(1)->cast<CNodePtr>(); | |||
| } | |||
| auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) { | |||
| infer_done = false; | |||
| } | |||
| } | |||
| } | |||
| return_node->set_inputs(origin_inputs); | |||
| if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) { | |||
| cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| } else { | |||
| if (abstract_list.size() != 1) { | |||
| MS_LOG(ERROR) << "cnode output is invalid."; | |||
| } | |||
| cnode->set_abstract(abstract_list.front()); | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| prim->AddAttr(kInferDone, MakeValue<bool>(infer_done)); | |||
| } | |||
| bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (main_graph && !need_reset_) { | |||
| status = HandleGraphInput(func_graph, cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto origin_inputs = cnode->inputs(); | |||
| for (size_t i = 3; i < cnode->size(); ++i) { | |||
| if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && | |||
| sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { | |||
| cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); | |||
| } | |||
| } | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)BasicProcess(sub_func_graph, false); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| SetSubGraphAbstract(cnode, sub_func_graph); | |||
| cnode->set_inputs(origin_inputs); | |||
| continue; | |||
| } | |||
| status = HandleGraphNode(func_graph, cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode)) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto origin_inputs = cnode->inputs(); | |||
| for (size_t i = 3; i < cnode->size(); ++i) { | |||
| if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && | |||
| sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { | |||
| cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); | |||
| } | |||
| } | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)DecreaseTransposeForSingleOp(sub_func_graph); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| SetSubGraphInput(cnode, sub_func_graph); | |||
| (void)DecreaseTransposeForSingleOp(sub_func_graph); | |||
| SetSubGraphOutput(cnode, sub_func_graph); | |||
| SetSubGraphAbstract(cnode, sub_func_graph); | |||
| cnode->set_inputs(origin_inputs); | |||
| continue; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (!lite::IsContain(GetDynamicFormatOpList(), prim->name())) { | |||
| continue; | |||
| } | |||
| TransTypePair trans_insert_info; | |||
| status = InsertPreTransNode(func_graph, cnode, &trans_insert_info); | |||
| if (status == lite::RET_NO_CHANGE) { | |||
| continue; | |||
| } else if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert pre node failed."; | |||
| return false; | |||
| } | |||
| auto after_perm = trans_insert_info.post_ == kNHWC2NCHW ? NH2NC : NC2NH; | |||
| if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(kInferDone) != nullptr) { | |||
| prim->EraseAttr(kInferDone); | |||
| } | |||
| if (prim->GetAttr(kTransDone) != nullptr) { | |||
| prim->EraseAttr(kTransDone); | |||
| } | |||
| if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) { | |||
| manager->Replace(node, cnode->input(1)); | |||
| } | |||
| if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) { | |||
| auto cnode_abstract = cnode->abstract(); | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(cnode_abstract)) { | |||
| MS_LOG(ERROR) << "abstract is not abstract tensor."; | |||
| return false; | |||
| } | |||
| auto cnode_abstract_tensor = cnode_abstract->cast<abstract::AbstractTensorPtr>(); | |||
| if (!utils::isa<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr."; | |||
| return false; | |||
| } | |||
| auto shape_ptr = utils::cast<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape()); | |||
| auto input_abstract = GetCNodeInputAbstract(cnode, 1); | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(input_abstract)) { | |||
| MS_LOG(ERROR) << "abstract is not abstract tensor."; | |||
| return false; | |||
| } | |||
| auto input_abstract_tensor = input_abstract->cast<abstract::AbstractTensorPtr>(); | |||
| input_abstract_tensor->set_shape(shape_ptr); | |||
| manager->Replace(node, cnode->input(1)); | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)ResetFuncGraph(sub_func_graph); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| need_reset_ = true; | |||
| // insert transpose for some ops whose format must be NHWC, which is depend on framework. | |||
| // In this process, transpose op cannot be fused to restore the original graph. | |||
| if (!BasicProcess(func_graph, true)) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| // delete insert transpose op and update op output shape. | |||
| if (!ResetFuncGraph(func_graph)) { | |||
| MS_LOG(ERROR) << "reset func_graph failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| auto prim = GetValueNode<PrimitivePtr>(node); | |||
| if (prim == nullptr) { | |||
| continue; | |||
| } | |||
| if (prim->GetAttr(kTransDone) != nullptr) { | |||
| return true; | |||
| } | |||
| } | |||
| // insert transpose for some ops whose format must be NHWC, which is depend on framework. | |||
| // In this process, tranpose can be fused, which the original graph may not be able to restored. | |||
| if (!BasicProcess(func_graph, true)) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| // if input's format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | |||
| if (!DecreaseTransposeForSingleOp(func_graph)) { | |||
| MS_LOG(ERROR) << "run local trans insert optimizer failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| #include "tools/optimizer/graph/transpose_strategy.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class UnifyFormatPass : public Pass { | |||
| public: | |||
| UnifyFormatPass() : Pass("unify_format_pass") {} | |||
| ~UnifyFormatPass() override = default; | |||
| void Init(FmkType fmk_type, bool train_flag) { | |||
| fmk_type_ = fmk_type; | |||
| train_flag_ = train_flag; | |||
| node_infer_shape_.Init(fmk_type, train_flag); | |||
| transpose_strategy_.Init(fmk_type, train_flag); | |||
| } | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| bool RunOnlyForShape(const FuncGraphPtr &func_graph); | |||
| private: | |||
| bool ResetFuncGraph(const FuncGraphPtr &func_graph); | |||
| bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); | |||
| bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); | |||
| bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before, | |||
| size_t index = 0); | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); | |||
| STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); | |||
| STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| void PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *match); | |||
| void PostProcessFowardInsert(const FuncGraphPtr &funcgraph, const CNodePtr &cnode, | |||
| const std::unordered_map<AnfNodePtr, AnfNodePtr> &match); | |||
| void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); | |||
| FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| bool need_reset_{false}; | |||
| bool train_flag_{false}; | |||
| NodeInferShape node_infer_shape_; | |||
| TransposeStrategy transpose_strategy_; | |||
| std::set<AnfNodePtr> pre_insert_trans_; | |||
| std::set<AnfNodePtr> post_insert_trans_; | |||
| std::unordered_map<std::string, std::unordered_map<AnfNodePtr, AnfNodePtr>> sub_inputs_map_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNIFY_FORMAT_PASS_H_ | |||
| @@ -16,7 +16,6 @@ | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include <memory> | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/conv2d_backprop_input_fusion.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| @@ -33,7 +32,6 @@ using mindspore::schema::QuantType_WeightQuant; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion); | |||
| } // namespace | |||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; } | |||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||
| @@ -17,7 +17,6 @@ | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include "ops/fusion/conv2d_backprop_input_fusion.h" | |||
| #include "ops/transpose.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| @@ -34,7 +33,6 @@ namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kFirstInputIndex = 1; | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion); | |||
| lite::STATUS GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) { | |||
| MS_ASSERT(perm != nullptr); | |||
| auto src_format_str = std::string(schema::EnumNameFormat(src_format)); | |||