| @@ -736,7 +736,6 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo | |||||
| MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||||
| // nchw->nhwc,offsets need pad 0; | // nchw->nhwc,offsets need pad 0; | ||||
| if (axis_map[origin_axis] == 0) { | if (axis_map[origin_axis] == 0) { | ||||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | ||||
| @@ -57,34 +57,23 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||||
| auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | ||||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | ||||
| // fusion const_fold | |||||
| auto cf_pm = std::make_shared<opt::PassManager>("constant folding pass manager", false); | |||||
| cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||||
| cf_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>()); | |||||
| // for now - trainning is not supporting fuse operations | // for now - trainning is not supporting fuse operations | ||||
| if (!config->trainModel) { | if (!config->trainModel) { | ||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation, | |||||
| schema::ActivationType_RELU)); | |||||
| pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation, | |||||
| schema::ActivationType_RELU6)); | |||||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | |||||
| true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); | |||||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>( | |||||
| true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); | |||||
| pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>()); | |||||
| } | } | ||||
| auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | ||||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | weight_format_hardcode_pass->SetFmkType(config->fmk); | ||||
| @@ -108,7 +97,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| remove_unused_cast_pass->SetFmkType(config->fmk); | remove_unused_cast_pass->SetFmkType(config->fmk); | ||||
| pm->AddPass(remove_unused_cast_pass); | |||||
| fusion_pm->AddPass(remove_unused_cast_pass); | |||||
| } | } | ||||
| if (config->fmk == lite::converter::FmkType_ONNX) { | if (config->fmk == lite::converter::FmkType_ONNX) { | ||||
| auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); | ||||
| @@ -117,17 +106,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| remove_unused_transpose_pass->SetFmkType(config->fmk); | remove_unused_transpose_pass->SetFmkType(config->fmk); | ||||
| pm->AddPass(remove_unused_transpose_pass); | |||||
| fusion_pm->AddPass(remove_unused_transpose_pass); | |||||
| } | } | ||||
| pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||||
| inne_context_ptr->Init(); | |||||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); | |||||
| const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>()); | |||||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | ||||
| if (config->fmk == lite::converter::FmkType_TFLITE) { | if (config->fmk == lite::converter::FmkType_TFLITE) { | ||||
| convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); | convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); | ||||
| convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); | convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); | ||||
| } | } | ||||
| optimizer->AddPassManager(cf_pm); | |||||
| optimizer->AddPassManager(const_fold_pm); | |||||
| optimizer->AddPassManager(convert_pm); | optimizer->AddPassManager(convert_pm); | ||||
| optimizer->AddPassManager(pm); | |||||
| optimizer->AddPassManager(fusion_pm); | |||||
| optimizer->AddPassManager(graph_pm); | optimizer->AddPassManager(graph_pm); | ||||
| auto new_graph = optimizer->Optimize(old_graph); | auto new_graph = optimizer->Optimize(old_graph); | ||||
| if (new_graph == nullptr) { | if (new_graph == nullptr) { | ||||
| @@ -24,9 +24,9 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace converter { | namespace converter { | ||||
| Flags::Flags() { | Flags::Flags() { | ||||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); | |||||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); | |||||
| AddFlag(&Flags::modelFile, "modelFile", | AddFlag(&Flags::modelFile, "modelFile", | ||||
| "Input model file. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||||
| "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | ||||
| AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | ||||
| ""); | ""); | ||||
| @@ -494,6 +494,14 @@ bool IsPoolingNode(const BaseRef &n) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool IsActivationNode(const BaseRef &n) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| auto type = opt::GetCNodeType(n); | |||||
| return type == schema::PrimitiveType_Activation; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsQuantNode(const BaseRef &n) { | bool IsQuantNode(const BaseRef &n) { | ||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | ||||
| auto type = opt::GetCNodeType(n); | auto type = opt::GetCNodeType(n); | ||||
| @@ -65,6 +65,8 @@ bool IsPoolingNode(const BaseRef &n); | |||||
| bool IsQuantNode(const BaseRef &n); | bool IsQuantNode(const BaseRef &n); | ||||
| bool IsActivationNode(const BaseRef &n); | |||||
| bool CheckIsAllInputsParam(const AnfNodePtr &node); | bool CheckIsAllInputsParam(const AnfNodePtr &node); | ||||
| size_t GetOutputTensorNum(const AnfNodePtr &node); | size_t GetOutputTensorNum(const AnfNodePtr &node); | ||||
| @@ -252,7 +252,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); | << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context, lite_primitive.get()); | |||||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context.get(), lite_primitive.get()); | |||||
| if (lite_kernel == nullptr) { | if (lite_kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | ||||
| FreeTensors(&input_tensors, &output_tensors); | FreeTensors(&input_tensors, &output_tensors); | ||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | ||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | #define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_ | ||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| @@ -27,15 +29,13 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class ConstFoldPass : public PatternProcessPass { | class ConstFoldPass : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) { | |||||
| this->context = new lite::InnerContext; | |||||
| this->context->Init(); | |||||
| } | |||||
| ~ConstFoldPass() override { delete (this->context); } | |||||
| explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true) | |||||
| : PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {} | |||||
| ~ConstFoldPass() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| lite::InnerContext *context = nullptr; | |||||
| std::shared_ptr<lite::InnerContext> context; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,17 +29,14 @@ constexpr size_t kActivationInputsLength = 2; | |||||
| } | } | ||||
| const BaseRef ConvActivationFusion::DefinePattern() const { | const BaseRef ConvActivationFusion::DefinePattern() const { | ||||
| auto conv_var = std::make_shared<CondVar>(IsConvNode); | auto conv_var = std::make_shared<CondVar>(IsConvNode); | ||||
| auto prim = new schema::PrimitiveT(); | |||||
| prim->value.type = primitive_type; | |||||
| auto prim_value = std::make_shared<lite::PrimitiveC>(prim); | |||||
| return VectorRef({prim_value, conv_var}); | |||||
| auto act_var = std::make_shared<CondVar>(IsActivationNode); | |||||
| return VectorRef({act_var, conv_var}); | |||||
| } | } | ||||
| const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | ||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -53,7 +50,8 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | ||||
| auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | ||||
| MS_ASSERT(act_primitivec != nullptr); | MS_ASSERT(act_primitivec != nullptr); | ||||
| if (act_primitivec->GetType() != activation_type) { | |||||
| if (act_primitivec->GetType() != schema::ActivationType_RELU && | |||||
| act_primitivec->GetType() != schema::ActivationType_RELU6) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr pre_node = act_node->input(1); | AnfNodePtr pre_node = act_node->input(1); | ||||
| @@ -73,7 +71,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | ||||
| MS_ASSERT(primc != nullptr); | MS_ASSERT(primc != nullptr); | ||||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | ||||
| primc->SetActivationType(activation_type); | |||||
| primc->SetActivationType(act_primitivec->GetType()); | |||||
| return pre_node; | return pre_node; | ||||
| } | } | ||||
| } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { | } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| @@ -81,7 +79,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | ||||
| MS_ASSERT(primc != nullptr); | MS_ASSERT(primc != nullptr); | ||||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | ||||
| primc->SetActivationType(activation_type); | |||||
| primc->SetActivationType(act_primitivec->GetType()); | |||||
| return pre_node; | return pre_node; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -25,15 +25,11 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class ConvActivationFusion : public PatternProcessPass { | class ConvActivationFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion", | |||||
| schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, | |||||
| schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) | |||||
| : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} | |||||
| explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") | |||||
| : PatternProcessPass(name, multigraph) {} | |||||
| ~ConvActivationFusion() override = default; | ~ConvActivationFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| schema::PrimitiveType primitive_type; | |||||
| schema::ActivationType activation_type; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -113,8 +113,8 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { | |||||
| // bias --1 | // bias --1 | ||||
| // estimated_mean --2 | // estimated_mean --2 | ||||
| // estimated_variance --3 | // estimated_variance --3 | ||||
| const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| MS_ASSERT(bn_node != nullptr); | MS_ASSERT(bn_node != nullptr); | ||||
| MS_ASSERT(trans_bias != nullptr); | MS_ASSERT(trans_bias != nullptr); | ||||
| MS_ASSERT(trans_scale != nullptr); | MS_ASSERT(trans_scale != nullptr); | ||||
| @@ -25,7 +25,7 @@ class ConvBatchNormFusion : public ConvTransformFusion { | |||||
| explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {} | explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {} | ||||
| ~ConvBatchNormFusion() override = default; | ~ConvBatchNormFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||||
| void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||||
| }; | }; | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ | ||||
| @@ -15,11 +15,11 @@ | |||||
| */ | */ | ||||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | #include "tools/optimizer/fusion/conv_conv_fusion.h" | ||||
| #include <memory> | |||||
| #include <functional> | #include <functional> | ||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/conv2d.h" | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/ops/conv2d.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| @@ -128,6 +128,7 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr | |||||
| for (int k = 0; k < cout0; k++) { | for (int k = 0; k < cout0; k++) { | ||||
| auto up_weight_offset = k * window_size * cin0 + j; | auto up_weight_offset = k * window_size * cin0 + j; | ||||
| auto down_weight_offset = down_weight_base + k; | auto down_weight_offset = down_weight_base + k; | ||||
| auto new_weight_offset = new_weight_base + j; | auto new_weight_offset = new_weight_base + j; | ||||
| for (int m = 0; m < window_size; m++) { | for (int m = 0; m < window_size; m++) { | ||||
| new_weight_data[new_weight_offset + cin0 * m] += | new_weight_data[new_weight_offset + cin0 * m] += | ||||
| @@ -44,8 +44,8 @@ const BaseRef ConvScaleFusion::DefinePattern() const { | |||||
| auto bias_var = std::make_shared<SeqVar>(); | auto bias_var = std::make_shared<SeqVar>(); | ||||
| return VectorRef({bn_var, conv_var, weight_var, bias_var}); | return VectorRef({bn_var, conv_var, weight_var, bias_var}); | ||||
| } | } | ||||
| const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| MS_ASSERT(scale_node != nullptr); | MS_ASSERT(scale_node != nullptr); | ||||
| MS_ASSERT(trans_bias != nullptr); | MS_ASSERT(trans_bias != nullptr); | ||||
| MS_ASSERT(trans_scale != nullptr); | MS_ASSERT(trans_scale != nullptr); | ||||
| @@ -25,7 +25,7 @@ class ConvScaleFusion : public ConvTransformFusion { | |||||
| explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {} | explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {} | ||||
| ~ConvScaleFusion() override = default; | ~ConvScaleFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||||
| void InitTransParam(const CNodePtr &, int, float *, float *) const override; | |||||
| }; | }; | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ | ||||
| @@ -119,8 +119,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||||
| return pre_node; | return pre_node; | ||||
| } | } | ||||
| const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| if (trans_scale == nullptr) { | if (trans_scale == nullptr) { | ||||
| MS_LOG(ERROR) << "new transScale failed"; | MS_LOG(ERROR) << "new transScale failed"; | ||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | ||||
| @@ -145,9 +145,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in | |||||
| InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias); | InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias); | ||||
| } | } | ||||
| const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, | |||||
| int kernel_num, const float *trans_scale, | |||||
| const float *trans_bias) const { | |||||
| void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num, | |||||
| const float *trans_scale, const float *trans_bias) const { | |||||
| MS_ASSERT(conv_node != nullptr); | MS_ASSERT(conv_node != nullptr); | ||||
| AnfNodePtr conv_weight_node = nullptr; | AnfNodePtr conv_weight_node = nullptr; | ||||
| AnfNodePtr conv_bias_node = nullptr; | AnfNodePtr conv_bias_node = nullptr; | ||||
| @@ -203,8 +202,8 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, | |||||
| conv_node->add_input(bias_node); | conv_node->add_input(bias_node); | ||||
| } | } | ||||
| } | } | ||||
| const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | |||||
| const float *trans_scale) const { | |||||
| void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | |||||
| const float *trans_scale) const { | |||||
| MS_ASSERT(weight_data != nullptr); | MS_ASSERT(weight_data != nullptr); | ||||
| MS_ASSERT(trans_scale != nullptr); | MS_ASSERT(trans_scale != nullptr); | ||||
| auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; | auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; | ||||
| @@ -237,8 +236,8 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne | |||||
| delete[] tmp_weight_data; | delete[] tmp_weight_data; | ||||
| } | } | ||||
| const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, | |||||
| const float *trans_scale, const float *trans_bias) { | |||||
| void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale, | |||||
| const float *trans_bias) const { | |||||
| MS_ASSERT(bias_data != nullptr); | MS_ASSERT(bias_data != nullptr); | ||||
| MS_ASSERT(trans_bias != nullptr); | MS_ASSERT(trans_bias != nullptr); | ||||
| MS_ASSERT(trans_scale != nullptr); | MS_ASSERT(trans_scale != nullptr); | ||||
| @@ -27,11 +27,11 @@ class ConvTransformFusion : public PatternProcessPass { | |||||
| : PatternProcessPass(name, multigraph) {} | : PatternProcessPass(name, multigraph) {} | ||||
| ~ConvTransformFusion() override = default; | ~ConvTransformFusion() override = default; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| const void GenTransParam(const CNodePtr &, int, float *, float *) const; | |||||
| virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; | |||||
| const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; | |||||
| const void CalNewWeightTensor(float *, int, int, const float *) const; | |||||
| static const void CalNewBiasTensor(float *, int, bool, const float *, const float *); | |||||
| void GenTransParam(const CNodePtr &, int, float *, float *) const; | |||||
| virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; | |||||
| void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; | |||||
| void CalNewWeightTensor(float *, int, int, const float *) const; | |||||
| void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; | |||||
| }; | }; | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ | #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ | ||||
| @@ -26,26 +26,27 @@ | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kActivationInputsLength = 2; | constexpr size_t kActivationInputsLength = 2; | ||||
| bool IsTupleGetItemNode(const BaseRef &n) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| auto type = opt::GetCNodeType(n); | |||||
| return type == schema::PrimitiveType_TupleGetItem; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| } // namespace | |||||
| const BaseRef ConvTupleActivationFusion::DefinePattern() const { | const BaseRef ConvTupleActivationFusion::DefinePattern() const { | ||||
| auto conv_var = std::make_shared<CondVar>(IsConvNode); | auto conv_var = std::make_shared<CondVar>(IsConvNode); | ||||
| auto tuple_getitem_var = std::make_shared<CondVar>(IsTupleGetItemNode); | |||||
| auto tuple_index = std::make_shared<Var>(); | auto tuple_index = std::make_shared<Var>(); | ||||
| auto tuple_prim = new schema::PrimitiveT(); | |||||
| tuple_prim->value.type = schema::PrimitiveType_TupleGetItem; | |||||
| auto tuple_value = std::make_shared<lite::PrimitiveC>(tuple_prim); | |||||
| VectorRef tuple_get_item = VectorRef({tuple_value, conv_var, tuple_index}); | |||||
| auto act_prim = new schema::PrimitiveT(); | |||||
| act_prim->value.type = primitive_type; | |||||
| auto act_value = std::make_shared<lite::PrimitiveC>(act_prim); | |||||
| return VectorRef({act_value, tuple_get_item}); | |||||
| VectorRef tuple_get_item = VectorRef({tuple_getitem_var, conv_var, tuple_index}); | |||||
| auto act_var = std::make_shared<CondVar>(IsActivationNode); | |||||
| return VectorRef({act_var, tuple_get_item}); | |||||
| } | } | ||||
| const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -59,7 +60,8 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec)); | ||||
| auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec); | ||||
| MS_ASSERT(act_primitivec != nullptr); | MS_ASSERT(act_primitivec != nullptr); | ||||
| if (act_primitivec->GetType() != activation_type) { | |||||
| if (act_primitivec->GetType() != schema::ActivationType_RELU && | |||||
| act_primitivec->GetType() != schema::ActivationType_RELU6) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr tuple_node = act_node->input(1); | AnfNodePtr tuple_node = act_node->input(1); | ||||
| @@ -82,7 +84,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | ||||
| MS_ASSERT(primc != nullptr); | MS_ASSERT(primc != nullptr); | ||||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | ||||
| primc->SetActivationType(activation_type); | |||||
| primc->SetActivationType(act_primitivec->GetType()); | |||||
| conv_node->set_abstract(act_node->abstract()); | conv_node->set_abstract(act_node->abstract()); | ||||
| return conv_node; | return conv_node; | ||||
| } | } | ||||
| @@ -91,7 +93,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra | |||||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | ||||
| MS_ASSERT(primc != nullptr); | MS_ASSERT(primc != nullptr); | ||||
| if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { | ||||
| primc->SetActivationType(activation_type); | |||||
| primc->SetActivationType(act_primitivec->GetType()); | |||||
| conv_node->set_abstract(act_node->abstract()); | conv_node->set_abstract(act_node->abstract()); | ||||
| return conv_node; | return conv_node; | ||||
| } | } | ||||
| @@ -25,15 +25,11 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class ConvTupleActivationFusion : public PatternProcessPass { | class ConvTupleActivationFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion", | |||||
| schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, | |||||
| schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) | |||||
| : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} | |||||
| explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion") | |||||
| : PatternProcessPass(name, multigraph) {} | |||||
| ~ConvTupleActivationFusion() override = default; | ~ConvTupleActivationFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| schema::PrimitiveType primitive_type; | |||||
| schema::ActivationType activation_type; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,13 +24,6 @@ | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| bool IsActivationNode(const BaseRef &n) { | |||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||||
| auto type = opt::GetCNodeType(n); | |||||
| return type == schema::PrimitiveType_Activation; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsMulNode(const BaseRef &n) { | bool IsMulNode(const BaseRef &n) { | ||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | ||||
| auto type = opt::GetCNodeType(n); | auto type = opt::GetCNodeType(n); | ||||