|
|
|
@@ -39,6 +39,7 @@ bool IsConvExtendNode(const BaseRef &n) { |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsAddNode(const BaseRef &n) { |
|
|
|
if (utils::isa<AnfNodePtr>(n)) { |
|
|
|
auto anf_node = utils::cast<AnfNodePtr>(n); |
|
|
|
@@ -71,6 +72,115 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) { |
|
|
|
MS_ASSERT(bias_add_weight_node != nullptr); |
|
|
|
MS_ASSERT(add_bias_data != nullptr); |
|
|
|
MS_ASSERT(*add_bias_data != nullptr); |
|
|
|
float *add_weight_data = nullptr; |
|
|
|
ShapeVector add_weight_shape; |
|
|
|
if (utils::isa<Parameter>(bias_add_weight_node)) { |
|
|
|
auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>(); |
|
|
|
if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param()); |
|
|
|
if (add_weight_tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope() |
|
|
|
<< " is not tensorPtr."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c()); |
|
|
|
MS_ASSERT(add_weight_data != nullptr); |
|
|
|
add_weight_shape = add_weight_tensor->shape(); |
|
|
|
} else { |
|
|
|
MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node)); |
|
|
|
auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>(); |
|
|
|
auto add_weight_value = add_weight_value_node->value(); |
|
|
|
MS_ASSERT(add_weight_value != nullptr); |
|
|
|
auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>(); |
|
|
|
if (add_weight_tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope() |
|
|
|
<< " is not tensorPtr."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c()); |
|
|
|
MS_ASSERT(add_weight_data != nullptr); |
|
|
|
auto value_abstract = add_weight_value_node->abstract(); |
|
|
|
auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract); |
|
|
|
add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape(); |
|
|
|
} |
|
|
|
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { |
|
|
|
for (int i = 0; i < kernel_nums; i++) { |
|
|
|
(*add_bias_data)[i] = *add_weight_data; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { |
|
|
|
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) { |
|
|
|
MS_ASSERT(add_bias_data != nullptr); |
|
|
|
MS_ASSERT(conv_bias_node != nullptr); |
|
|
|
if (utils::isa<Parameter>(conv_bias_node)) { |
|
|
|
auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>(); |
|
|
|
if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param()); |
|
|
|
if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() || |
|
|
|
conv_bias_tensor->shape()[0] != kernel_nums) { |
|
|
|
MS_LOG(ERROR) << "conv_bias_node shape error"; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c()); |
|
|
|
MS_ASSERT(conv_bias_data != nullptr); |
|
|
|
for (int i = 0; i < kernel_nums; i++) { |
|
|
|
conv_bias_data[i] += add_bias_data[i]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_ASSERT(utils::isa<ValueNode>(conv_bias_node)); |
|
|
|
auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>(); |
|
|
|
auto conv_bias_value = conv_bias_value_node->value(); |
|
|
|
MS_ASSERT(conv_bias_value != nullptr); |
|
|
|
auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>(); |
|
|
|
if (conv_bias_tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c()); |
|
|
|
MS_ASSERT(conv_bias_data != nullptr); |
|
|
|
for (int i = 0; i < kernel_nums; i++) { |
|
|
|
conv_bias_data[i] += add_bias_data[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) { |
|
|
|
tensor::TensorPtr conv_weight_tensor; |
|
|
|
if (utils::isa<ValueNode>(conv_weight_node)) { |
|
|
|
auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>(); |
|
|
|
auto conv_weight_value = conv_weight_value_node->value(); |
|
|
|
MS_ASSERT(conv_weight_value != nullptr); |
|
|
|
conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>(); |
|
|
|
MS_ASSERT(conv_weight_tensor != nullptr); |
|
|
|
} else { |
|
|
|
MS_ASSERT(utils::isa<Parameter>(conv_weight_node)); |
|
|
|
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param(); |
|
|
|
MS_ASSERT(conv_weight_param != nullptr); |
|
|
|
conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param); |
|
|
|
MS_ASSERT(conv_weight_tensor != nullptr); |
|
|
|
} |
|
|
|
return conv_weight_tensor; |
|
|
|
} |
|
|
|
|
|
|
|
int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
MS_ASSERT(conv_node != nullptr); |
|
|
|
@@ -97,45 +207,30 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co |
|
|
|
return lite::RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); |
|
|
|
if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) { |
|
|
|
if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) { |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param(); |
|
|
|
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param); |
|
|
|
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c()); |
|
|
|
auto add_weight_shape = add_weight_tensor->shape(); |
|
|
|
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { |
|
|
|
for (int i = 0; i < kernel_nums; i++) { |
|
|
|
add_bias_data[i] = *add_weight_data; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { |
|
|
|
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_MEMORY_FAILED; |
|
|
|
} |
|
|
|
if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) { |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
if (conv_bias_node != nullptr) { |
|
|
|
if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) { |
|
|
|
if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) { |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param(); |
|
|
|
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param); |
|
|
|
if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) { |
|
|
|
MS_LOG(ERROR) << "conv_bias_node shape error"; |
|
|
|
if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) { |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c()); |
|
|
|
for (int i = 0; i < kernel_nums; i++) { |
|
|
|
conv_bias_data[i] += add_bias_data[i]; |
|
|
|
} |
|
|
|
delete[] add_bias_data; |
|
|
|
} else { |
|
|
|
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param(); |
|
|
|
auto conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param); |
|
|
|
if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) { |
|
|
|
delete[] add_bias_data; |
|
|
|
return lite::RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node); |
|
|
|
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); |
|
|
|
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); |
|
|
|
conv_node->add_input(conv_new_bias); |
|
|
|
@@ -146,7 +241,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co |
|
|
|
const BaseRef ConvBiasaddFusion::DefinePattern() const { |
|
|
|
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode); |
|
|
|
auto add_var = std::make_shared<CondVar>(IsAddNode); |
|
|
|
auto weight_var = std::make_shared<CondVar>(IsParamNode); |
|
|
|
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData); |
|
|
|
return VectorRef({add_var, conv_var, weight_var}); |
|
|
|
} |
|
|
|
|
|
|
|
|