Browse Source

!10538 [MS][LITE]Add MirrorPad Parser, Fix transpose fusion bug

From: @gongdaguo
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
332087a1c3
7 changed files with 79 additions and 15 deletions
  1. +60
    -6
      mindspore/lite/src/ops/pad.cc
  2. +1
    -0
      mindspore/lite/src/ops/pad.h
  3. +2
    -0
      mindspore/lite/src/ops/primitive_c.cc
  4. +0
    -4
      mindspore/lite/src/ops/while.cc
  5. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  6. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc
  7. +14
    -5
      mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc

+ 60
- 6
mindspore/lite/src/ops/pad.cc View File

@@ -15,7 +15,7 @@
*/

#include "src/ops/pad.h"
#include <string>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
@@ -32,6 +32,52 @@ void Pad::SetPaddingMode(int padding_mode) {
this->primitive_->value.AsPad()->paddingMode = (schema::PaddingMode)padding_mode;
}
void Pad::SetConstantValue(float constant_value) { this->primitive_->value.AsPad()->constantValue = constant_value; }
int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Pad;
}
if (this->primitive_->value.type != schema::PrimitiveType_Pad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PadT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
string paddingmode = "REFLECT";
if (prim.GetAttr("mode") == nullptr) {
MS_LOG(ERROR) << "get mode failed!";
delete this->primitive_;
delete attr;
this->primitive_ = nullptr;
attr = nullptr;
return RET_ERROR;
} else {
paddingmode = GetValue<string>(prim.GetAttr("mode"));
}
if (paddingmode == "REFLECT") {
attr->paddingMode = schema::PaddingMode_REFLECT;
} else if (paddingmode == "SYMMETRIC") {
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
} else {
MS_LOG(ERROR) << "model type not supported!";
delete this->primitive_;
delete attr;
this->primitive_ = nullptr;
attr = nullptr;
return RET_ERROR;
}
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else

@@ -94,14 +140,22 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
auto paddings_tensor = inputs.at(1);
int rank = static_cast<int>(inputs.front()->shape().size());
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank);
int *paddings_data = reinterpret_cast<int *>(paddings_tensor->MutableData());
if (paddings_data == nullptr) {
if (paddings_tensor->MutableData() == nullptr) {
return RET_INFER_ERR;
}
paddings.clear();
for (auto i = 0; i < rank; ++i) {
paddings.emplace_back(paddings_data[i * 2]);
paddings.emplace_back(paddings_data[i * 2 + 1]);
if (paddings_tensor->data_type() == mindspore::kNumberTypeInt64) {
auto paddings_data = reinterpret_cast<int64_t *>(paddings_tensor->MutableData());
for (auto i = 0; i < rank; ++i) {
paddings.emplace_back(paddings_data[i * 2]);
paddings.emplace_back(paddings_data[i * 2 + 1]);
}
} else if (paddings_tensor->data_type() == mindspore::kNumberTypeInt32) {
auto paddings_data = reinterpret_cast<int32_t *>(paddings_tensor->MutableData());
for (auto i = 0; i < rank; ++i) {
paddings.emplace_back(paddings_data[i * 2]);
paddings.emplace_back(paddings_data[i * 2 + 1]);
}
}
}



+ 1
- 0
mindspore/lite/src/ops/pad.h View File

@@ -36,6 +36,7 @@ class Pad : public PrimitiveC {
void SetPaddings(const std::vector<int> &paddings);
void SetPaddingMode(int padding_mode);
void SetConstantValue(float constant_value);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 2
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -591,6 +591,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Dropout>(prim, inputs, quantType);
} else if (op_type == "While") {
return NewPrimitiveC<While>(prim, inputs, quantType);
} else if (op_type == "MirrorPad") {
return NewPrimitiveC<Pad>(prim, inputs, quantType);
} else if (op_type == "GatherV2") {
return NewPrimitiveC<Gather>(prim, inputs, quantType);
} else if (op_type == "OnesLike") {


+ 0
- 4
mindspore/lite/src/ops/while.cc View File

@@ -56,10 +56,6 @@ int While::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
attr->bodySubgraphIndex = GetValue<bool>(prim.GetAttr("body_subgraph_index"));
attr->condSubgraphIndex = GetValue<bool>(prim.GetAttr("cond_subgraph_index"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -64,6 +64,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
execute_weight_ = nullptr;
}

// init bias


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc View File

@@ -81,6 +81,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
execute_weight_ = nullptr;
}

// init bias


+ 14
- 5
mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc View File

@@ -194,11 +194,20 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc
}
// todo multi output,other edge need insert nh2nc node
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
if ((pre_node_output_indexs.size() != 1) &&
(node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat)) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
if (pre_node_output_indexs.size() != 1) {
if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
for (auto pre_node_output_index : pre_node_output_indexs) {
MS_ASSERT(graph->nodes.size() > pre_node_output_index);
if (graph->nodes.at(pre_node_output_index)->primitive->value.type == schema::PrimitiveType_Pad) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
}
}
} else {
pre_nh2nc_nodes->clear();


Loading…
Cancel
Save