| @@ -736,7 +736,6 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo | |||
| MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||
| // nchw->nhwc,offsets need pad 0; | |||
| if (axis_map[origin_axis] == 0) { | |||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | |||
| @@ -57,34 +57,23 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| return nullptr; | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| // fusion const_fold | |||
| auto cf_pm = std::make_shared<opt::PassManager>("constant folding pass manager", false); | |||
| cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| cf_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>()); | |||
| // for now - trainning is not supporting fuse operations | |||
| if (!config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | |||
| pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation, | |||
| schema::ActivationType_RELU)); | |||
| pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation, | |||
| schema::ActivationType_RELU6)); | |||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | |||
| true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); | |||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | |||
| true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); | |||
| pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | |||
| } | |||
| auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | |||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | |||
| @@ -108,7 +97,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| return nullptr; | |||
| } | |||
| remove_unused_cast_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_unused_cast_pass); | |||
| fusion_pm->AddPass(remove_unused_cast_pass); | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_ONNX) { | |||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | |||
| @@ -117,17 +106,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| return nullptr; | |||
| } | |||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | |||
| pm->AddPass(remove_unused_transpose_pass); | |||
| fusion_pm->AddPass(remove_unused_transpose_pass); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||
| inne_context_ptr->Init(); | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); | |||
| const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>()); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | |||
| if (config->fmk == lite::converter::FmkType_TFLITE) { | |||
| convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); | |||
| convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); | |||
| } | |||
| optimizer->AddPassManager(cf_pm); | |||
| optimizer->AddPassManager(const_fold_pm); | |||
| optimizer->AddPassManager(convert_pm); | |||
| optimizer->AddPassManager(pm); | |||
| optimizer->AddPassManager(fusion_pm); | |||
| optimizer->AddPassManager(graph_pm); | |||
| auto new_graph = optimizer->Optimize(old_graph); | |||
| if (new_graph == nullptr) { | |||
| @@ -24,9 +24,9 @@ namespace mindspore { | |||
| namespace lite { | |||
| namespace converter { | |||
| Flags::Flags() { | |||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); | |||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); | |||
| AddFlag(&Flags::modelFile, "modelFile", | |||
| "Input model file. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||
| "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | |||
| AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | |||
| ""); | |||
| @@ -494,6 +494,14 @@ bool IsPoolingNode(const BaseRef &n) { | |||
| return false; | |||
| } | |||
| bool IsActivationNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||
| return type == schema::PrimitiveType_Activation; | |||
| } | |||
| return false; | |||
| } | |||
| bool IsQuantNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||
| @@ -65,6 +65,8 @@ bool IsPoolingNode(const BaseRef &n); | |||
| bool IsQuantNode(const BaseRef &n); | |||
| bool IsActivationNode(const BaseRef &n); | |||
| bool CheckIsAllInputsParam(const AnfNodePtr &node); | |||
| size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| @@ -252,7 +252,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); | |||
| return nullptr; | |||
| } | |||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context, lite_primitive.get()); | |||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context.get(), lite_primitive.get()); | |||
| if (lite_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | |||
| FreeTensors(&input_tensors, &output_tensors); | |||
| @@ -17,6 +17,8 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/tensor.h" | |||
| #include "src/lite_kernel.h" | |||
| @@ -27,15 +29,13 @@ namespace mindspore { | |||
| namespace opt { | |||
| class ConstFoldPass : public PatternProcessPass { | |||
| public: | |||
| explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) { | |||
| this->context = new lite::InnerContext; | |||
| this->context->Init(); | |||
| } | |||
| ~ConstFoldPass() override { delete (this->context); } | |||
| explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true) | |||
| : PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {} | |||
| ~ConstFoldPass() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| lite::InnerContext *context = nullptr; | |||
| std::shared_ptr<lite::InnerContext> context; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -29,17 +29,14 @@ constexpr size_t kActivationInputsLength = 2; | |||
| } | |||
| const BaseRef ConvActivationFusion::DefinePattern() const { | |||
| auto conv_var = std::make_shared<CondVar>(IsConvNode); | |||
| auto prim = new schema::PrimitiveT(); | |||
| prim->value.type = primitive_type; | |||
| auto prim_value = std::make_shared<lite::PrimitiveC>(prim); | |||
| return VectorRef({prim_value, conv_var}); | |||
| auto act_var = std::make_shared<CondVar>(IsActivationNode); | |||
| return VectorRef({act_var, conv_var}); | |||
| } | |||
| const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return nullptr; | |||
| @@ -53,7 +50,8 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | |||
| auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | |||
| MS_ASSERT(act_primitivec != nullptr); | |||
| if (act_primitivec->GetType() != activation_type) { | |||
| if (act_primitivec->GetType() != schema::ActivationType_RELU && | |||
| act_primitivec->GetType() != schema::ActivationType_RELU6) { | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr pre_node = act_node->input(1); | |||
| @@ -73,7 +71,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | |||
| primc->SetActivationType(activation_type); | |||
| primc->SetActivationType(act_primitivec->GetType()); | |||
| return pre_node; | |||
| } | |||
| } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| @@ -81,7 +79,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | |||
| primc->SetActivationType(activation_type); | |||
| primc->SetActivationType(act_primitivec->GetType()); | |||
| return pre_node; | |||
| } | |||
| } else { | |||
| @@ -25,15 +25,11 @@ namespace mindspore { | |||
| namespace opt { | |||
| class ConvActivationFusion : public PatternProcessPass { | |||
| public: | |||
| ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion", | |||
| schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, | |||
| schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) | |||
| : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} | |||
| explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") | |||
| : PatternProcessPass(name, multigraph) {} | |||
| ~ConvActivationFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| schema::PrimitiveType primitive_type; | |||
| schema::ActivationType activation_type; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -113,8 +113,8 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { | |||
| // bias --1 | |||
| // estimated_mean --2 | |||
| // estimated_variance --3 | |||
| const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, | |||
| float *trans_bias) const { | |||
| void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, | |||
| float *trans_bias) const { | |||
| MS_ASSERT(bn_node != nullptr); | |||
| MS_ASSERT(trans_bias != nullptr); | |||
| MS_ASSERT(trans_scale != nullptr); | |||
| @@ -25,7 +25,7 @@ class ConvBatchNormFusion : public ConvTransformFusion { | |||
| explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {} | |||
| ~ConvBatchNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||
| void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | |||
| @@ -15,11 +15,11 @@ | |||
| */ | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/conv2d.h" | |||
| #include <memory> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/ops/conv2d.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore::opt { | |||
| @@ -128,6 +128,7 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr | |||
| for (int k = 0; k < cout0; k++) { | |||
| auto up_weight_offset = k * window_size * cin0 + j; | |||
| auto down_weight_offset = down_weight_base + k; | |||
| auto new_weight_offset = new_weight_base + j; | |||
| for (int m = 0; m < window_size; m++) { | |||
| new_weight_data[new_weight_offset + cin0 * m] += | |||
| @@ -44,8 +44,8 @@ const BaseRef ConvScaleFusion::DefinePattern() const { | |||
| auto bias_var = std::make_shared<SeqVar>(); | |||
| return VectorRef({bn_var, conv_var, weight_var, bias_var}); | |||
| } | |||
| const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, | |||
| float *trans_bias) const { | |||
| void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, | |||
| float *trans_bias) const { | |||
| MS_ASSERT(scale_node != nullptr); | |||
| MS_ASSERT(trans_bias != nullptr); | |||
| MS_ASSERT(trans_scale != nullptr); | |||
| @@ -25,7 +25,7 @@ class ConvScaleFusion : public ConvTransformFusion { | |||
| explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {} | |||
| ~ConvScaleFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||
| void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | |||
| @@ -119,8 +119,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||
| return pre_node; | |||
| } | |||
| const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, | |||
| float *trans_bias) const { | |||
| void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, | |||
| float *trans_bias) const { | |||
| if (trans_scale == nullptr) { | |||
| MS_LOG(ERROR) << "new transScale failed"; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| @@ -145,9 +145,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in | |||
| InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias); | |||
| } | |||
| const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, | |||
| int kernel_num, const float *trans_scale, | |||
| const float *trans_bias) const { | |||
| void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num, | |||
| const float *trans_scale, const float *trans_bias) const { | |||
| MS_ASSERT(conv_node != nullptr); | |||
| AnfNodePtr conv_weight_node = nullptr; | |||
| AnfNodePtr conv_bias_node = nullptr; | |||
| @@ -203,8 +202,8 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, | |||
| conv_node->add_input(bias_node); | |||
| } | |||
| } | |||
| const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | |||
| const float *trans_scale) const { | |||
| void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | |||
| const float *trans_scale) const { | |||
| MS_ASSERT(weight_data != nullptr); | |||
| MS_ASSERT(trans_scale != nullptr); | |||
| auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; | |||
| @@ -237,8 +236,8 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne | |||
| delete[] tmp_weight_data; | |||
| } | |||
| const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, | |||
| const float *trans_scale, const float *trans_bias) { | |||
| void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale, | |||
| const float *trans_bias) const { | |||
| MS_ASSERT(bias_data != nullptr); | |||
| MS_ASSERT(trans_bias != nullptr); | |||
| MS_ASSERT(trans_scale != nullptr); | |||
| @@ -27,11 +27,11 @@ class ConvTransformFusion : public PatternProcessPass { | |||
| : PatternProcessPass(name, multigraph) {} | |||
| ~ConvTransformFusion() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| const void GenTransParam(const CNodePtr &, int, float *, float *) const; | |||
| virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; | |||
| const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; | |||
| const void CalNewWeightTensor(float *, int, int, const float *) const; | |||
| static const void CalNewBiasTensor(float *, int, bool, const float *, const float *); | |||
| void GenTransParam(const CNodePtr &, int, float *, float *) const; | |||
| virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; | |||
| void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; | |||
| void CalNewWeightTensor(float *, int, int, const float *) const; | |||
| void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ | |||
| @@ -26,26 +26,27 @@ | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kActivationInputsLength = 2; | |||
| bool IsTupleGetItemNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||
| return type == schema::PrimitiveType_TupleGetItem; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| const BaseRef ConvTupleActivationFusion::DefinePattern() const { | |||
| auto conv_var = std::make_shared<CondVar>(IsConvNode); | |||
| auto tuple_getitem_var = std::make_shared<CondVar>(IsTupleGetItemNode); | |||
| auto tuple_index = std::make_shared<Var>(); | |||
| auto tuple_prim = new schema::PrimitiveT(); | |||
| tuple_prim->value.type = schema::PrimitiveType_TupleGetItem; | |||
| auto tuple_value = std::make_shared<lite::PrimitiveC>(tuple_prim); | |||
| VectorRef tuple_get_item = VectorRef({tuple_value, conv_var, tuple_index}); | |||
| auto act_prim = new schema::PrimitiveT(); | |||
| act_prim->value.type = primitive_type; | |||
| auto act_value = std::make_shared<lite::PrimitiveC>(act_prim); | |||
| return VectorRef({act_value, tuple_get_item}); | |||
| VectorRef tuple_get_item = VectorRef({tuple_getitem_var, conv_var, tuple_index}); | |||
| auto act_var = std::make_shared<CondVar>(IsActivationNode); | |||
| return VectorRef({act_var, tuple_get_item}); | |||
| } | |||
| const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| @@ -59,7 +60,8 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | |||
| auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | |||
| MS_ASSERT(act_primitivec != nullptr); | |||
| if (act_primitivec->GetType() != activation_type) { | |||
| if (act_primitivec->GetType() != schema::ActivationType_RELU && | |||
| act_primitivec->GetType() != schema::ActivationType_RELU6) { | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr tuple_node = act_node->input(1); | |||
| @@ -82,7 +84,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | |||
| primc->SetActivationType(activation_type); | |||
| primc->SetActivationType(act_primitivec->GetType()); | |||
| conv_node->set_abstract(act_node->abstract()); | |||
| return conv_node; | |||
| } | |||
| @@ -91,7 +93,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | |||
| primc->SetActivationType(activation_type); | |||
| primc->SetActivationType(act_primitivec->GetType()); | |||
| conv_node->set_abstract(act_node->abstract()); | |||
| return conv_node; | |||
| } | |||
| @@ -25,15 +25,11 @@ namespace mindspore { | |||
| namespace opt { | |||
| class ConvTupleActivationFusion : public PatternProcessPass { | |||
| public: | |||
| ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion", | |||
| schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, | |||
| schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) | |||
| : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} | |||
| explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion") | |||
| : PatternProcessPass(name, multigraph) {} | |||
| ~ConvTupleActivationFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| schema::PrimitiveType primitive_type; | |||
| schema::ActivationType activation_type; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -24,13 +24,6 @@ | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| bool IsActivationNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||
| return type == schema::PrimitiveType_Activation; | |||
| } | |||
| return false; | |||
| } | |||
| bool IsMulNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||