Merge pull request !22431 from 徐安越/primitivetags/v1.5.0-rc1
| @@ -609,6 +609,7 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) { | |||
| // init init_ref_count for subgraphs and kernels | |||
| for (auto *kernel : this->kernels_) { | |||
| kernel->InitOutTensorInitRefCount(); | |||
| #ifndef DELEGATE_CLIP | |||
| if (kernel->desc().arch == kernel::kDelegate) { | |||
| continue; | |||
| @@ -617,7 +618,6 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) { | |||
| if (IsIsolatedSubGraph(kernel)) { | |||
| static_cast<kernel::SubGraphKernel *>(kernel)->InitInputTensorInitRefCount(); | |||
| } | |||
| kernel->InitOutTensorInitRefCount(); | |||
| } | |||
| AdjustModelOutputTensorInitRefCount(model); | |||
| for (auto kernel : this->kernels_) { | |||
| @@ -41,20 +41,30 @@ OpParameter *PopulateConvParameter(const void *prim) { | |||
| auto stride = value->stride(); | |||
| auto pad_list = value->pad_list(); | |||
| auto dilation = value->dilation(); | |||
| if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) { | |||
| if (kernel_size != nullptr) { | |||
| if (kernel_size->size() < kMinShapeSizeTwo) { | |||
| MS_LOG(ERROR) << "kernel size is invalid."; | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||
| } else { | |||
| param->kernel_h_ = -1; | |||
| param->kernel_w_ = -1; | |||
| } | |||
| if (stride == nullptr || dilation == nullptr) { | |||
| MS_LOG(ERROR) << "kernel_size/stride/dilation is nullptr"; | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| if (kernel_size->size() < kMinShapeSizeTwo || stride->size() < kMinShapeSizeTwo || | |||
| dilation->size() < kMinShapeSizeTwo) { | |||
| if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) { | |||
| MS_LOG(ERROR) << "Invalid shape size!kernel_size size: " << kernel_size->size() | |||
| << ", stride size: " << stride->size() << ", dilation size: " << dilation->size(); | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||
| param->group_ = static_cast<int>(value->group()); | |||
| param->stride_h_ = static_cast<int>(*(stride->begin())); | |||
| param->stride_w_ = static_cast<int>(*(stride->begin() + 1)); | |||
| @@ -43,13 +43,24 @@ OpParameter *PopulateDeconvParameter(const void *prim) { | |||
| auto pad_list = value->pad_list(); | |||
| auto dilation = value->dilation(); | |||
| auto output_paddings = value->output_paddings(); | |||
| if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) { | |||
| if (kernel_size != nullptr) { | |||
| if (kernel_size->size() < kMinShapeSizeTwo) { | |||
| MS_LOG(ERROR) << "kernel size is invalid."; | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| param->kernel_h_ = static_cast<int>(*(kernel_size->begin())); | |||
| param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1)); | |||
| } else { | |||
| param->kernel_h_ = -1; | |||
| param->kernel_w_ = -1; | |||
| } | |||
| if (stride == nullptr || dilation == nullptr || output_paddings == nullptr) { | |||
| MS_LOG(ERROR) << "nullptr"; | |||
| free(param); | |||
| return nullptr; | |||
| } | |||
| if (kernel_size->size() < kMinShapeSizeTwo || stride->size() < kMinShapeSizeTwo || | |||
| dilation->size() < kMinShapeSizeTwo) { | |||
| if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) { | |||
| MS_LOG(ERROR) << "Invalid shape size!kernel_size size: " << kernel_size->size() | |||
| << ", stride size: " << stride->size() << ", dilation size: " << dilation->size() | |||
| << ", output_paddings size:" << output_paddings->size(); | |||
| @@ -217,8 +217,6 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F | |||
| int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); | |||
| auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel); | |||
| convert_pm->AddPass(infershape_pass); | |||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | |||
| optimizer->AddPassManager(convert_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| @@ -235,8 +233,9 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte | |||
| if (!config->trainModel) { | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk)); | |||
| } | |||
| auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel); | |||
| const_fold_pm->AddPass(infershape_pass); | |||
| auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>(); | |||
| update_conv2d_param_pass->SetFmkType(config->fmk); | |||
| const_fold_pm->AddPass(update_conv2d_param_pass); | |||
| optimizer->AddPassManager(const_fold_pm); | |||
| if (optimizer->Optimize(old_graph) == nullptr) { | |||
| @@ -37,42 +37,38 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt | |||
| MS_LOG(DEBUG) << "there is no attr :" << attr_name; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| if (static_cast<int>(inputs.size()) > input_num) { | |||
| if (static_cast<int>(cnode->size()) > input_num) { | |||
| primitive_c->EraseAttr(attr_name); | |||
| MS_LOG(DEBUG) << "input num has been meet, which is " << inputs.size(); | |||
| MS_LOG(DEBUG) << "input num has been meet, which is " << cnode->size(); | |||
| return lite::RET_OK; | |||
| } else if (static_cast<int>(inputs.size()) < input_num) { | |||
| } else if (static_cast<int>(cnode->size()) < input_num) { | |||
| MS_LOG(ERROR) << "input num is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| AnfNodePtr param_node; | |||
| switch (flag) { | |||
| case 1: { | |||
| auto value_data = opt::CastToInt(value_ptr).front(); | |||
| auto param_node = | |||
| param_node = | |||
| opt::BuildIntValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case kBuildInputFlagTwo: { | |||
| auto value_data = opt::CastToInt(value_ptr); | |||
| auto param_node = | |||
| param_node = | |||
| opt::BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case kBuildInputFlagThree: { | |||
| auto value_data = opt::CastToVec2DInt(value_ptr); | |||
| auto param_node = | |||
| param_node = | |||
| opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| case kBuildInputFlagFour: { | |||
| auto value_data = GetValue<float>(value_ptr); | |||
| auto param_node = | |||
| param_node = | |||
| opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); | |||
| inputs.push_back(param_node); | |||
| break; | |||
| } | |||
| default: { | |||
| @@ -80,8 +76,11 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| cnode->set_inputs(inputs); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| tr.AddEdge(cnode, param_node); | |||
| tr.Commit(); | |||
| return lite::RET_OK; | |||
| } | |||
| @@ -124,6 +124,7 @@ STATUS GetConvChannel(const onnx::GraphProto &onnx_graph, const onnx::NodeProto | |||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | |||
| if (node_iter == onnx_graph.initializer().end()) { | |||
| MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; | |||
| return RET_NO_CHANGE; | |||
| } else { | |||
| std::vector<int> weight_shape; | |||
| auto size = (*node_iter).dims_size(); | |||
| @@ -151,6 +152,12 @@ STATUS GetConvChannel(const onnx::GraphProto &onnx_graph, const onnx::NodeProto | |||
| return RET_ERROR; | |||
| } | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } else { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| if (dims.size() < kNumDim4) { | |||
| MS_LOG(ERROR) << "conv weight size is not 4D, please check."; | |||
| return RET_ERROR; | |||
| } | |||
| *channel_out = dims.at(0); | |||
| *channel_in = dims.at(3) * group; | |||
| @@ -211,11 +218,13 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| // get channel_out and channel_in | |||
| if (GetConvChannel(onnx_graph, onnx_node, group, &channel_out, &channel_in) != RET_OK) { | |||
| auto status = GetConvChannel(onnx_graph, onnx_node, group, &channel_out, &channel_in); | |||
| if (status == RET_OK) { | |||
| prim->set_in_channel(channel_in); | |||
| prim->set_out_channel(channel_out); | |||
| } else if (status != RET_NO_CHANGE) { | |||
| return nullptr; | |||
| } | |||
| prim->set_in_channel(channel_in); | |||
| prim->set_out_channel(channel_out); | |||
| // parse activationType | |||
| prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION); | |||
| @@ -77,9 +77,6 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | |||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | |||
| if (node_iter == onnx_graph.initializer().end()) { | |||
| // in_channel and out_channnel is set to 1 by default. | |||
| prim->set_in_channel(1); | |||
| prim->set_out_channel(1); | |||
| MS_LOG(WARNING) << "parsing of channelIn/Out is delayed."; | |||
| } else { | |||
| std::vector<int> weight_shape; | |||
| @@ -60,9 +60,6 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| prim->set_out_channel(kernels[3]); | |||
| prim->set_in_channel(kernels[2]); | |||
| } else { | |||
| prim->set_kernel_size({0, 0}); | |||
| prim->set_out_channel(1); | |||
| prim->set_in_channel(1); | |||
| MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; | |||
| } | |||
| @@ -84,8 +81,10 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| if (tf_op.op() == "DepthwiseConv2dNative") { | |||
| prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true)); | |||
| prim->set_group(prim->get_in_channel()); | |||
| prim->set_out_channel(prim->get_in_channel()); | |||
| if (prim->GetAttr(ops::kInChannel) != nullptr) { | |||
| prim->set_group(prim->get_in_channel()); | |||
| prim->set_out_channel(prim->get_in_channel()); | |||
| } | |||
| } | |||
| return prim.release(); | |||
| @@ -60,9 +60,6 @@ ops::PrimitiveC *TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| prim->set_out_channel(kernels[2]); | |||
| prim->set_in_channel(kernels[3]); | |||
| } else { | |||
| prim->set_kernel_size({-1, -1}); | |||
| prim->set_out_channel(-1); | |||
| prim->set_in_channel(-1); | |||
| MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; | |||
| } | |||
| @@ -39,6 +39,7 @@ | |||
| #include "ops/fusion/div_fusion.h" | |||
| #include "ops/fusion/max_pool_fusion.h" | |||
| #include "ops/fusion/mul_fusion.h" | |||
| #include "ops/fusion/pad_fusion.h" | |||
| #include "ops/fusion/pow_fusion.h" | |||
| #include "ops/fusion/prelu_fusion.h" | |||
| #include "ops/fusion/slice_fusion.h" | |||
| @@ -98,18 +99,27 @@ static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = { | |||
| static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}}; | |||
| // a certain op whose input's format is not fixed. | |||
| static const std::vector<std::string> DynamicFormatOpList = { | |||
| ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNameDivFusion, ops::kNamePowFusion, | |||
| ops::kNameStridedSlice, ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, | |||
| ops::kNameCrop, ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; | |||
| // a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not. | |||
| static const std::unordered_map<std::string, bool> DynamicFormatOpList = { | |||
| {ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true}, | |||
| {ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false}, | |||
| {ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false}, | |||
| {ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false}, | |||
| {ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false}, | |||
| {ops::kNameQuantDTypeCast, false}}; | |||
| static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; } | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; } | |||
| const std::unordered_map<int, int> &GetNC2NHAxisMap() { return NC2NHAxisMap; } | |||
| const std::vector<std::string> &GetDynamicFormatOpList() { return DynamicFormatOpList; } | |||
| bool IsDynamicFormatOp(const std::string &op_type) { | |||
| return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end(); | |||
| } | |||
| bool IsDynamicFormatOpWithAxis(const std::string &op_type) { | |||
| auto iter = DynamicFormatOpList.find(op_type); | |||
| return iter != DynamicFormatOpList.end() && iter->second; | |||
| } | |||
| Format GetFormat(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| @@ -34,8 +34,9 @@ struct TransTypePair { | |||
| }; | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap(); | |||
| const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap(); | |||
| const std::unordered_map<int, int> &GetNC2NHAxisMap(); | |||
| const std::vector<std::string> &GetDynamicFormatOpList(); | |||
| bool IsDynamicFormatOp(const std::string &op_type); | |||
| bool IsDynamicFormatOpWithAxis(const std::string &op_type); | |||
| Format GetFormat(const CNodePtr &cnode); | |||
| STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm); | |||
| void RemoveIfMonad(const CNodePtr &cnode); | |||
| @@ -496,13 +496,12 @@ int CheckLeastInputSize(const CNodePtr &node, const int size) { | |||
| return lite::RET_OK; | |||
| } | |||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||
| const tensor::TensorPtr &weight_tensor) { | |||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) { | |||
| auto bias_parameter = func_graph->add_parameter(); | |||
| MS_ASSERT(bias_parameter != nullptr); | |||
| std::vector<int64_t> shape_vector = {kernel_num}; | |||
| auto tensor_info = lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, | |||
| weight_tensor->data_type()); | |||
| auto tensor_info = | |||
| lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "create tensor info failed."; | |||
| return nullptr; | |||
| @@ -613,6 +612,9 @@ bool IsParamOrValueNodeWithData(const BaseRef &n) { | |||
| } | |||
| } | |||
| if (utils::isa<ParameterPtr>(n)) { | |||
| if (!utils::cast<ParameterPtr>(n)->has_default()) { | |||
| return false; | |||
| } | |||
| auto param = utils::cast<ParameterPtr>(n)->default_param(); | |||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param); | |||
| if (tensor == nullptr || tensor->data_c() == nullptr) { | |||
| @@ -82,8 +82,7 @@ int CheckIfNodeIsParamOrValue(const AnfNodePtr &node); | |||
| int CheckLeastInputSize(const CNodePtr &node, int size); | |||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||
| const tensor::TensorPtr &weight_tensor); | |||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id); | |||
| bool IsParamNode(const BaseRef &n); | |||
| @@ -14,12 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/conv_biasadd_fusion.h" | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/fusion/add_fusion.h" | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/conv2d_transpose_fusion.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "utils/utils.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "securec/include/securec.h" | |||
| @@ -48,238 +50,171 @@ bool IsAddNode(const BaseRef &n) { | |||
| return false; | |||
| } | |||
| int Get_Kenrnel_nums(const CNodePtr &conv_node) { | |||
| MS_ASSERT(conv_node != nullptr); | |||
| auto value_primitive = conv_node->input(0); | |||
| auto value_node = value_primitive->cast<ValueNodePtr>(); | |||
| MS_ASSERT(value_node != nullptr); | |||
| auto value = value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| auto primitive = value->cast<PrimitiveCPtr>(); | |||
| MS_ASSERT(primitive != nullptr); | |||
| if (primitive->isa<mindspore::ops::Conv2DFusion>()) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::Conv2DFusion>>(primitive)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::ops::Conv2DFusion>>(primitive); | |||
| MS_ASSERT(primc != nullptr); | |||
| return primc->get_out_channel(); | |||
| } else if (primitive->isa<mindspore::ops::Conv2dTransposeFusion>()) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::Conv2dTransposeFusion>>(primitive)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::ops::Conv2dTransposeFusion>>(primitive); | |||
| MS_ASSERT(primc != nullptr); | |||
| return primc->get_out_channel(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType, " << primitive->name(); | |||
| return 0; | |||
| } | |||
| } | |||
| int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) { | |||
| MS_ASSERT(bias_add_weight_node != nullptr); | |||
| MS_ASSERT(add_bias_data != nullptr); | |||
| MS_ASSERT(*add_bias_data != nullptr); | |||
| float *add_weight_data = nullptr; | |||
| ShapeVector add_weight_shape; | |||
| if (utils::isa<Parameter>(bias_add_weight_node)) { | |||
| auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>(); | |||
| if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) { | |||
| MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr."; | |||
| return lite::RET_ERROR; | |||
| bool FuseBias(const lite::DataInfo &add_bias, const lite::DataInfo &conv_bias, std::vector<float> *fusion_bias, | |||
| int out_channel) { | |||
| MS_ASSERT(conv_bias != nullptr); | |||
| if ((add_bias.data_type_ != TypeId::kNumberTypeFloat32 && add_bias.data_type_ != TypeId::kNumberTypeFloat) || | |||
| add_bias.data_.empty()) { | |||
| return false; | |||
| } | |||
| if (out_channel <= 0) { | |||
| return false; | |||
| } | |||
| std::vector<float> add_bias_data(add_bias.data_.size() / sizeof(float)); | |||
| if (memcpy_s(add_bias_data.data(), add_bias.data_.size(), add_bias.data_.data(), add_bias.data_.size()) != EOK) { | |||
| return false; | |||
| } | |||
| fusion_bias->resize(out_channel, 0); | |||
| if (!conv_bias.data_.empty()) { | |||
| if (conv_bias.data_type_ != TypeId::kNumberTypeFloat32 && conv_bias.data_type_ != TypeId::kNumberTypeFloat && | |||
| conv_bias.data_.size() != out_channel * sizeof(float)) { | |||
| return false; | |||
| } | |||
| auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param()); | |||
| if (add_weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope() | |||
| << " is not tensorPtr."; | |||
| return lite::RET_ERROR; | |||
| if (memcpy_s(fusion_bias->data(), conv_bias.data_.size(), conv_bias.data_.data(), conv_bias.data_.size()) != EOK) { | |||
| return false; | |||
| } | |||
| add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c()); | |||
| MS_ASSERT(add_weight_data != nullptr); | |||
| add_weight_shape = add_weight_tensor->shape(); | |||
| } else { | |||
| MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node)); | |||
| auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>(); | |||
| auto add_weight_value = add_weight_value_node->value(); | |||
| MS_ASSERT(add_weight_value != nullptr); | |||
| auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>(); | |||
| if (add_weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope() | |||
| << " is not tensorPtr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c()); | |||
| MS_ASSERT(add_weight_data != nullptr); | |||
| auto value_abstract = add_weight_value_node->abstract(); | |||
| auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract); | |||
| add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape(); | |||
| } | |||
| if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { | |||
| for (int i = 0; i < kernel_nums; i++) { | |||
| (*add_bias_data)[i] = *add_weight_data; | |||
| } | |||
| } else { | |||
| if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | |||
| MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (fusion_bias->size() % add_bias_data.size() != 0) { | |||
| return false; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) { | |||
| MS_ASSERT(add_bias_data != nullptr); | |||
| MS_ASSERT(conv_bias_node != nullptr); | |||
| if (utils::isa<Parameter>(conv_bias_node)) { | |||
| auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>(); | |||
| if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) { | |||
| MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param()); | |||
| if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() || | |||
| conv_bias_tensor->shape()[0] != kernel_nums) { | |||
| MS_LOG(ERROR) << "conv_bias_node shape error"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c()); | |||
| MS_ASSERT(conv_bias_data != nullptr); | |||
| for (int i = 0; i < kernel_nums; i++) { | |||
| conv_bias_data[i] += add_bias_data[i]; | |||
| } | |||
| } else { | |||
| MS_ASSERT(utils::isa<ValueNode>(conv_bias_node)); | |||
| auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>(); | |||
| auto conv_bias_value = conv_bias_value_node->value(); | |||
| MS_ASSERT(conv_bias_value != nullptr); | |||
| auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>(); | |||
| if (conv_bias_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c()); | |||
| MS_ASSERT(conv_bias_data != nullptr); | |||
| for (int i = 0; i < kernel_nums; i++) { | |||
| conv_bias_data[i] += add_bias_data[i]; | |||
| } | |||
| for (size_t i = 0; i < fusion_bias->size(); ++i) { | |||
| fusion_bias->at(i) += add_bias_data[i % add_bias_data.size()]; | |||
| } | |||
| return lite::RET_OK; | |||
| return true; | |||
| } | |||
| tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) { | |||
| tensor::TensorPtr conv_weight_tensor; | |||
| if (utils::isa<ValueNode>(conv_weight_node)) { | |||
| auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>(); | |||
| auto conv_weight_value = conv_weight_value_node->value(); | |||
| MS_ASSERT(conv_weight_value != nullptr); | |||
| conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>(); | |||
| MS_ASSERT(conv_weight_tensor != nullptr); | |||
| } else { | |||
| MS_ASSERT(utils::isa<Parameter>(conv_weight_node)); | |||
| auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param(); | |||
| MS_ASSERT(conv_weight_param != nullptr); | |||
| conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param); | |||
| MS_ASSERT(conv_weight_tensor != nullptr); | |||
| } | |||
| return conv_weight_tensor; | |||
| } // namespace | |||
| const BaseRef ConvBiasaddFusion::DefinePattern() const { | |||
| auto conv_var = std::make_shared<CondVar>(IsConvExtendNode); | |||
| auto add_var = std::make_shared<CondVar>(IsAddNode); | |||
| auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData); | |||
| return VectorRef({add_var, conv_var, weight_var}); | |||
| } | |||
| int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(conv_node != nullptr); | |||
| MS_ASSERT(bias_node != nullptr); | |||
| AnfNodePtr conv_bias_node = nullptr; | |||
| AnfNodePtr conv_weight_node = nullptr; | |||
| if (conv_node->inputs().size() == kConvNoBiasLen) { | |||
| conv_weight_node = conv_node->input(kConvWeightIndex); | |||
| } else if (conv_node->inputs().size() == kConvWithBiasLen) { | |||
| conv_weight_node = conv_node->input(kConvWeightIndex); | |||
| conv_bias_node = conv_node->input(kConvBiasIndex); | |||
| } else { | |||
| MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4"; | |||
| return lite::RET_INPUT_TENSOR_ERROR; | |||
| bool ConvBiasaddFusion::CheckCanFusion(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { | |||
| MS_ASSERT(node != nullptr); | |||
| if (!utils::isa<CNode>(node)) { | |||
| return false; | |||
| } | |||
| auto add_cnode = node->cast<CNodePtr>(); | |||
| if (CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| auto prim_add = GetValueNode<PrimitivePtr>(add_cnode->input(0)); | |||
| MS_ASSERT(rim_add != nullptr); | |||
| auto add_act_ptr = prim_add->GetAttr(ops::kActivationType); | |||
| auto add_act = add_act_ptr == nullptr ? mindspore::NO_ACTIVATION | |||
| : static_cast<mindspore::ActivationType>(GetValue<int64_t>(add_act_ptr)); | |||
| auto conv_cnode = add_cnode->input(1)->cast<CNodePtr>(); | |||
| if (conv_cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (IsMultiOutputTensors(func_graph, conv_cnode)) { | |||
| return false; | |||
| } | |||
| if (conv_cnode->size() == kInputSizeFour) { | |||
| auto conv_bias = conv_cnode->input(kInputIndexThree); | |||
| if (conv_bias->isa<CNode>() || (conv_bias->isa<Parameter>() && !conv_bias->cast<ParameterPtr>()->has_default())) { | |||
| return false; | |||
| } | |||
| } | |||
| auto kernel_nums = Get_Kenrnel_nums(conv_node); | |||
| if (kernel_nums <= 0) { | |||
| MS_LOG(ERROR) << "kernel num less than 0"; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| auto prim_conv = GetValueNode<PrimitivePtr>(conv_cnode->input(0)); | |||
| MS_ASSERT(prim_conv != nullptr); | |||
| auto conv_act_ptr = prim_add->GetAttr(ops::kActivationType); | |||
| auto conv_act = add_act_ptr == nullptr ? mindspore::NO_ACTIVATION | |||
| : static_cast<mindspore::ActivationType>(GetValue<int64_t>(conv_act_ptr)); | |||
| if (add_act != mindspore::NO_ACTIVATION) { | |||
| if (conv_act != mindspore::NO_ACTIVATION || (add_act != mindspore::RELU && add_act != mindspore::RELU6)) { | |||
| return false; | |||
| } | |||
| } | |||
| auto add_bias_data = new (std::nothrow) float[kernel_nums]; | |||
| if (add_bias_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| return lite::RET_MEMORY_FAILED; | |||
| if (prim_conv->GetAttr(ops::kOutChannel) == nullptr) { | |||
| return false; | |||
| } | |||
| auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); | |||
| if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) { | |||
| delete[] add_bias_data; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| auto out_channel = GetValue<int64_t>(prim_conv->GetAttr(ops::kOutChannel)); | |||
| auto add_weight = add_cnode->input(kInputIndexTwo); | |||
| MS_ASSERT(add_weight != nullptr); | |||
| ShapeVector shape; | |||
| if (FetchShapeFromAbstract(add_weight->abstract(), &shape) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) { | |||
| delete[] add_bias_data; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| if (std::count_if(shape.begin(), shape.end(), [](int64_t dim) { return dim > 1; }) > 1) { | |||
| return false; | |||
| } | |||
| if (conv_bias_node != nullptr) { | |||
| if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) { | |||
| delete[] add_bias_data; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| auto element_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||
| return out_channel % element_num == 0; | |||
| } | |||
| int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { | |||
| MS_ASSERT(node != nullptr); | |||
| auto add_cnode = node->cast<CNodePtr>(); | |||
| MS_ASSERT(add_cnode != nullptr); | |||
| auto add_bias = add_cnode->input(kInputIndexTwo); | |||
| lite::DataInfo add_bias_info; | |||
| int status = lite::RET_ERROR; | |||
| if (add_bias->isa<Parameter>()) { | |||
| status = lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info); | |||
| } else if (add_bias->isa<ValueNode>()) { | |||
| status = lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info); | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(DEBUG) << "conv and add do fusion failed, please check"; | |||
| return status; | |||
| } | |||
| auto conv_cnode = add_cnode->input(1)->cast<CNodePtr>(); | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| lite::DataInfo conv_bias_info; | |||
| if (conv_cnode->size() > kInputSizeThree) { | |||
| auto conv_bias = conv_cnode->input(kInputIndexThree); | |||
| if (conv_bias->isa<Parameter>()) { | |||
| status = | |||
| lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info); | |||
| } else if (conv_bias->isa<ValueNode>()) { | |||
| status = | |||
| lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info); | |||
| } | |||
| if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) { | |||
| delete[] add_bias_data; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(DEBUG) << "conv and add do fusion failed, please check"; | |||
| return status; | |||
| } | |||
| delete[] add_bias_data; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kOutChannel) == nullptr) { | |||
| return lite::RET_ERROR; | |||
| } | |||
| int out_channel = GetValue<int64_t>(prim->GetAttr(ops::kOutChannel)); | |||
| std::vector<float> fusion_data; | |||
| if (!FuseBias(add_bias_info, conv_bias_info, &fusion_data, out_channel)) { | |||
| MS_LOG(DEBUG) << "conv and add do fusion failed, please check"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto conv_new_bias = | |||
| AddNewBiasNode(fusion_data.data(), func_graph, out_channel, static_cast<TypeId>(add_bias_info.data_type_)); | |||
| conv_new_bias->set_name(conv_cnode->fullname_with_scope() + "_bias"); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto tr = manager->Transact(); | |||
| if (conv_cnode->size() > kInputSizeThree) { | |||
| tr.SetEdge(conv_cnode, kInputIndexThree, conv_new_bias); | |||
| } else { | |||
| if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) { | |||
| delete[] add_bias_data; | |||
| return lite::RET_INVALID_OP_ATTR; | |||
| } | |||
| tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node); | |||
| auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); | |||
| conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); | |||
| conv_node->add_input(conv_new_bias); | |||
| tr.AddEdge(conv_cnode, conv_new_bias); | |||
| } | |||
| tr.Commit(); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| const BaseRef ConvBiasaddFusion::DefinePattern() const { | |||
| auto conv_var = std::make_shared<CondVar>(IsConvExtendNode); | |||
| auto add_var = std::make_shared<CondVar>(IsAddNode); | |||
| auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData); | |||
| return VectorRef({add_var, conv_var, weight_var}); | |||
| } | |||
| const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_LOG(DEBUG) << "Enter pass process"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| auto add_node = node->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(add_node) != lite::RET_OK || CheckInputSize(add_node, kAddInputsLength) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| if (CheckPrimitiveType(add_node, prim::kPrimAddFusion)) { | |||
| auto primitive_c = GetValueNode<PrimitiveCPtr>(add_node->input(0)); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::AddFusion>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::ops::AddFusion>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| if (primc->GetAttr(ops::kActivationType) != nullptr && primc->get_activation_type() != mindspore::NO_ACTIVATION) { | |||
| return add_node; | |||
| } | |||
| } | |||
| AnfNodePtr conv_node_anf = add_node->input(1); | |||
| if (CheckIfAnfNodeIsNull(conv_node_anf) != lite::RET_OK || IsMultiOutputTensors(func_graph, conv_node_anf)) { | |||
| return nullptr; | |||
| } | |||
| auto conv_node = conv_node_anf->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(conv_node) != lite::RET_OK) { | |||
| MS_ASSERT(func_graph != nullptr && node != nullptr); | |||
| if (!CheckCanFusion(func_graph, node)) { | |||
| return nullptr; | |||
| } | |||
| int ret = GenConvNewBias(func_graph, conv_node, add_node); | |||
| if (ret != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | |||
| if (DoFuison(func_graph, node) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| return conv_node; | |||
| auto add_cnode = node->cast<CNodePtr>(); | |||
| MS_ASSERT(add_cnode != nullptr); | |||
| return add_cnode->input(1); | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -28,6 +28,10 @@ class ConvBiasaddFusion : public PatternProcessPass { | |||
| ~ConvBiasaddFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| bool CheckCanFusion(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; | |||
| int DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -211,7 +211,7 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const | |||
| } | |||
| CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias); | |||
| if (!bias_flag) { | |||
| auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor); | |||
| auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor->data_type()); | |||
| delete[] bias_data; | |||
| bias_node->set_name(conv_node->fullname_with_scope() + "_bias"); | |||
| conv_node->add_input(bias_node); | |||
| @@ -130,15 +130,13 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set< | |||
| if (trans_info->pre_ == trans_info->post_) { | |||
| return false; | |||
| } | |||
| auto &dynamic_ops = GetDynamicFormatOpList(); | |||
| TransposeStrategy transpose_strategy; | |||
| for (auto &middle_cnode : middle_nodes) { | |||
| if (IsSpecialType(middle_cnode)) { | |||
| continue; | |||
| } | |||
| auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0)); | |||
| if (!lite::IsContain(dynamic_ops, middle_node_prim->name()) || | |||
| !transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) { | |||
| if (!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) { | |||
| return false; | |||
| } | |||
| } | |||
| @@ -642,7 +640,7 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (!lite::IsContain(GetDynamicFormatOpList(), prim->name())) { | |||
| if (!IsDynamicFormatOp(prim->name())) { | |||
| continue; | |||
| } | |||
| TransTypePair trans_insert_info; | |||
| @@ -16,6 +16,8 @@ | |||
| #include "tools/optimizer/graph/transpose_strategy.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| @@ -24,7 +26,7 @@ | |||
| #include "ops/fusion/activation.h" | |||
| #include "ops/fusion/slice_fusion.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/strided_slice.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -32,7 +34,9 @@ namespace { | |||
| constexpr size_t kFirstInput = 1; | |||
| constexpr size_t kHalfDivisor = 2; | |||
| constexpr size_t kOnnxStridedSlice = 6; | |||
| constexpr int kPaddingListLength = 8; | |||
| STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| manager = Manage(func_graph, true); | |||
| @@ -50,6 +54,268 @@ STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std:: | |||
| [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; }); | |||
| return lite::RET_OK; | |||
| } | |||
| bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) { | |||
| MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape->GetInputShape(cnode, 1); | |||
| if (shape.size() != kInputSizeFour) { | |||
| if (cnode->size() > kInputSizeTwo) { | |||
| shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo); | |||
| if (shape.size() != kInputSizeFour && !shape.empty()) { | |||
| return false; | |||
| } | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) { | |||
| std::vector<int> cur_axes; | |||
| for (size_t i = 0; i < origin_axes.size(); ++i) { | |||
| int axis = origin_axes[i]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| cur_axes.push_back(cur_axis); | |||
| } | |||
| std::sort(cur_axes.begin(), cur_axes.end()); | |||
| return cur_axes; | |||
| } | |||
| void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (input_index >= cnode->size() || axes.empty()) { | |||
| return; | |||
| } | |||
| auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index); | |||
| if (origin_input.size() != axes.size()) { | |||
| return; | |||
| } | |||
| std::vector<int> cur_input; | |||
| for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) { | |||
| for (size_t index = 0; index < axes.size(); ++index) { | |||
| int axis = axes[index]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| if (cur_axis == dim) { | |||
| cur_input.push_back(origin_input[index]); | |||
| } | |||
| } | |||
| } | |||
| auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(input_index), param_node); | |||
| } | |||
| STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis)); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| new_axis = kNC2NH[axis]; | |||
| } | |||
| prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis)); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0)); | |||
| if (crop_prim == nullptr) { | |||
| MS_LOG(ERROR) << "cnode is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto axis = crop_prim->get_axis(); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| auto offsets = crop_prim->get_offsets(); | |||
| if (trans_type == kNCHW2NHWC) { | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]}; | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } else { | |||
| auto new_axis = kNC2NH[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]}; | |||
| } else { | |||
| offsets.pop_back(); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (cnode->size() < kInputSizeThree) { | |||
| MS_LOG(ERROR) << "pad op need three inputs."; | |||
| return lite::RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto second_input = cnode->input(kInputIndexTwo); | |||
| lite::DataInfo data_info; | |||
| int status; | |||
| if (utils::isa<Parameter>(second_input)) { | |||
| status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info); | |||
| } else if (utils::isa<ValueNode>(second_input)) { | |||
| status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info); | |||
| } else { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get paddings failed."; | |||
| return status; | |||
| } | |||
| if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) != | |||
| kPaddingListLength) { | |||
| return lite::RET_OK; | |||
| } | |||
| std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo)); | |||
| auto data = reinterpret_cast<int32_t *>(data_info.data_.data()); | |||
| for (int i = 0; i < kPaddingListLength; ++i) { | |||
| padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data; | |||
| data += 1; | |||
| } | |||
| if (trans_type == kNCHW2NHWC) { | |||
| auto chanel_pad = padding_list[1]; | |||
| padding_list.erase(padding_list.begin() + 1); | |||
| padding_list.push_back(chanel_pad); | |||
| } else { | |||
| auto chanel_pad = padding_list.back(); | |||
| padding_list.pop_back(); | |||
| padding_list.insert(padding_list.begin() + 1, chanel_pad); | |||
| } | |||
| auto param_node = | |||
| BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(kInputIndexTwo), param_node); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kPaddings) != nullptr) { | |||
| std::vector<std::vector<int64_t>> padding_attr; | |||
| (void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr), | |||
| [](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); }); | |||
| prim->AddAttr(ops::kPaddings, MakeValue(padding_attr)); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo); | |||
| if (shape.empty()) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| int element_num = shape.front(); | |||
| auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0)); | |||
| std::vector<int> axes; | |||
| if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| axes.push_back(index); | |||
| } | |||
| } else { | |||
| auto origin_axes = prim->get_axes(); | |||
| std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes), | |||
| [](int64_t v) { return static_cast<int>(v); }); | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape); | |||
| } | |||
| auto tmp_axes = TransformOpAxesAttr(axes, trans_type); | |||
| std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end()); | |||
| prim->set_axes(new_axes); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, | |||
| NodeInferShape *node_infer_shape) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (cnode->size() != kOnnxStridedSlice) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour); | |||
| if (axes.empty()) { | |||
| MS_LOG(ERROR) << "strided slice input invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t index = 2; index < cnode->size(); ++index) { | |||
| if (index == kInputIndexFour) { | |||
| continue; | |||
| } | |||
| TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape); | |||
| } | |||
| auto cur_axes = TransformOpAxesAttr(axes, trans_type); | |||
| auto param_node = | |||
| BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| @@ -138,32 +404,31 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const | |||
| bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != kInputSizeFour) { | |||
| if (cnode->size() > kInputSizeTwo) { | |||
| shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo); | |||
| if (shape.size() != kInputSizeFour && !shape.empty()) { | |||
| return false; | |||
| } | |||
| } else { | |||
| return false; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (!IsDynamicFormatOp(prim->name())) { | |||
| return false; | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return false; | |||
| } | |||
| if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) { | |||
| return false; | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) || | |||
| CheckPrimitiveType(cnode, prim::kPrimPadFusion)) { | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return false; | |||
| } | |||
| if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) { | |||
| return false; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) { | |||
| return false; | |||
| } | |||
| } else if (IsDynamicFormatOpWithAxis(prim->name())) { | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -171,28 +436,20 @@ bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CN | |||
| STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != kInputSizeFour) { | |||
| if (cnode->size() > kInputSizeTwo) { | |||
| shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo); | |||
| if (shape.size() != kInputSizeFour && !shape.empty()) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { | |||
| return ChangeCommonOp(cnode, trans_type); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimCrop)) { | |||
| return ChangeOpCrop(cnode, trans_type); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { | |||
| return ChangeOpSlice(func_graph, cnode, trans_type); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { | |||
| return ChangeOpStrideSlice(func_graph, cnode, trans_type); | |||
| std::map<std::string, | |||
| std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>> | |||
| process_funcs = { | |||
| {prim::kPrimConcat->name(), ChangeCommonOp}, {prim::kPrimSplit->name(), ChangeCommonOp}, | |||
| {prim::kPrimCrop->name(), ChangeOpCrop}, {prim::kPrimPadFusion->name(), ChangeOpPad}, | |||
| {prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice}}; | |||
| auto iter = process_funcs.find(prim->name()); | |||
| if (iter != process_funcs.end()) { | |||
| return iter->second(func_graph, cnode, trans_type, &node_infer_shape_); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| @@ -273,190 +530,5 @@ void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, Tra | |||
| trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } | |||
| } | |||
| STATUS TransposeStrategy::ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis)); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| new_axis = kNC2NH[axis]; | |||
| } | |||
| prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis)); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0)); | |||
| if (crop_prim == nullptr) { | |||
| MS_LOG(ERROR) << "cnode is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto axis = crop_prim->get_axis(); | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| auto offsets = crop_prim->get_offsets(); | |||
| if (trans_type == kNCHW2NHWC) { | |||
| auto new_axis = kNH2NC[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]}; | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } else { | |||
| auto new_axis = kNC2NH[axis]; | |||
| if (new_axis == 0) { | |||
| offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]}; | |||
| } else if (new_axis == kInputIndexThree) { | |||
| offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]}; | |||
| } else { | |||
| offsets.pop_back(); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo); | |||
| if (shape.empty()) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| int element_num = shape.front(); | |||
| auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0)); | |||
| std::vector<int> axes; | |||
| if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| axes.push_back(index); | |||
| } | |||
| } else { | |||
| auto origin_axes = prim->get_axes(); | |||
| std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes), | |||
| [](int64_t v) { return static_cast<int>(v); }); | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| TransformAttrByAxes(func_graph, cnode, i, axes, trans_type); | |||
| } | |||
| auto tmp_axes = TransformOpAxesAttr(axes, trans_type); | |||
| std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end()); | |||
| prim->set_axes(new_axes); | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| FormatTransNodeType trans_type) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| if (trans_type == kNONE) { | |||
| MS_LOG(ERROR) << "trans_type is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (cnode->size() != kOnnxStridedSlice) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return lite::RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| std::vector<int> axes = node_infer_shape_.GetIntVecInput(cnode, kInputIndexFour); | |||
| if (axes.empty()) { | |||
| MS_LOG(ERROR) << "strided slice input invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| for (size_t index = 2; index < cnode->size(); ++index) { | |||
| if (index == kInputIndexFour) { | |||
| continue; | |||
| } | |||
| TransformAttrByAxes(func_graph, cnode, index, axes, trans_type); | |||
| } | |||
| auto cur_axes = TransformOpAxesAttr(axes, trans_type); | |||
| auto param_node = | |||
| BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node); | |||
| return lite::RET_OK; | |||
| } | |||
| void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes, FormatTransNodeType trans_type) { | |||
| if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) { | |||
| return; | |||
| } | |||
| auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index); | |||
| if (origin_input.size() != axes.size()) { | |||
| return; | |||
| } | |||
| std::vector<int> cur_input; | |||
| for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) { | |||
| for (size_t index = 0; index < axes.size(); ++index) { | |||
| int axis = axes[index]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| if (cur_axis == dim) { | |||
| cur_input.push_back(origin_input[index]); | |||
| } | |||
| } | |||
| } | |||
| auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope()); | |||
| func_graph->manager()->Replace(cnode->input(input_index), param_node); | |||
| } | |||
| std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes, | |||
| FormatTransNodeType trans_type) { | |||
| std::vector<int> cur_axes; | |||
| for (size_t i = 0; i < origin_axes.size(); ++i) { | |||
| int axis = origin_axes[i]; | |||
| if (axis < 0) { | |||
| axis += kInputSizeFour; | |||
| } | |||
| MS_ASSERT(axis >= 0 && axis < kInputSizeFour); | |||
| int cur_axis = kNH2NC[axis]; | |||
| if (trans_type == kNHWC2NCHW) { | |||
| cur_axis = kNC2NH[axis]; | |||
| } | |||
| cur_axes.push_back(cur_axis); | |||
| } | |||
| std::sort(cur_axes.begin(), cur_axes.end()); | |||
| return cur_axes; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -51,13 +51,6 @@ class TransposeStrategy { | |||
| bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count, | |||
| FormatTransNodeType *trans_type); | |||
| void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info); | |||
| STATUS ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type); | |||
| void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, | |||
| const std::vector<int> &axes, FormatTransNodeType trans_type); | |||
| std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type); | |||
| FmkType fmk_type_{converter::kFmkTypeMs}; | |||
| bool train_flag_{false}; | |||
| NodeInferShape node_infer_shape_; | |||
| @@ -15,75 +15,87 @@ | |||
| */ | |||
| #include "tools/optimizer/graph/update_conv2d_param_pass.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kNumDim0 = 0; | |||
| constexpr size_t kNumDim1 = 1; | |||
| constexpr size_t kNumDim2 = 2; | |||
| constexpr size_t kNumDim3 = 3; | |||
| constexpr int kAnfPopulaterInputNumTwo = 2; | |||
| void SetConvAttr(const PrimitivePtr &prim, const std::vector<int64_t> &kernel_size, int64_t in_channel, | |||
| int64_t out_channel) { | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->GetAttr(ops::kKernelSize) == nullptr) { | |||
| prim->AddAttr(ops::kKernelSize, MakeValue(kernel_size)); | |||
| } else { | |||
| auto origin_kernel_size = GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kKernelSize)); | |||
| if (std::any_of(origin_kernel_size.begin(), origin_kernel_size.end(), [](int64_t size) { return size <= 0; })) { | |||
| prim->AddAttr(ops::kKernelSize, MakeValue(kernel_size)); | |||
| } | |||
| } | |||
| if (prim->GetAttr(ops::kInChannel) == nullptr || GetValue<int64_t>(prim->GetAttr(ops::kInChannel)) <= 0) { | |||
| prim->AddAttr(ops::kInChannel, MakeValue(in_channel)); | |||
| } | |||
| if (prim->GetAttr(ops::kOutChannel) == nullptr || GetValue<int64_t>(prim->GetAttr(ops::kOutChannel)) <= 0) { | |||
| prim->AddAttr(ops::kOutChannel, MakeValue(out_channel)); | |||
| } | |||
| } | |||
| } // namespace | |||
| lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) { | |||
| STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (fmk_type_ != converter::kFmkTypeTf) { | |||
| return lite::RET_OK; | |||
| if (cnode->size() < kInputSizeThree) { | |||
| MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto weight = cnode->input(kInputIndexTwo); | |||
| if (weight == nullptr) { | |||
| MS_LOG(ERROR) << "conv2d's weight is invalid, now is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0)); | |||
| if (conv == nullptr) { | |||
| MS_LOG(DEBUG) << "cnode is invalid."; | |||
| auto abstract = weight->abstract(); | |||
| ShapeVector shape; | |||
| if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch shape from abstract failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (conv->GetAttr(ops::kFormat) == nullptr || | |||
| (conv->get_format() != mindspore::NHWC && conv->get_format() != mindspore::KHWC)) { | |||
| if (shape.empty()) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto weight_node = cnode->input(kAnfPopulaterInputNumTwo); | |||
| if (weight_node == nullptr) { | |||
| MS_LOG(DEBUG) << "Conv2D weight node is nullptr."; | |||
| if (shape.size() != kInputSizeFour) { | |||
| MS_LOG(ERROR) << "conv2d weight shape size is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!weight_node->isa<Parameter>()) { | |||
| MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto weight_param = weight_node->cast<ParameterPtr>(); | |||
| if (!weight_param->has_default()) { | |||
| MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; | |||
| return lite::RET_NO_CHANGE; | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(ops::kFormat) == nullptr) { | |||
| MS_LOG(ERROR) << "current conv2d's format is undefined."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto default_param = weight_param->default_param(); | |||
| auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param); | |||
| auto weight_shape = weight_tensor->shape(); | |||
| std::vector<int64_t> kernel_size = {weight_shape[kNumDim1], weight_shape[kNumDim2]}; | |||
| conv->set_kernel_size(kernel_size); | |||
| conv->set_in_channel(weight_shape[kNumDim3]); | |||
| conv->set_out_channel(weight_shape[kNumDim0]); | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS UpdateConv2DParamPass::UpdateDepthWiseConv2D(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0)); | |||
| if (conv == nullptr) { | |||
| MS_LOG(ERROR) << "cnode is invalid."; | |||
| auto format = static_cast<mindspore::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat))); | |||
| if (format != mindspore::NHWC && format != mindspore::NCHW) { | |||
| MS_LOG(ERROR) << "conv2d's format only support nhwc or nchw, now is " << format; | |||
| return lite::RET_ERROR; | |||
| } | |||
| int64_t channel_in = conv->GetAttr(ops::kInChannel) != nullptr ? conv->get_in_channel() : -1; | |||
| if (channel_in == -1) { | |||
| auto input_node = cnode->input(kAnfPopulaterInputNumTwo); | |||
| MS_ASSERT(input_node != nullptr); | |||
| if (input_node->isa<Parameter>()) { | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| auto param = param_node->default_param(); | |||
| auto weight = std::dynamic_pointer_cast<tensor::Tensor>(param); | |||
| conv->set_in_channel(static_cast<int64_t>(weight->shape().at(0))); | |||
| } | |||
| auto kernel_size = format == mindspore::NHWC ? ShapeVector{shape[1], shape[kInputIndexTwo]} | |||
| : ShapeVector{shape[kInputIndexTwo], shape[kInputIndexThree]}; | |||
| int64_t in_channel = format == mindspore::NHWC ? shape[kInputIndexThree] : shape[1]; | |||
| int64_t out_channel = shape[0]; | |||
| if (prim->GetAttr(ops::kGroup) == nullptr) { | |||
| bool is_depth_wise = | |||
| prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise)); | |||
| prim->AddAttr(ops::kGroup, MakeValue(is_depth_wise ? out_channel : 1)); | |||
| } | |||
| auto group = GetValue<int64_t>(prim->GetAttr(ops::kGroup)); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| std::swap(in_channel, out_channel); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| in_channel *= group; | |||
| } else { | |||
| out_channel *= group; | |||
| } | |||
| SetConvAttr(prim, kernel_size, in_channel, out_channel); | |||
| return lite::RET_OK; | |||
| } | |||
| @@ -92,28 +104,17 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| int status = lite::RET_OK; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto conv = GetValueNode<std::shared_ptr<mindspore::ops::Conv2DFusion>>(cnode->input(0)); | |||
| if (conv == nullptr) { | |||
| MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveC."; | |||
| return RET_ERROR; | |||
| } | |||
| if (conv->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(conv->GetAttr(ops::kIsDepthWise))) { | |||
| status = UpdateDepthWiseConv2D(cnode); | |||
| } else { | |||
| status = UpdateCommonConv2D(cnode); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "update con2d failed."; | |||
| return false; | |||
| if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || | |||
| CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { | |||
| if (UpdateConv2DAttr(cnode) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "update conv2d attr failed."; | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| @@ -16,24 +16,19 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| using mindspore::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class UpdateConv2DParamPass : public Pass { | |||
| public: | |||
| UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {} | |||
| UpdateConv2DParamPass() : Pass("UpdateConv2DParamPass") {} | |||
| ~UpdateConv2DParamPass() override = default; | |||
| lite::STATUS UpdateCommonConv2D(const CNodePtr &cnode); | |||
| static lite::STATUS UpdateDepthWiseConv2D(const CNodePtr &cnode); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| void SetFmkType(FmkType fmk_type) { this->fmk_type_ = fmk_type; } | |||
| private: | |||
| FmkType fmk_type_ = converter::kFmkTypeOnnx; | |||
| STATUS UpdateConv2DAttr(const CNodePtr &cnode); | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ | |||