From 08068e6a7a7bbc831c61400c2f1f57d5837fe784 Mon Sep 17 00:00:00 2001 From: zhang__sss Date: Wed, 21 Apr 2021 16:28:20 +0800 Subject: [PATCH] bias add fusion --- .../lite/tools/optimizer/common/gllo_utils.cc | 34 ++++ .../lite/tools/optimizer/common/gllo_utils.h | 4 + .../optimizer/fusion/conv_biasadd_fusion.cc | 149 ++++++++++++++---- 3 files changed, 160 insertions(+), 27 deletions(-) diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index bbc98c9fec..4ccc20330c 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -417,6 +417,15 @@ int CheckIfNodeIsParam(const AnfNodePtr &node) { return lite::RET_OK; } +int CheckIfNodeIsParamOrValue(const AnfNodePtr &node) { + if (node == nullptr || (node != nullptr && !utils::isa(node) && !utils::isa(node))) { + MS_LOG(DEBUG) << "The Node is not param or value node."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); + return lite::RET_INVALID_OP_ATTR; + } + return lite::RET_OK; +} + int CheckInputSize(const CNodePtr &node, const int size) { if (static_cast(node->inputs().size()) != size) { MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); @@ -534,6 +543,31 @@ bool IsParamNode(const BaseRef &n) { return tensor->data_c() != nullptr; } +bool IsParamOrValueNodeWithData(const BaseRef &n) { + if (utils::isa(n)) { + auto value_node = utils::cast(n); + auto value = value_node->value(); + if (value->isa()) { + auto tensor = value->cast(); + if (tensor == nullptr || tensor->data_c() == nullptr) { + return false; + } + return true; + } else { + return false; + } + } + if (utils::isa(n)) { + auto param = utils::cast(n)->default_param(); + auto tensor = std::dynamic_pointer_cast(param); + if (tensor == nullptr || tensor->data_c() == nullptr) { + return false; + } + return true; + } + return false; +} + bool IsConvNode(const BaseRef &n) { if (utils::isa(n)) { auto anf_node = utils::cast(n); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 3d86d71fed..aece7064d3 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -63,6 +63,8 @@ int CheckInputSize(const CNodePtr &node, int size); int CheckIfNodeIsParam(const AnfNodePtr &node); +int CheckIfNodeIsParamOrValue(const AnfNodePtr &node); + int CheckLeastInputSize(const CNodePtr &node, int size); ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, @@ -70,6 +72,8 @@ ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, in bool IsParamNode(const BaseRef &n); +bool IsParamOrValueNodeWithData(const BaseRef &n); + bool IsConvNode(const BaseRef &n); bool IsPoolingNode(const BaseRef &n); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index d664b8fde5..e9f65abaaa 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -39,6 +39,7 @@ bool IsConvExtendNode(const BaseRef &n) { } return false; } + bool IsAddNode(const BaseRef &n) { if (utils::isa(n)) { auto anf_node = utils::cast(n); @@ -71,6 +72,115 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { return 0; } } + +int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) { + MS_ASSERT(bias_add_weight_node != nullptr); + MS_ASSERT(add_bias_data != nullptr); + MS_ASSERT(*add_bias_data != nullptr); + float *add_weight_data = nullptr; + ShapeVector add_weight_shape; + if (utils::isa(bias_add_weight_node)) { + auto add_weight_param_node = bias_add_weight_node->cast(); + if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) { + MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr."; + return lite::RET_ERROR; + } + auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param_node->default_param()); + if (add_weight_tensor == nullptr) { + MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope() + << " is not tensorPtr."; + return lite::RET_ERROR; + } + add_weight_data = reinterpret_cast(add_weight_tensor->data_c()); + MS_ASSERT(add_weight_data != nullptr); + add_weight_shape = add_weight_tensor->shape(); + } else { + MS_ASSERT(utils::isa(bias_add_weight_node)); + auto add_weight_value_node = bias_add_weight_node->cast(); + auto add_weight_value = add_weight_value_node->value(); + MS_ASSERT(add_weight_value != nullptr); + auto add_weight_tensor = add_weight_value->cast(); + if (add_weight_tensor == nullptr) { + MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope() + << " is not tensorPtr."; + return lite::RET_ERROR; + } + add_weight_data = reinterpret_cast(add_weight_tensor->data_c()); + MS_ASSERT(add_weight_data != nullptr); + auto value_abstract = add_weight_value_node->abstract(); + auto value_abstract_tensor = utils::cast(value_abstract); + add_weight_shape = utils::cast(value_abstract_tensor->BuildShape())->shape(); + } + if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { + for (int i = 0; i < kernel_nums; i++) { + (*add_bias_data)[i] = *add_weight_data; + } + } else { + if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) { + MS_ASSERT(add_bias_data != nullptr); + MS_ASSERT(conv_bias_node != nullptr); + if (utils::isa(conv_bias_node)) { + auto conv_bias_param_node = conv_bias_node->cast(); + if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) { + MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr."; + return lite::RET_ERROR; + } + auto conv_bias_tensor = std::dynamic_pointer_cast(conv_bias_param_node->default_param()); + if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() || + conv_bias_tensor->shape()[0] != kernel_nums) { + MS_LOG(ERROR) << "conv_bias_node shape error"; + return lite::RET_ERROR; + } + auto conv_bias_data = reinterpret_cast(conv_bias_tensor->data_c()); + MS_ASSERT(conv_bias_data != nullptr); + for (int i = 0; i < kernel_nums; i++) { + conv_bias_data[i] += add_bias_data[i]; + } + } else { + MS_ASSERT(utils::isa(conv_bias_node)); + auto conv_bias_value_node = conv_bias_node->cast(); + auto conv_bias_value = conv_bias_value_node->value(); + MS_ASSERT(conv_bias_value != nullptr); + auto conv_bias_tensor = conv_bias_value->cast(); + if (conv_bias_tensor == nullptr) { + MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr."; + return lite::RET_ERROR; + } + auto conv_bias_data = reinterpret_cast(conv_bias_tensor->data_c()); + MS_ASSERT(conv_bias_data != nullptr); + for (int i = 0; i < kernel_nums; i++) { + conv_bias_data[i] += add_bias_data[i]; + } + } + return lite::RET_OK; +} + +tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) { + tensor::TensorPtr conv_weight_tensor; + if (utils::isa(conv_weight_node)) { + auto conv_weight_value_node = conv_weight_node->cast(); + auto conv_weight_value = conv_weight_value_node->value(); + MS_ASSERT(conv_weight_value != nullptr); + conv_weight_tensor = conv_weight_value->cast(); + MS_ASSERT(conv_weight_tensor != nullptr); + } else { + MS_ASSERT(utils::isa(conv_weight_node)); + auto conv_weight_param = conv_weight_node->cast()->default_param(); + MS_ASSERT(conv_weight_param != nullptr); + conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + MS_ASSERT(conv_weight_tensor != nullptr); + } + return conv_weight_tensor; +} + int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(conv_node != nullptr); @@ -97,45 +207,30 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co return lite::RET_MEMORY_FAILED; } auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); - if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) { + if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) { delete[] add_bias_data; return lite::RET_INVALID_OP_ATTR; } - auto add_weight_param = bias_add_weight->cast()->default_param(); - auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); - auto add_weight_data = reinterpret_cast(add_weight_tensor->data_c()); - auto add_weight_shape = add_weight_tensor->shape(); - if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { - for (int i = 0; i < kernel_nums; i++) { - add_bias_data[i] = *add_weight_data; - } - } else { - if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; - delete[] add_bias_data; - return lite::RET_MEMORY_FAILED; - } + if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) { + delete[] add_bias_data; + return lite::RET_INVALID_OP_ATTR; } if (conv_bias_node != nullptr) { - if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) { + if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) { delete[] add_bias_data; return lite::RET_INVALID_OP_ATTR; } - auto conv_bias_param = conv_bias_node->cast()->default_param(); - auto conv_bias_tensor = std::dynamic_pointer_cast(conv_bias_param); - if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) { - MS_LOG(ERROR) << "conv_bias_node shape error"; + if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) { delete[] add_bias_data; return lite::RET_INVALID_OP_ATTR; } - auto conv_bias_data = reinterpret_cast(conv_bias_tensor->data_c()); - for (int i = 0; i < kernel_nums; i++) { - conv_bias_data[i] += add_bias_data[i]; - } delete[] add_bias_data; } else { - auto conv_weight_param = conv_weight_node->cast()->default_param(); - auto conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) { + delete[] add_bias_data; + return lite::RET_INVALID_OP_ATTR; + } + tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node); auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); conv_node->add_input(conv_new_bias); @@ -146,7 +241,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co const BaseRef ConvBiasaddFusion::DefinePattern() const { auto conv_var = std::make_shared(IsConvExtendNode); auto add_var = std::make_shared(IsAddNode); - auto weight_var = std::make_shared(IsParamNode); + auto weight_var = std::make_shared(IsParamOrValueNodeWithData); return VectorRef({add_var, conv_var, weight_var}); }