From 7a6876c21ca15a3e9e573aa446c1873b588d909c Mon Sep 17 00:00:00 2001 From: tx1103mark Date: Thu, 26 Nov 2020 15:47:47 +0800 Subject: [PATCH] converter fusion module sec check --- mindspore/lite/tools/common/graph_util.cc | 1 - .../lite/tools/converter/anf_transform.cc | 48 ++++++++----------- .../lite/tools/converter/converter_flags.cc | 4 +- .../lite/tools/optimizer/common/gllo_utils.cc | 8 ++++ .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../fusion/constant_folding_fusion.cc | 2 +- .../fusion/constant_folding_fusion.h | 12 ++--- .../fusion/conv_activation_fusion.cc | 14 +++--- .../optimizer/fusion/conv_activation_fusion.h | 8 +--- .../tools/optimizer/fusion/conv_bn_fusion.cc | 4 +- .../tools/optimizer/fusion/conv_bn_fusion.h | 2 +- .../optimizer/fusion/conv_conv_fusion.cc | 7 +-- .../optimizer/fusion/conv_scale_fusion.cc | 4 +- .../optimizer/fusion/conv_scale_fusion.h | 2 +- .../optimizer/fusion/conv_transform_fusion.cc | 17 ++++--- .../optimizer/fusion/conv_transform_fusion.h | 10 ++-- .../fusion/conv_tuple_activation_fusion.cc | 28 ++++++----- .../fusion/conv_tuple_activation_fusion.h | 8 +--- .../optimizer/fusion/sigmoid_mul_fusion.cc | 7 --- 19 files changed, 88 insertions(+), 100 deletions(-) diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 30f2987875..dc16534641 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -736,7 +736,6 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptrprimitive->value.AsCrop() is nullptr"; return RET_NULL_PTR; } - node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; // nchw->nhwc,offsets need pad 0; if (axis_map[origin_axis] == 0) { offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 28332eca12..a862719436 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -57,34 +57,23 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } auto optimizer = std::make_shared(); - auto pm = std::make_shared("anf fusion pass manager", false); + auto fusion_pm = std::make_shared("anf fusion pass manager", false); auto graph_pm = std::make_shared("anf graph pass manager", true); auto convert_pm = std::make_shared("anf graph convert pass manager", true); - // fusion const_fold - auto cf_pm = std::make_shared("constant folding pass manager", false); - cf_pm->AddPass(std::make_shared()); - cf_pm->AddPass(std::make_shared()); - // for now - trainning is not supporting fuse operations if (!config->trainModel) { // remove quantdtype when awaretraining - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(true, "conv_relu", schema::PrimitiveType_Activation, - schema::ActivationType_RELU)); - pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, - schema::ActivationType_RELU6)); - pm->AddPass(std::make_shared( - true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); - pm->AddPass(std::make_shared( - true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); - pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); } auto weight_format_hardcode_pass = std::make_shared(); weight_format_hardcode_pass->SetFmkType(config->fmk); @@ -108,7 +97,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } remove_unused_cast_pass->SetFmkType(config->fmk); - pm->AddPass(remove_unused_cast_pass); + fusion_pm->AddPass(remove_unused_cast_pass); } if (config->fmk == lite::converter::FmkType_ONNX) { auto remove_unused_transpose_pass = std::make_shared(); @@ -117,17 +106,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } remove_unused_transpose_pass->SetFmkType(config->fmk); - pm->AddPass(remove_unused_transpose_pass); + fusion_pm->AddPass(remove_unused_transpose_pass); } - pm->AddPass(std::make_shared()); + auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); + auto inne_context_ptr = std::make_shared(); + inne_context_ptr->Init(); + const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); + const_fold_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); if (config->fmk == lite::converter::FmkType_TFLITE) { convert_pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); } - optimizer->AddPassManager(cf_pm); + optimizer->AddPassManager(const_fold_pm); optimizer->AddPassManager(convert_pm); - optimizer->AddPassManager(pm); + optimizer->AddPassManager(fusion_pm); optimizer->AddPassManager(graph_pm); auto new_graph = optimizer->Optimize(old_graph); if (new_graph == nullptr) { diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 40ccc61209..2327b17122 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -24,9 +24,9 @@ namespace mindspore { namespace lite { namespace converter { Flags::Flags() { - AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); + AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); AddFlag(&Flags::modelFile, "modelFile", - "Input model file. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); + "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 940b2a73a1..0a79383d2a 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -494,6 +494,14 @@ bool IsPoolingNode(const BaseRef &n) { return false; } +bool IsActivationNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Activation; + } + return false; +} + bool IsQuantNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 893a5b48fa..e0c4dd760f 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -65,6 +65,8 @@ bool IsPoolingNode(const BaseRef &n); bool IsQuantNode(const BaseRef &n); +bool IsActivationNode(const BaseRef &n); + bool CheckIsAllInputsParam(const AnfNodePtr &node); size_t GetOutputTensorNum(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 566ddb0d93..1d4388a4d3 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -252,7 +252,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); return nullptr; } - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context, lite_primitive.get()); + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context.get(), lite_primitive.get()); if (lite_kernel == nullptr) { MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; FreeTensors(&input_tensors, &output_tensors); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h index fe53b74ac3..9de1eb2d03 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h @@ -17,6 +17,8 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ +#include +#include #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "src/lite_kernel.h" @@ -27,15 +29,13 @@ namespace mindspore { namespace opt { class ConstFoldPass : public PatternProcessPass { public: - explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) { - this->context = new lite::InnerContext; - this->context->Init(); - } - ~ConstFoldPass() override { delete (this->context); } + explicit ConstFoldPass(std::shared_ptr context_ptr = nullptr, bool multigraph = true) + : PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {} + ~ConstFoldPass() override = default; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - lite::InnerContext *context = nullptr; + std::shared_ptr context; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index e82d8cb6d1..47b8d172b0 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -29,17 +29,14 @@ constexpr size_t kActivationInputsLength = 2; } const BaseRef ConvActivationFusion::DefinePattern() const { auto conv_var = std::make_shared(IsConvNode); - auto prim = new schema::PrimitiveT(); - prim->value.type = primitive_type; - auto prim_value = std::make_shared(prim); - return VectorRef({prim_value, conv_var}); + auto act_var = std::make_shared(IsActivationNode); + return VectorRef({act_var, conv_var}); } const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(node != nullptr); - MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return nullptr; @@ -53,7 +50,8 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c MS_ASSERT(utils::isa>(primitivec)); auto act_primitivec = utils::cast>(primitivec); MS_ASSERT(act_primitivec != nullptr); - if (act_primitivec->GetType() != activation_type) { + if (act_primitivec->GetType() != schema::ActivationType_RELU && + act_primitivec->GetType() != schema::ActivationType_RELU6) { return nullptr; } AnfNodePtr pre_node = act_node->input(1); @@ -73,7 +71,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { - primc->SetActivationType(activation_type); + primc->SetActivationType(act_primitivec->GetType()); return pre_node; } } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { @@ -81,7 +79,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { - primc->SetActivationType(activation_type); + primc->SetActivationType(act_primitivec->GetType()); return pre_node; } } else { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h index ed6417d59b..39077fe9a9 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h @@ -25,15 +25,11 @@ namespace mindspore { namespace opt { class ConvActivationFusion : public PatternProcessPass { public: - ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion", - schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, - schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) - : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} + explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") + : PatternProcessPass(name, multigraph) {} ~ConvActivationFusion() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - schema::PrimitiveType primitive_type; - schema::ActivationType activation_type; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc index 45f1bdff16..802a9f805c 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -113,8 +113,8 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { // bias --1 // estimated_mean --2 // estimated_variance --3 -const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, - float *trans_bias) const { +void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, + float *trans_bias) const { MS_ASSERT(bn_node != nullptr); MS_ASSERT(trans_bias != nullptr); MS_ASSERT(trans_scale != nullptr); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h index 201e582fcd..3646927c2d 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h @@ -25,7 +25,7 @@ class ConvBatchNormFusion : public ConvTransformFusion { explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {} ~ConvBatchNormFusion() override = default; const BaseRef DefinePattern() const override; - const void InitTransParam(const CNodePtr &, int, float *, float *) const override; + void InitTransParam(const CNodePtr &, int, float *, float *) const override; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc index e791e6b93b..f5a03cf7df 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc @@ -15,11 +15,11 @@ */ #include "tools/optimizer/fusion/conv_conv_fusion.h" -#include #include -#include "src/ops/primitive_c.h" -#include "src/ops/conv2d.h" +#include #include "schema/inner/model_generated.h" +#include "src/ops/conv2d.h" +#include "src/ops/primitive_c.h" #include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { @@ -128,6 +128,7 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr for (int k = 0; k < cout0; k++) { auto up_weight_offset = k * window_size * cin0 + j; auto down_weight_offset = down_weight_base + k; + auto new_weight_offset = new_weight_base + j; for (int m = 0; m < window_size; m++) { new_weight_data[new_weight_offset + cin0 * m] += diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc index 1a53132e73..af52cb2818 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -44,8 +44,8 @@ const BaseRef ConvScaleFusion::DefinePattern() const { auto bias_var = std::make_shared(); return VectorRef({bn_var, conv_var, weight_var, bias_var}); } -const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, - float *trans_bias) const { +void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, + float *trans_bias) const { MS_ASSERT(scale_node != nullptr); MS_ASSERT(trans_bias != nullptr); MS_ASSERT(trans_scale != nullptr); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h index 969490dd59..ac58a6db0f 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h @@ -25,7 +25,7 @@ class ConvScaleFusion : public ConvTransformFusion { explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {} ~ConvScaleFusion() override = default; const BaseRef DefinePattern() const override; - const void InitTransParam(const CNodePtr &, int, float *, float *) const override; + void InitTransParam(const CNodePtr &, int, float *, float *) const override; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 745e6ac253..1546927589 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -119,8 +119,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co return pre_node; } -const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, - float *trans_bias) const { +void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, + float *trans_bias) const { if (trans_scale == nullptr) { MS_LOG(ERROR) << "new transScale failed"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -145,9 +145,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias); } -const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, - int kernel_num, const float *trans_scale, - const float *trans_bias) const { +void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num, + const float *trans_scale, const float *trans_bias) const { MS_ASSERT(conv_node != nullptr); AnfNodePtr conv_weight_node = nullptr; AnfNodePtr conv_bias_node = nullptr; @@ -203,8 +202,8 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, conv_node->add_input(bias_node); } } -const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, - const float *trans_scale) const { +void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, + const float *trans_scale) const { MS_ASSERT(weight_data != nullptr); MS_ASSERT(trans_scale != nullptr); auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; @@ -237,8 +236,8 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne delete[] tmp_weight_data; } -const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, - const float *trans_scale, const float *trans_bias) { +void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale, + const float *trans_bias) const { MS_ASSERT(bias_data != nullptr); MS_ASSERT(trans_bias != nullptr); MS_ASSERT(trans_scale != nullptr); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h index da161c0192..04ba3e5f91 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h @@ -27,11 +27,11 @@ class ConvTransformFusion : public PatternProcessPass { : PatternProcessPass(name, multigraph) {} ~ConvTransformFusion() override = default; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - const void GenTransParam(const CNodePtr &, int, float *, float *) const; - virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; - const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; - const void CalNewWeightTensor(float *, int, int, const float *) const; - static const void CalNewBiasTensor(float *, int, bool, const float *, const float *); + void GenTransParam(const CNodePtr &, int, float *, float *) const; + virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; + void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; + void CalNewWeightTensor(float *, int, int, const float *) const; + void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc index 6f4ec96295..796a928f72 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc @@ -26,26 +26,27 @@ namespace mindspore::opt { namespace { constexpr size_t kActivationInputsLength = 2; +bool IsTupleGetItemNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_TupleGetItem; + } + return false; } +} // namespace const BaseRef ConvTupleActivationFusion::DefinePattern() const { auto conv_var = std::make_shared(IsConvNode); + auto tuple_getitem_var = std::make_shared(IsTupleGetItemNode); auto tuple_index = std::make_shared(); - auto tuple_prim = new schema::PrimitiveT(); - tuple_prim->value.type = schema::PrimitiveType_TupleGetItem; - auto tuple_value = std::make_shared(tuple_prim); - VectorRef tuple_get_item = VectorRef({tuple_value, conv_var, tuple_index}); - - auto act_prim = new schema::PrimitiveT(); - act_prim->value.type = primitive_type; - auto act_value = std::make_shared(act_prim); - return VectorRef({act_value, tuple_get_item}); + VectorRef tuple_get_item = VectorRef({tuple_getitem_var, conv_var, tuple_index}); + auto act_var = std::make_shared(IsActivationNode); + return VectorRef({act_var, tuple_get_item}); } const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(node != nullptr); - MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { return nullptr; } @@ -59,7 +60,8 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra MS_ASSERT(utils::isa>(primitivec)); auto act_primitivec = utils::cast>(primitivec); MS_ASSERT(act_primitivec != nullptr); - if (act_primitivec->GetType() != activation_type) { + if (act_primitivec->GetType() != schema::ActivationType_RELU && + act_primitivec->GetType() != schema::ActivationType_RELU6) { return nullptr; } AnfNodePtr tuple_node = act_node->input(1); @@ -82,7 +84,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { - primc->SetActivationType(activation_type); + primc->SetActivationType(act_primitivec->GetType()); conv_node->set_abstract(act_node->abstract()); return conv_node; } @@ -91,7 +93,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { - primc->SetActivationType(activation_type); + primc->SetActivationType(act_primitivec->GetType()); conv_node->set_abstract(act_node->abstract()); return conv_node; } diff --git a/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.h index e89974976f..74b499415f 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_tuple_activation_fusion.h @@ -25,15 +25,11 @@ namespace mindspore { namespace opt { class ConvTupleActivationFusion : public PatternProcessPass { public: - ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion", - schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, - schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) - : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} + explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion") + : PatternProcessPass(name, multigraph) {} ~ConvTupleActivationFusion() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - schema::PrimitiveType primitive_type; - schema::ActivationType activation_type; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc index af56444dd1..c61617bd18 100644 --- a/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc @@ -24,13 +24,6 @@ namespace mindspore::opt { namespace { -bool IsActivationNode(const BaseRef &n) { - if (utils::isa(n) || utils::isa(n)) { - auto type = opt::GetCNodeType(n); - return type == schema::PrimitiveType_Activation; - } - return false; -} bool IsMulNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n);