Browse Source

bias add fusion

pull/15507/head
zhang__sss 4 years ago
parent
commit
08068e6a7a
3 changed files with 160 additions and 27 deletions
  1. +34
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  2. +4
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.h
  3. +122
    -27
      mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc

+ 34
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -417,6 +417,15 @@ int CheckIfNodeIsParam(const AnfNodePtr &node) {
return lite::RET_OK;
}

int CheckIfNodeIsParamOrValue(const AnfNodePtr &node) {
if (node == nullptr || (node != nullptr && !utils::isa<ParameterPtr>(node) && !utils::isa<ValueNode>(node))) {
MS_LOG(DEBUG) << "The Node is not param or value node.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
return lite::RET_INVALID_OP_ATTR;
}
return lite::RET_OK;
}

int CheckInputSize(const CNodePtr &node, const int size) {
if (static_cast<int>(node->inputs().size()) != size) {
MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
@@ -534,6 +543,31 @@ bool IsParamNode(const BaseRef &n) {
return tensor->data_c() != nullptr;
}

bool IsParamOrValueNodeWithData(const BaseRef &n) {
if (utils::isa<ValueNode>(n)) {
auto value_node = utils::cast<ValueNodePtr>(n);
auto value = value_node->value();
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
if (tensor == nullptr || tensor->data_c() == nullptr) {
return false;
}
return true;
} else {
return false;
}
}
if (utils::isa<ParameterPtr>(n)) {
auto param = utils::cast<ParameterPtr>(n)->default_param();
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param);
if (tensor == nullptr || tensor->data_c() == nullptr) {
return false;
}
return true;
}
return false;
}

bool IsConvNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);


+ 4
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.h View File

@@ -63,6 +63,8 @@ int CheckInputSize(const CNodePtr &node, int size);

int CheckIfNodeIsParam(const AnfNodePtr &node);

int CheckIfNodeIsParamOrValue(const AnfNodePtr &node);

int CheckLeastInputSize(const CNodePtr &node, int size);

ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
@@ -70,6 +72,8 @@ ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, in

bool IsParamNode(const BaseRef &n);

bool IsParamOrValueNodeWithData(const BaseRef &n);

bool IsConvNode(const BaseRef &n);

bool IsPoolingNode(const BaseRef &n);


+ 122
- 27
mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc View File

@@ -39,6 +39,7 @@ bool IsConvExtendNode(const BaseRef &n) {
}
return false;
}

bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
@@ -71,6 +72,115 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) {
return 0;
}
}

int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) {
MS_ASSERT(bias_add_weight_node != nullptr);
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(*add_bias_data != nullptr);
float *add_weight_data = nullptr;
ShapeVector add_weight_shape;
if (utils::isa<Parameter>(bias_add_weight_node)) {
auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>();
if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
}
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param());
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
add_weight_shape = add_weight_tensor->shape();
} else {
MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node));
auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>();
auto add_weight_value = add_weight_value_node->value();
MS_ASSERT(add_weight_value != nullptr);
auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>();
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
auto value_abstract = add_weight_value_node->abstract();
auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract);
add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape();
}
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
for (int i = 0; i < kernel_nums; i++) {
(*add_bias_data)[i] = *add_weight_data;
}
} else {
if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) {
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(conv_bias_node != nullptr);
if (utils::isa<Parameter>(conv_bias_node)) {
auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>();
if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
}
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param());
if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() ||
conv_bias_tensor->shape()[0] != kernel_nums) {
MS_LOG(ERROR) << "conv_bias_node shape error";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
} else {
MS_ASSERT(utils::isa<ValueNode>(conv_bias_node));
auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>();
auto conv_bias_value = conv_bias_value_node->value();
MS_ASSERT(conv_bias_value != nullptr);
auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>();
if (conv_bias_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr.";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
}
return lite::RET_OK;
}

tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) {
tensor::TensorPtr conv_weight_tensor;
if (utils::isa<ValueNode>(conv_weight_node)) {
auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>();
auto conv_weight_value = conv_weight_value_node->value();
MS_ASSERT(conv_weight_value != nullptr);
conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>();
MS_ASSERT(conv_weight_tensor != nullptr);
} else {
MS_ASSERT(utils::isa<Parameter>(conv_weight_node));
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
MS_ASSERT(conv_weight_param != nullptr);
conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
MS_ASSERT(conv_weight_tensor != nullptr);
}
return conv_weight_tensor;
}

int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(conv_node != nullptr);
@@ -97,45 +207,30 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
return lite::RET_MEMORY_FAILED;
}
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) {
if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param);
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
auto add_weight_shape = add_weight_tensor->shape();
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
for (int i = 0; i < kernel_nums; i++) {
add_bias_data[i] = *add_weight_data;
}
} else {
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
delete[] add_bias_data;
return lite::RET_MEMORY_FAILED;
}
if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
if (conv_bias_node != nullptr) {
if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) {
if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) {
MS_LOG(ERROR) << "conv_bias_node shape error";
if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
delete[] add_bias_data;
} else {
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node);
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
conv_node->add_input(conv_new_bias);
@@ -146,7 +241,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
const BaseRef ConvBiasaddFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
auto add_var = std::make_shared<CondVar>(IsAddNode);
auto weight_var = std::make_shared<CondVar>(IsParamNode);
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
return VectorRef({add_var, conv_var, weight_var});
}



Loading…
Cancel
Save