Browse Source

!22431 [lite]optimizer dynamic-format's op judgement and support pad op

Merge pull request !22431 from 徐安越/primitive
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
e3f3976543
21 changed files with 601 additions and 570 deletions
  1. +1
    -1
      mindspore/lite/src/lite_session.cc
  2. +15
    -5
      mindspore/lite/src/ops/populate/conv2d_populate.cc
  3. +14
    -3
      mindspore/lite/src/ops/populate/deconv2d_populate.cc
  4. +2
    -3
      mindspore/lite/tools/converter/anf_transform.cc
  5. +13
    -14
      mindspore/lite/tools/converter/parser/inputs_adjust.cc
  6. +12
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  7. +0
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc
  8. +4
    -5
      mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc
  9. +0
    -3
      mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.cc
  10. +16
    -6
      mindspore/lite/tools/optimizer/common/format_utils.cc
  11. +2
    -1
      mindspore/lite/tools/optimizer/common/format_utils.h
  12. +6
    -4
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  13. +1
    -2
      mindspore/lite/tools/optimizer/common/gllo_utils.h
  14. +143
    -208
      mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc
  15. +4
    -0
      mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h
  16. +1
    -1
      mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc
  17. +2
    -4
      mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc
  18. +295
    -223
      mindspore/lite/tools/optimizer/graph/transpose_strategy.cc
  19. +0
    -7
      mindspore/lite/tools/optimizer/graph/transpose_strategy.h
  20. +67
    -66
      mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc
  21. +3
    -8
      mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h

+ 1
- 1
mindspore/lite/src/lite_session.cc View File

@@ -609,6 +609,7 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) {

// init init_ref_count for subgraphs and kernels
for (auto *kernel : this->kernels_) {
kernel->InitOutTensorInitRefCount();
#ifndef DELEGATE_CLIP
if (kernel->desc().arch == kernel::kDelegate) {
continue;
@@ -617,7 +618,6 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) {
if (IsIsolatedSubGraph(kernel)) {
static_cast<kernel::SubGraphKernel *>(kernel)->InitInputTensorInitRefCount();
}
kernel->InitOutTensorInitRefCount();
}
AdjustModelOutputTensorInitRefCount(model);
for (auto kernel : this->kernels_) {


+ 15
- 5
mindspore/lite/src/ops/populate/conv2d_populate.cc View File

@@ -41,20 +41,30 @@ OpParameter *PopulateConvParameter(const void *prim) {
auto stride = value->stride();
auto pad_list = value->pad_list();
auto dilation = value->dilation();
if (kernel_size == nullptr || stride == nullptr || dilation == nullptr) {
if (kernel_size != nullptr) {
if (kernel_size->size() < kMinShapeSizeTwo) {
MS_LOG(ERROR) << "kernel size is invalid.";
free(param);
return nullptr;
}
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
} else {
param->kernel_h_ = -1;
param->kernel_w_ = -1;
}
if (stride == nullptr || dilation == nullptr) {
MS_LOG(ERROR) << "kernel_size/stride/dilation is nullptr";
free(param);
return nullptr;
}
if (kernel_size->size() < kMinShapeSizeTwo || stride->size() < kMinShapeSizeTwo ||
dilation->size() < kMinShapeSizeTwo) {
if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) {
MS_LOG(ERROR) << "Invalid shape size!kernel_size size: " << kernel_size->size()
<< ", stride size: " << stride->size() << ", dilation size: " << dilation->size();
free(param);
return nullptr;
}
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));

param->group_ = static_cast<int>(value->group());
param->stride_h_ = static_cast<int>(*(stride->begin()));
param->stride_w_ = static_cast<int>(*(stride->begin() + 1));


+ 14
- 3
mindspore/lite/src/ops/populate/deconv2d_populate.cc View File

@@ -43,13 +43,24 @@ OpParameter *PopulateDeconvParameter(const void *prim) {
auto pad_list = value->pad_list();
auto dilation = value->dilation();
auto output_paddings = value->output_paddings();
if (kernel_size == nullptr || stride == nullptr || dilation == nullptr || output_paddings == nullptr) {
if (kernel_size != nullptr) {
if (kernel_size->size() < kMinShapeSizeTwo) {
MS_LOG(ERROR) << "kernel size is invalid.";
free(param);
return nullptr;
}
param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
} else {
param->kernel_h_ = -1;
param->kernel_w_ = -1;
}
if (stride == nullptr || dilation == nullptr || output_paddings == nullptr) {
MS_LOG(ERROR) << "nullptr";
free(param);
return nullptr;
}
if (kernel_size->size() < kMinShapeSizeTwo || stride->size() < kMinShapeSizeTwo ||
dilation->size() < kMinShapeSizeTwo) {
if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) {
MS_LOG(ERROR) << "Invalid shape size!kernel_size size: " << kernel_size->size()
<< ", stride size: " << stride->size() << ", dilation size: " << dilation->size()
<< ", output_paddings size:" << output_paddings->size();


+ 2
- 3
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -217,8 +217,6 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F
int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
convert_pm->AddPass(infershape_pass);
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
@@ -235,8 +233,9 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
}
auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
const_fold_pm->AddPass(infershape_pass);
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
update_conv2d_param_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(update_conv2d_param_pass);
optimizer->AddPassManager(const_fold_pm);
if (optimizer->Optimize(old_graph) == nullptr) {


+ 13
- 14
mindspore/lite/tools/converter/parser/inputs_adjust.cc View File

@@ -37,42 +37,38 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt
MS_LOG(DEBUG) << "there is no attr :" << attr_name;
return lite::RET_NO_CHANGE;
}
auto inputs = cnode->inputs();
if (static_cast<int>(inputs.size()) > input_num) {
if (static_cast<int>(cnode->size()) > input_num) {
primitive_c->EraseAttr(attr_name);
MS_LOG(DEBUG) << "input num has been meet, which is " << inputs.size();
MS_LOG(DEBUG) << "input num has been meet, which is " << cnode->size();
return lite::RET_OK;
} else if (static_cast<int>(inputs.size()) < input_num) {
} else if (static_cast<int>(cnode->size()) < input_num) {
MS_LOG(ERROR) << "input num is invalid.";
return lite::RET_ERROR;
}
AnfNodePtr param_node;
switch (flag) {
case 1: {
auto value_data = opt::CastToInt(value_ptr).front();
auto param_node =
param_node =
opt::BuildIntValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case kBuildInputFlagTwo: {
auto value_data = opt::CastToInt(value_ptr);
auto param_node =
param_node =
opt::BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case kBuildInputFlagThree: {
auto value_data = opt::CastToVec2DInt(value_ptr);
auto param_node =
param_node =
opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case kBuildInputFlagFour: {
auto value_data = GetValue<float>(value_ptr);
auto param_node =
param_node =
opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
default: {
@@ -80,8 +76,11 @@ STATUS InputAdjust::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePt
return lite::RET_ERROR;
}
}
cnode->set_inputs(inputs);

auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
tr.AddEdge(cnode, param_node);
tr.Commit();
return lite::RET_OK;
}



+ 12
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -124,6 +124,7 @@ STATUS GetConvChannel(const onnx::GraphProto &onnx_graph, const onnx::NodeProto
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (node_iter == onnx_graph.initializer().end()) {
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
return RET_NO_CHANGE;
} else {
std::vector<int> weight_shape;
auto size = (*node_iter).dims_size();
@@ -151,6 +152,12 @@ STATUS GetConvChannel(const onnx::GraphProto &onnx_graph, const onnx::NodeProto
return RET_ERROR;
}
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
} else {
return RET_NO_CHANGE;
}
if (dims.size() < kNumDim4) {
MS_LOG(ERROR) << "conv weight size is not 4D, please check.";
return RET_ERROR;
}
*channel_out = dims.at(0);
*channel_in = dims.at(3) * group;
@@ -211,11 +218,13 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
}

// get channel_out and channel_in
if (GetConvChannel(onnx_graph, onnx_node, group, &channel_out, &channel_in) != RET_OK) {
auto status = GetConvChannel(onnx_graph, onnx_node, group, &channel_out, &channel_in);
if (status == RET_OK) {
prim->set_in_channel(channel_in);
prim->set_out_channel(channel_out);
} else if (status != RET_NO_CHANGE) {
return nullptr;
}
prim->set_in_channel(channel_in);
prim->set_out_channel(channel_out);

// parse activationType
prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);


+ 0
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc View File

@@ -77,9 +77,6 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (node_iter == onnx_graph.initializer().end()) {
// in_channel and out_channnel is set to 1 by default.
prim->set_in_channel(1);
prim->set_out_channel(1);
MS_LOG(WARNING) << "parsing of channelIn/Out is delayed.";
} else {
std::vector<int> weight_shape;


+ 4
- 5
mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc View File

@@ -60,9 +60,6 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
prim->set_out_channel(kernels[3]);
prim->set_in_channel(kernels[2]);
} else {
prim->set_kernel_size({0, 0});
prim->set_out_channel(1);
prim->set_in_channel(1);
MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed";
}

@@ -84,8 +81,10 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
}
if (tf_op.op() == "DepthwiseConv2dNative") {
prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
prim->set_group(prim->get_in_channel());
prim->set_out_channel(prim->get_in_channel());
if (prim->GetAttr(ops::kInChannel) != nullptr) {
prim->set_group(prim->get_in_channel());
prim->set_out_channel(prim->get_in_channel());
}
}

return prim.release();


+ 0
- 3
mindspore/lite/tools/converter/parser/tf/tf_deconv_parser.cc View File

@@ -60,9 +60,6 @@ ops::PrimitiveC *TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
prim->set_out_channel(kernels[2]);
prim->set_in_channel(kernels[3]);
} else {
prim->set_kernel_size({-1, -1});
prim->set_out_channel(-1);
prim->set_in_channel(-1);
MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed";
}



+ 16
- 6
mindspore/lite/tools/optimizer/common/format_utils.cc View File

@@ -39,6 +39,7 @@
#include "ops/fusion/div_fusion.h"
#include "ops/fusion/max_pool_fusion.h"
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/pad_fusion.h"
#include "ops/fusion/pow_fusion.h"
#include "ops/fusion/prelu_fusion.h"
#include "ops/fusion/slice_fusion.h"
@@ -98,18 +99,27 @@ static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {

static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ops::kNameInstanceNorm, {1}}};

// a certain op whose input's format is not fixed.
static const std::vector<std::string> DynamicFormatOpList = {
ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNameDivFusion, ops::kNamePowFusion,
ops::kNameStridedSlice, ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion,
ops::kNameCrop, ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast};
// a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
{ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
{ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
{ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
{ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
{ops::kNameQuantDTypeCast, false}};

static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}};

const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
const std::unordered_map<int, int> &GetNC2NHAxisMap() { return NC2NHAxisMap; }
const std::vector<std::string> &GetDynamicFormatOpList() { return DynamicFormatOpList; }
bool IsDynamicFormatOp(const std::string &op_type) {
return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
}
bool IsDynamicFormatOpWithAxis(const std::string &op_type) {
auto iter = DynamicFormatOpList.find(op_type);
return iter != DynamicFormatOpList.end() && iter->second;
}

Format GetFormat(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);


+ 2
- 1
mindspore/lite/tools/optimizer/common/format_utils.h View File

@@ -34,8 +34,9 @@ struct TransTypePair {
};
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap();
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap();
const std::unordered_map<int, int> &GetNC2NHAxisMap();
const std::vector<std::string> &GetDynamicFormatOpList();
bool IsDynamicFormatOp(const std::string &op_type);
bool IsDynamicFormatOpWithAxis(const std::string &op_type);
Format GetFormat(const CNodePtr &cnode);
STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm);
void RemoveIfMonad(const CNodePtr &cnode);


+ 6
- 4
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -496,13 +496,12 @@ int CheckLeastInputSize(const CNodePtr &node, const int size) {
return lite::RET_OK;
}

ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
const tensor::TensorPtr &weight_tensor) {
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
auto bias_parameter = func_graph->add_parameter();
MS_ASSERT(bias_parameter != nullptr);
std::vector<int64_t> shape_vector = {kernel_num};
auto tensor_info = lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector,
weight_tensor->data_type());
auto tensor_info =
lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, type_id);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "create tensor info failed.";
return nullptr;
@@ -613,6 +612,9 @@ bool IsParamOrValueNodeWithData(const BaseRef &n) {
}
}
if (utils::isa<ParameterPtr>(n)) {
if (!utils::cast<ParameterPtr>(n)->has_default()) {
return false;
}
auto param = utils::cast<ParameterPtr>(n)->default_param();
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param);
if (tensor == nullptr || tensor->data_c() == nullptr) {


+ 1
- 2
mindspore/lite/tools/optimizer/common/gllo_utils.h View File

@@ -82,8 +82,7 @@ int CheckIfNodeIsParamOrValue(const AnfNodePtr &node);

int CheckLeastInputSize(const CNodePtr &node, int size);

ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
const tensor::TensorPtr &weight_tensor);
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);

bool IsParamNode(const BaseRef &n);



+ 143
- 208
mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc View File

@@ -14,12 +14,14 @@
* limitations under the License.
*/
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
#include <functional>
#include <memory>
#include <vector>
#include "ops/fusion/add_fusion.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "tools/common/tensor_util.h"
#include "utils/utils.h"
#include "tools/anf_exporter/fetch_content.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"

@@ -48,238 +50,171 @@ bool IsAddNode(const BaseRef &n) {
return false;
}

int Get_Kenrnel_nums(const CNodePtr &conv_node) {
MS_ASSERT(conv_node != nullptr);
auto value_primitive = conv_node->input(0);
auto value_node = value_primitive->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto primitive = value->cast<PrimitiveCPtr>();
MS_ASSERT(primitive != nullptr);
if (primitive->isa<mindspore::ops::Conv2DFusion>()) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::Conv2DFusion>>(primitive));
auto primc = utils::cast<std::shared_ptr<mindspore::ops::Conv2DFusion>>(primitive);
MS_ASSERT(primc != nullptr);
return primc->get_out_channel();
} else if (primitive->isa<mindspore::ops::Conv2dTransposeFusion>()) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::Conv2dTransposeFusion>>(primitive));
auto primc = utils::cast<std::shared_ptr<mindspore::ops::Conv2dTransposeFusion>>(primitive);
MS_ASSERT(primc != nullptr);
return primc->get_out_channel();
} else {
MS_LOG(ERROR) << "Unsupported opType, " << primitive->name();
return 0;
}
}

int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) {
MS_ASSERT(bias_add_weight_node != nullptr);
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(*add_bias_data != nullptr);
float *add_weight_data = nullptr;
ShapeVector add_weight_shape;
if (utils::isa<Parameter>(bias_add_weight_node)) {
auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>();
if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
bool FuseBias(const lite::DataInfo &add_bias, const lite::DataInfo &conv_bias, std::vector<float> *fusion_bias,
int out_channel) {
MS_ASSERT(conv_bias != nullptr);
if ((add_bias.data_type_ != TypeId::kNumberTypeFloat32 && add_bias.data_type_ != TypeId::kNumberTypeFloat) ||
add_bias.data_.empty()) {
return false;
}
if (out_channel <= 0) {
return false;
}
std::vector<float> add_bias_data(add_bias.data_.size() / sizeof(float));
if (memcpy_s(add_bias_data.data(), add_bias.data_.size(), add_bias.data_.data(), add_bias.data_.size()) != EOK) {
return false;
}
fusion_bias->resize(out_channel, 0);
if (!conv_bias.data_.empty()) {
if (conv_bias.data_type_ != TypeId::kNumberTypeFloat32 && conv_bias.data_type_ != TypeId::kNumberTypeFloat &&
conv_bias.data_.size() != out_channel * sizeof(float)) {
return false;
}
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param());
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
if (memcpy_s(fusion_bias->data(), conv_bias.data_.size(), conv_bias.data_.data(), conv_bias.data_.size()) != EOK) {
return false;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
add_weight_shape = add_weight_tensor->shape();
} else {
MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node));
auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>();
auto add_weight_value = add_weight_value_node->value();
MS_ASSERT(add_weight_value != nullptr);
auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>();
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
auto value_abstract = add_weight_value_node->abstract();
auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract);
add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape();
}
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
for (int i = 0; i < kernel_nums; i++) {
(*add_bias_data)[i] = *add_weight_data;
}
} else {
if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
return lite::RET_ERROR;
}
if (fusion_bias->size() % add_bias_data.size() != 0) {
return false;
}
return lite::RET_OK;
}

int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) {
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(conv_bias_node != nullptr);
if (utils::isa<Parameter>(conv_bias_node)) {
auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>();
if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
}
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param());
if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() ||
conv_bias_tensor->shape()[0] != kernel_nums) {
MS_LOG(ERROR) << "conv_bias_node shape error";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
} else {
MS_ASSERT(utils::isa<ValueNode>(conv_bias_node));
auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>();
auto conv_bias_value = conv_bias_value_node->value();
MS_ASSERT(conv_bias_value != nullptr);
auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>();
if (conv_bias_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr.";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
for (size_t i = 0; i < fusion_bias->size(); ++i) {
fusion_bias->at(i) += add_bias_data[i % add_bias_data.size()];
}
return lite::RET_OK;
return true;
}

tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) {
tensor::TensorPtr conv_weight_tensor;
if (utils::isa<ValueNode>(conv_weight_node)) {
auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>();
auto conv_weight_value = conv_weight_value_node->value();
MS_ASSERT(conv_weight_value != nullptr);
conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>();
MS_ASSERT(conv_weight_tensor != nullptr);
} else {
MS_ASSERT(utils::isa<Parameter>(conv_weight_node));
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
MS_ASSERT(conv_weight_param != nullptr);
conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
MS_ASSERT(conv_weight_tensor != nullptr);
}
return conv_weight_tensor;
} // namespace
const BaseRef ConvBiasaddFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
auto add_var = std::make_shared<CondVar>(IsAddNode);
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
return VectorRef({add_var, conv_var, weight_var});
}

int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(conv_node != nullptr);
MS_ASSERT(bias_node != nullptr);
AnfNodePtr conv_bias_node = nullptr;
AnfNodePtr conv_weight_node = nullptr;
if (conv_node->inputs().size() == kConvNoBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
} else if (conv_node->inputs().size() == kConvWithBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
conv_bias_node = conv_node->input(kConvBiasIndex);
} else {
MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
return lite::RET_INPUT_TENSOR_ERROR;
bool ConvBiasaddFusion::CheckCanFusion(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
if (!utils::isa<CNode>(node)) {
return false;
}
auto add_cnode = node->cast<CNodePtr>();
if (CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) {
return false;
}
auto prim_add = GetValueNode<PrimitivePtr>(add_cnode->input(0));
MS_ASSERT(rim_add != nullptr);
auto add_act_ptr = prim_add->GetAttr(ops::kActivationType);
auto add_act = add_act_ptr == nullptr ? mindspore::NO_ACTIVATION
: static_cast<mindspore::ActivationType>(GetValue<int64_t>(add_act_ptr));
auto conv_cnode = add_cnode->input(1)->cast<CNodePtr>();
if (conv_cnode == nullptr) {
return false;
}
if (IsMultiOutputTensors(func_graph, conv_cnode)) {
return false;
}
if (conv_cnode->size() == kInputSizeFour) {
auto conv_bias = conv_cnode->input(kInputIndexThree);
if (conv_bias->isa<CNode>() || (conv_bias->isa<Parameter>() && !conv_bias->cast<ParameterPtr>()->has_default())) {
return false;
}
}
auto kernel_nums = Get_Kenrnel_nums(conv_node);
if (kernel_nums <= 0) {
MS_LOG(ERROR) << "kernel num less than 0";
return lite::RET_INVALID_OP_ATTR;
auto prim_conv = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_ASSERT(prim_conv != nullptr);
auto conv_act_ptr = prim_add->GetAttr(ops::kActivationType);
auto conv_act = add_act_ptr == nullptr ? mindspore::NO_ACTIVATION
: static_cast<mindspore::ActivationType>(GetValue<int64_t>(conv_act_ptr));
if (add_act != mindspore::NO_ACTIVATION) {
if (conv_act != mindspore::NO_ACTIVATION || (add_act != mindspore::RELU && add_act != mindspore::RELU6)) {
return false;
}
}
auto add_bias_data = new (std::nothrow) float[kernel_nums];
if (add_bias_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
return lite::RET_MEMORY_FAILED;

if (prim_conv->GetAttr(ops::kOutChannel) == nullptr) {
return false;
}
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
auto out_channel = GetValue<int64_t>(prim_conv->GetAttr(ops::kOutChannel));
auto add_weight = add_cnode->input(kInputIndexTwo);
MS_ASSERT(add_weight != nullptr);
ShapeVector shape;
if (FetchShapeFromAbstract(add_weight->abstract(), &shape) != lite::RET_OK) {
return false;
}
if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
if (std::count_if(shape.begin(), shape.end(), [](int64_t dim) { return dim > 1; }) > 1) {
return false;
}
if (conv_bias_node != nullptr) {
if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
auto element_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
return out_channel % element_num == 0;
}

int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
auto add_cnode = node->cast<CNodePtr>();
MS_ASSERT(add_cnode != nullptr);
auto add_bias = add_cnode->input(kInputIndexTwo);
lite::DataInfo add_bias_info;
int status = lite::RET_ERROR;
if (add_bias->isa<Parameter>()) {
status = lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
} else if (add_bias->isa<ValueNode>()) {
status = lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
}
if (status != lite::RET_OK) {
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
return status;
}
auto conv_cnode = add_cnode->input(1)->cast<CNodePtr>();
MS_ASSERT(conv_cnode != nullptr);
lite::DataInfo conv_bias_info;
if (conv_cnode->size() > kInputSizeThree) {
auto conv_bias = conv_cnode->input(kInputIndexThree);
if (conv_bias->isa<Parameter>()) {
status =
lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
} else if (conv_bias->isa<ValueNode>()) {
status =
lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
}
if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
if (status != lite::RET_OK) {
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
return status;
}
delete[] add_bias_data;
}
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kOutChannel) == nullptr) {
return lite::RET_ERROR;
}
int out_channel = GetValue<int64_t>(prim->GetAttr(ops::kOutChannel));
std::vector<float> fusion_data;
if (!FuseBias(add_bias_info, conv_bias_info, &fusion_data, out_channel)) {
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
return lite::RET_ERROR;
}
auto conv_new_bias =
AddNewBiasNode(fusion_data.data(), func_graph, out_channel, static_cast<TypeId>(add_bias_info.data_type_));
conv_new_bias->set_name(conv_cnode->fullname_with_scope() + "_bias");
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
if (conv_cnode->size() > kInputSizeThree) {
tr.SetEdge(conv_cnode, kInputIndexThree, conv_new_bias);
} else {
if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node);
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
conv_node->add_input(conv_new_bias);
tr.AddEdge(conv_cnode, conv_new_bias);
}
tr.Commit();
return lite::RET_OK;
}
} // namespace
const BaseRef ConvBiasaddFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
auto add_var = std::make_shared<CondVar>(IsAddNode);
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
return VectorRef({add_var, conv_var, weight_var});
}

const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "Enter pass process";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
auto add_node = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(add_node) != lite::RET_OK || CheckInputSize(add_node, kAddInputsLength) != lite::RET_OK) {
return nullptr;
}
if (CheckPrimitiveType(add_node, prim::kPrimAddFusion)) {
auto primitive_c = GetValueNode<PrimitiveCPtr>(add_node->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::ops::AddFusion>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::ops::AddFusion>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetAttr(ops::kActivationType) != nullptr && primc->get_activation_type() != mindspore::NO_ACTIVATION) {
return add_node;
}
}

AnfNodePtr conv_node_anf = add_node->input(1);
if (CheckIfAnfNodeIsNull(conv_node_anf) != lite::RET_OK || IsMultiOutputTensors(func_graph, conv_node_anf)) {
return nullptr;
}
auto conv_node = conv_node_anf->cast<CNodePtr>();
if (CheckIfCNodeIsNull(conv_node) != lite::RET_OK) {
MS_ASSERT(func_graph != nullptr && node != nullptr);
if (!CheckCanFusion(func_graph, node)) {
return nullptr;
}
int ret = GenConvNewBias(func_graph, conv_node, add_node);
if (ret != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
if (DoFuison(func_graph, node) != lite::RET_OK) {
return nullptr;
}
return conv_node;
auto add_cnode = node->cast<CNodePtr>();
MS_ASSERT(add_cnode != nullptr);
return add_cnode->input(1);
}
} // namespace mindspore::opt

+ 4
- 0
mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h View File

@@ -28,6 +28,10 @@ class ConvBiasaddFusion : public PatternProcessPass {
~ConvBiasaddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
bool CheckCanFusion(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
int DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
};
} // namespace opt
} // namespace mindspore


+ 1
- 1
mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc View File

@@ -211,7 +211,7 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const
}
CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias);
if (!bias_flag) {
auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor);
auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor->data_type());
delete[] bias_data;
bias_node->set_name(conv_node->fullname_with_scope() + "_bias");
conv_node->add_input(bias_node);


+ 2
- 4
mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc View File

@@ -130,15 +130,13 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
if (trans_info->pre_ == trans_info->post_) {
return false;
}
auto &dynamic_ops = GetDynamicFormatOpList();
TransposeStrategy transpose_strategy;
for (auto &middle_cnode : middle_nodes) {
if (IsSpecialType(middle_cnode)) {
continue;
}
auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
if (!lite::IsContain(dynamic_ops, middle_node_prim->name()) ||
!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) {
if (!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) {
return false;
}
}
@@ -642,7 +640,7 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (!lite::IsContain(GetDynamicFormatOpList(), prim->name())) {
if (!IsDynamicFormatOp(prim->name())) {
continue;
}
TransTypePair trans_insert_info;


+ 295
- 223
mindspore/lite/tools/optimizer/graph/transpose_strategy.cc View File

@@ -16,6 +16,8 @@

#include "tools/optimizer/graph/transpose_strategy.h"
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <vector>
#include <string>
@@ -24,7 +26,7 @@
#include "ops/fusion/activation.h"
#include "ops/fusion/slice_fusion.h"
#include "ops/op_utils.h"
#include "ops/strided_slice.h"
#include "tools/anf_exporter/fetch_content.h"

namespace mindspore {
namespace opt {
@@ -32,7 +34,9 @@ namespace {
constexpr size_t kFirstInput = 1;
constexpr size_t kHalfDivisor = 2;
constexpr size_t kOnnxStridedSlice = 6;
constexpr int kPaddingListLength = 8;
STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
@@ -50,6 +54,268 @@ STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::
[](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
return lite::RET_OK;
}

bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) {
MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr);
auto shape = node_infer_shape->GetInputShape(cnode, 1);
if (shape.size() != kInputSizeFour) {
if (cnode->size() > kInputSizeTwo) {
shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
if (shape.size() != kInputSizeFour && !shape.empty()) {
return false;
}
} else {
return false;
}
}
return true;
}

std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) {
std::vector<int> cur_axes;
for (size_t i = 0; i < origin_axes.size(); ++i) {
int axis = origin_axes[i];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
cur_axes.push_back(cur_axis);
}
std::sort(cur_axes.begin(), cur_axes.end());
return cur_axes;
}

void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (input_index >= cnode->size() || axes.empty()) {
return;
}
auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
if (origin_input.size() != axes.size()) {
return;
}
std::vector<int> cur_input;
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
for (size_t index = 0; index < axes.size(); ++index) {
int axis = axes[index];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
if (cur_axis == dim) {
cur_input.push_back(origin_input[index]);
}
}
}
auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(input_index), param_node);
}

STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
if (axis < 0) {
axis += kInputSizeFour;
}
auto new_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
new_axis = kNC2NH[axis];
}
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
return lite::RET_OK;
}

STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
if (crop_prim == nullptr) {
MS_LOG(ERROR) << "cnode is invalid.";
return lite::RET_ERROR;
}
auto axis = crop_prim->get_axis();
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
auto offsets = crop_prim->get_offsets();
if (trans_type == kNCHW2NHWC) {
auto new_axis = kNH2NC[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
} else {
offsets.push_back(0);
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
} else {
auto new_axis = kNC2NH[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
} else {
offsets.pop_back();
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
}
return lite::RET_OK;
}

STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
if (cnode->size() < kInputSizeThree) {
MS_LOG(ERROR) << "pad op need three inputs.";
return lite::RET_INPUT_TENSOR_ERROR;
}
auto second_input = cnode->input(kInputIndexTwo);
lite::DataInfo data_info;
int status;
if (utils::isa<Parameter>(second_input)) {
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
} else if (utils::isa<ValueNode>(second_input)) {
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
} else {
return lite::RET_NOT_SUPPORT;
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get paddings failed.";
return status;
}
if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) !=
kPaddingListLength) {
return lite::RET_OK;
}
std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo));
auto data = reinterpret_cast<int32_t *>(data_info.data_.data());
for (int i = 0; i < kPaddingListLength; ++i) {
padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data;
data += 1;
}
if (trans_type == kNCHW2NHWC) {
auto chanel_pad = padding_list[1];
padding_list.erase(padding_list.begin() + 1);
padding_list.push_back(chanel_pad);
} else {
auto chanel_pad = padding_list.back();
padding_list.pop_back();
padding_list.insert(padding_list.begin() + 1, chanel_pad);
}
auto param_node =
BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(kInputIndexTwo), param_node);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kPaddings) != nullptr) {
std::vector<std::vector<int64_t>> padding_attr;
(void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr),
[](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); });
prim->AddAttr(ops::kPaddings, MakeValue(padding_attr));
}
return lite::RET_OK;
}

STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
if (shape.empty()) {
return lite::RET_NOT_SUPPORT;
}
int element_num = shape.front();
auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
std::vector<int> axes;
if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
for (int index = 0; index < element_num; ++index) {
axes.push_back(index);
}
} else {
auto origin_axes = prim->get_axes();
std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
[](int64_t v) { return static_cast<int>(v); });
}
for (size_t i = 2; i < cnode->size(); ++i) {
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape);
}
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
prim->set_axes(new_axes);
return lite::RET_OK;
}

STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
if (cnode->size() != kOnnxStridedSlice) {
return lite::RET_NOT_SUPPORT;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour);
if (axes.empty()) {
MS_LOG(ERROR) << "strided slice input invalid.";
return lite::RET_ERROR;
}
for (size_t index = 2; index < cnode->size(); ++index) {
if (index == kInputIndexFour) {
continue;
}
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape);
}
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
auto param_node =
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
return lite::RET_OK;
}
} // namespace

AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
@@ -138,32 +404,31 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const

bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != kInputSizeFour) {
if (cnode->size() > kInputSizeTwo) {
shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo);
if (shape.size() != kInputSizeFour && !shape.empty()) {
return false;
}
} else {
return false;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (!IsDynamicFormatOp(prim->name())) {
return false;
}
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kAxis) == nullptr) {
return false;
}
if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
return false;
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) ||
CheckPrimitiveType(cnode, prim::kPrimPadFusion)) {
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return false;
}
if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) {
return false;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
return false;
}
} else if (IsDynamicFormatOpWithAxis(prim->name())) {
if (prim->GetAttr(ops::kAxis) == nullptr) {
return false;
}
}
return true;
}
@@ -171,28 +436,20 @@ bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CN
STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != kInputSizeFour) {
if (cnode->size() > kInputSizeTwo) {
shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo);
if (shape.size() != kInputSizeFour && !shape.empty()) {
return lite::RET_NOT_SUPPORT;
}
} else {
return lite::RET_NOT_SUPPORT;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
return ChangeCommonOp(cnode, trans_type);
}
if (CheckPrimitiveType(cnode, prim::kPrimCrop)) {
return ChangeOpCrop(cnode, trans_type);
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
return ChangeOpSlice(func_graph, cnode, trans_type);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
return lite::RET_NOT_SUPPORT;
}
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
return ChangeOpStrideSlice(func_graph, cnode, trans_type);
std::map<std::string,
std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>>
process_funcs = {
{prim::kPrimConcat->name(), ChangeCommonOp}, {prim::kPrimSplit->name(), ChangeCommonOp},
{prim::kPrimCrop->name(), ChangeOpCrop}, {prim::kPrimPadFusion->name(), ChangeOpPad},
{prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice}};
auto iter = process_funcs.find(prim->name());
if (iter != process_funcs.end()) {
return iter->second(func_graph, cnode, trans_type, &node_infer_shape_);
}
return lite::RET_OK;
}
@@ -273,190 +530,5 @@ void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, Tra
trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
}
}

STATUS TransposeStrategy::ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
if (axis < 0) {
axis += kInputSizeFour;
}
auto new_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
new_axis = kNC2NH[axis];
}
prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
return lite::RET_OK;
}

STATUS TransposeStrategy::ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
if (crop_prim == nullptr) {
MS_LOG(ERROR) << "cnode is invalid.";
return lite::RET_ERROR;
}
auto axis = crop_prim->get_axis();
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
auto offsets = crop_prim->get_offsets();
if (trans_type == kNCHW2NHWC) {
auto new_axis = kNH2NC[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
} else {
offsets.push_back(0);
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
} else {
auto new_axis = kNC2NH[axis];
if (new_axis == 0) {
offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
} else if (new_axis == kInputIndexThree) {
offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
} else {
offsets.pop_back();
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
}
return lite::RET_OK;
}

STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
auto shape = node_infer_shape_.GetInputShape(cnode, kInputIndexTwo);
if (shape.empty()) {
return lite::RET_NOT_SUPPORT;
}
int element_num = shape.front();
auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
std::vector<int> axes;
if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
for (int index = 0; index < element_num; ++index) {
axes.push_back(index);
}
} else {
auto origin_axes = prim->get_axes();
std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
[](int64_t v) { return static_cast<int>(v); });
}
for (size_t i = 2; i < cnode->size(); ++i) {
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type);
}
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
prim->set_axes(new_axes);
return lite::RET_OK;
}

STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (trans_type == kNONE) {
MS_LOG(ERROR) << "trans_type is invalid.";
return lite::RET_ERROR;
}
if (cnode->size() != kOnnxStridedSlice) {
return lite::RET_NOT_SUPPORT;
}
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return lite::RET_NOT_SUPPORT;
}
}
std::vector<int> axes = node_infer_shape_.GetIntVecInput(cnode, kInputIndexFour);
if (axes.empty()) {
MS_LOG(ERROR) << "strided slice input invalid.";
return lite::RET_ERROR;
}
for (size_t index = 2; index < cnode->size(); ++index) {
if (index == kInputIndexFour) {
continue;
}
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type);
}
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
auto param_node =
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
return lite::RET_OK;
}

void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type) {
if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) {
return;
}
auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index);
if (origin_input.size() != axes.size()) {
return;
}
std::vector<int> cur_input;
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
for (size_t index = 0; index < axes.size(); ++index) {
int axis = axes[index];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
if (cur_axis == dim) {
cur_input.push_back(origin_input[index]);
}
}
}
auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(input_index), param_node);
}

std::vector<int> TransposeStrategy::TransformOpAxesAttr(const std::vector<int> &origin_axes,
FormatTransNodeType trans_type) {
std::vector<int> cur_axes;
for (size_t i = 0; i < origin_axes.size(); ++i) {
int axis = origin_axes[i];
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
int cur_axis = kNH2NC[axis];
if (trans_type == kNHWC2NCHW) {
cur_axis = kNC2NH[axis];
}
cur_axes.push_back(cur_axis);
}
std::sort(cur_axes.begin(), cur_axes.end());
return cur_axes;
}
} // namespace opt
} // namespace mindspore

+ 0
- 7
mindspore/lite/tools/optimizer/graph/transpose_strategy.h View File

@@ -51,13 +51,6 @@ class TransposeStrategy {
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
FormatTransNodeType *trans_type);
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
STATUS ChangeCommonOp(const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpCrop(const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type);
std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type);
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
NodeInferShape node_infer_shape_;


+ 67
- 66
mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc View File

@@ -15,75 +15,87 @@
*/
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include <memory>
#include <utility>
#include <vector>
#include "ops/fusion/conv2d_fusion.h"
#include "mindspore/lite/include/errorcode.h"

namespace mindspore::opt {
namespace {
constexpr size_t kNumDim0 = 0;
constexpr size_t kNumDim1 = 1;
constexpr size_t kNumDim2 = 2;
constexpr size_t kNumDim3 = 3;
constexpr int kAnfPopulaterInputNumTwo = 2;
void SetConvAttr(const PrimitivePtr &prim, const std::vector<int64_t> &kernel_size, int64_t in_channel,
int64_t out_channel) {
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kKernelSize) == nullptr) {
prim->AddAttr(ops::kKernelSize, MakeValue(kernel_size));
} else {
auto origin_kernel_size = GetValue<std::vector<int64_t>>(prim->GetAttr(ops::kKernelSize));
if (std::any_of(origin_kernel_size.begin(), origin_kernel_size.end(), [](int64_t size) { return size <= 0; })) {
prim->AddAttr(ops::kKernelSize, MakeValue(kernel_size));
}
}
if (prim->GetAttr(ops::kInChannel) == nullptr || GetValue<int64_t>(prim->GetAttr(ops::kInChannel)) <= 0) {
prim->AddAttr(ops::kInChannel, MakeValue(in_channel));
}
if (prim->GetAttr(ops::kOutChannel) == nullptr || GetValue<int64_t>(prim->GetAttr(ops::kOutChannel)) <= 0) {
prim->AddAttr(ops::kOutChannel, MakeValue(out_channel));
}
}
} // namespace

lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (fmk_type_ != converter::kFmkTypeTf) {
return lite::RET_OK;
if (cnode->size() < kInputSizeThree) {
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1;
return lite::RET_ERROR;
}
auto weight = cnode->input(kInputIndexTwo);
if (weight == nullptr) {
MS_LOG(ERROR) << "conv2d's weight is invalid, now is nullptr.";
return lite::RET_ERROR;
}
auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0));
if (conv == nullptr) {
MS_LOG(DEBUG) << "cnode is invalid.";
auto abstract = weight->abstract();
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape from abstract failed.";
return lite::RET_ERROR;
}
if (conv->GetAttr(ops::kFormat) == nullptr ||
(conv->get_format() != mindspore::NHWC && conv->get_format() != mindspore::KHWC)) {
if (shape.empty()) {
return lite::RET_OK;
}
auto weight_node = cnode->input(kAnfPopulaterInputNumTwo);
if (weight_node == nullptr) {
MS_LOG(DEBUG) << "Conv2D weight node is nullptr.";
if (shape.size() != kInputSizeFour) {
MS_LOG(ERROR) << "conv2d weight shape size is invalid.";
return lite::RET_ERROR;
}
if (!weight_node->isa<Parameter>()) {
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
return lite::RET_NO_CHANGE;
}
auto weight_param = weight_node->cast<ParameterPtr>();
if (!weight_param->has_default()) {
MS_LOG(DEBUG) << "Conv2D weight node is not parameter.";
return lite::RET_NO_CHANGE;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(ERROR) << "current conv2d's format is undefined.";
return lite::RET_ERROR;
}
auto default_param = weight_param->default_param();
auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param);
auto weight_shape = weight_tensor->shape();
std::vector<int64_t> kernel_size = {weight_shape[kNumDim1], weight_shape[kNumDim2]};
conv->set_kernel_size(kernel_size);
conv->set_in_channel(weight_shape[kNumDim3]);
conv->set_out_channel(weight_shape[kNumDim0]);
return lite::RET_OK;
}

lite::STATUS UpdateConv2DParamPass::UpdateDepthWiseConv2D(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto conv = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode->input(0));
if (conv == nullptr) {
MS_LOG(ERROR) << "cnode is invalid.";
auto format = static_cast<mindspore::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
if (format != mindspore::NHWC && format != mindspore::NCHW) {
MS_LOG(ERROR) << "conv2d's format only support nhwc or nchw, now is " << format;
return lite::RET_ERROR;
}
int64_t channel_in = conv->GetAttr(ops::kInChannel) != nullptr ? conv->get_in_channel() : -1;
if (channel_in == -1) {
auto input_node = cnode->input(kAnfPopulaterInputNumTwo);
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
auto param = param_node->default_param();
auto weight = std::dynamic_pointer_cast<tensor::Tensor>(param);
conv->set_in_channel(static_cast<int64_t>(weight->shape().at(0)));
}
auto kernel_size = format == mindspore::NHWC ? ShapeVector{shape[1], shape[kInputIndexTwo]}
: ShapeVector{shape[kInputIndexTwo], shape[kInputIndexThree]};
int64_t in_channel = format == mindspore::NHWC ? shape[kInputIndexThree] : shape[1];
int64_t out_channel = shape[0];
if (prim->GetAttr(ops::kGroup) == nullptr) {
bool is_depth_wise =
prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
prim->AddAttr(ops::kGroup, MakeValue(is_depth_wise ? out_channel : 1));
}
auto group = GetValue<int64_t>(prim->GetAttr(ops::kGroup));
if (CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
std::swap(in_channel, out_channel);
}
if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
in_channel *= group;
} else {
out_channel *= group;
}

SetConvAttr(prim, kernel_size, in_channel, out_channel);
return lite::RET_OK;
}

@@ -92,28 +104,17 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
int status = lite::RET_OK;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto conv = GetValueNode<std::shared_ptr<mindspore::ops::Conv2DFusion>>(cnode->input(0));
if (conv == nullptr) {
MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveC.";
return RET_ERROR;
}
if (conv->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(conv->GetAttr(ops::kIsDepthWise))) {
status = UpdateDepthWiseConv2D(cnode);
} else {
status = UpdateCommonConv2D(cnode);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "update con2d failed.";
return false;
if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
if (UpdateConv2DAttr(cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "update conv2d attr failed.";
return false;
}
}
}
return true;


+ 3
- 8
mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.h View File

@@ -16,24 +16,19 @@

#ifndef MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
#include "schema/inner/model_generated.h"
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"

using mindspore::converter::FmkType;
namespace mindspore::opt {
class UpdateConv2DParamPass : public Pass {
public:
UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {}
UpdateConv2DParamPass() : Pass("UpdateConv2DParamPass") {}
~UpdateConv2DParamPass() override = default;
lite::STATUS UpdateCommonConv2D(const CNodePtr &cnode);
static lite::STATUS UpdateDepthWiseConv2D(const CNodePtr &cnode);
bool Run(const FuncGraphPtr &graph) override;
void SetFmkType(FmkType fmk_type) { this->fmk_type_ = fmk_type; }

private:
FmkType fmk_type_ = converter::kFmkTypeOnnx;
STATUS UpdateConv2DAttr(const CNodePtr &cnode);
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_

Loading…
Cancel
Save