From 3f659076048fe9ad21406cf9c126eb29df287774 Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Thu, 24 Dec 2020 20:47:32 +0800 Subject: [PATCH] add mirrorpad and fix transpose bug aa --- mindspore/lite/src/ops/pad.cc | 66 +++++++++++++++++-- mindspore/lite/src/ops/pad.h | 1 + mindspore/lite/src/ops/primitive_c.cc | 2 + mindspore/lite/src/ops/while.cc | 4 -- .../kernel/arm/fp16/convolution_fp16.cc | 1 + .../arm/fp16/convolution_winograd_fp16.cc | 1 + .../graph/global_format_transform_pass.cc | 19 ++++-- 7 files changed, 79 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index f9ac5ae519..4473934538 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -15,7 +15,7 @@ */ #include "src/ops/pad.h" - +#include #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" #endif @@ -32,6 +32,52 @@ void Pad::SetPaddingMode(int padding_mode) { this->primitive_->value.AsPad()->paddingMode = (schema::PaddingMode)padding_mode; } void Pad::SetConstantValue(float constant_value) { this->primitive_->value.AsPad()->constantValue = constant_value; } +int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Pad; + } + if (this->primitive_->value.type != schema::PrimitiveType_Pad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::PadT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + string paddingmode = "REFLECT"; + if (prim.GetAttr("mode") == nullptr) { + MS_LOG(ERROR) << "get mode failed!"; + delete this->primitive_; + delete attr; + this->primitive_ = nullptr; + attr = nullptr; + return RET_ERROR; + } else { + paddingmode = GetValue(prim.GetAttr("mode")); + } + if (paddingmode == "REFLECT") { + attr->paddingMode = schema::PaddingMode_REFLECT; + } else if (paddingmode == "SYMMETRIC") { + attr->paddingMode = schema::PaddingMode_SYMMETRIC; + } else { + MS_LOG(ERROR) << "model type not supported!"; + delete this->primitive_; + delete attr; + this->primitive_ = nullptr; + attr = nullptr; + return RET_ERROR; + } + this->primitive_->value.value = attr; + } + return RET_OK; +} #else @@ -94,14 +140,22 @@ int Pad::InferShape(std::vector inputs, std::vector outputs) auto paddings_tensor = inputs.at(1); int rank = static_cast(inputs.front()->shape().size()); MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); - int *paddings_data = reinterpret_cast(paddings_tensor->MutableData()); - if (paddings_data == nullptr) { + if (paddings_tensor->MutableData() == nullptr) { return RET_INFER_ERR; } paddings.clear(); - for (auto i = 0; i < rank; ++i) { - paddings.emplace_back(paddings_data[i * 2]); - paddings.emplace_back(paddings_data[i * 2 + 1]); + if (paddings_tensor->data_type() == mindspore::kNumberTypeInt64) { + auto paddings_data = reinterpret_cast(paddings_tensor->MutableData()); + for (auto i = 0; i < rank; ++i) { + paddings.emplace_back(paddings_data[i * 2]); + paddings.emplace_back(paddings_data[i * 2 + 1]); + } + } else if (paddings_tensor->data_type() == mindspore::kNumberTypeInt32) { + auto paddings_data = reinterpret_cast(paddings_tensor->MutableData()); + for (auto i = 0; i < rank; ++i) { + paddings.emplace_back(paddings_data[i * 2]); + paddings.emplace_back(paddings_data[i * 2 + 1]); + } } } diff --git a/mindspore/lite/src/ops/pad.h b/mindspore/lite/src/ops/pad.h index d7d1348e46..af18c746a5 100644 --- a/mindspore/lite/src/ops/pad.h +++ b/mindspore/lite/src/ops/pad.h @@ -36,6 +36,7 @@ class Pad : public PrimitiveC { void SetPaddings(const std::vector &paddings); void SetPaddingMode(int padding_mode); void SetConstantValue(float constant_value); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 056146919e..f94c0cbdde 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -571,6 +571,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "While") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "MirrorPad") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "GatherV2") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "OnesLike") { diff --git a/mindspore/lite/src/ops/while.cc b/mindspore/lite/src/ops/while.cc index 31ee5068c4..60b1537942 100644 --- a/mindspore/lite/src/ops/while.cc +++ b/mindspore/lite/src/ops/while.cc @@ -56,10 +56,6 @@ int While::UnPackAttr(const Primitive &prim, const std::vector &inpu attr->bodySubgraphIndex = GetValue(prim.GetAttr("body_subgraph_index")); attr->condSubgraphIndex = GetValue(prim.GetAttr("cond_subgraph_index")); this->primitive_->value.value = attr; - if (this->primitive_->value.value == nullptr) { - MS_LOG(ERROR) << "primitive value is nullptr"; - return RET_ERROR; - } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 07a3d66922..88e1b21a0d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -64,6 +64,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { if (fp16_weight_ != nullptr) { free(fp16_weight_); fp16_weight_ = nullptr; + execute_weight_ = nullptr; } // init bias diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index 21dd2ae16f..48d9111aca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -81,6 +81,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { if (fp16_weight_ != nullptr) { free(fp16_weight_); fp16_weight_ = nullptr; + execute_weight_ = nullptr; } // init bias diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc index 188366dbf8..27ab8168b5 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc @@ -194,11 +194,20 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc } // todo multi output,other edge need insert nh2nc node auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); - if ((pre_node_output_indexs.size() != 1) && - (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat)) { - pre_nh2nc_nodes->clear(); - pre_not_trans_nodes->clear(); - return RET_OK; + if (pre_node_output_indexs.size() != 1) { + if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { + pre_nh2nc_nodes->clear(); + pre_not_trans_nodes->clear(); + return RET_OK; + } + for (auto pre_node_output_index : pre_node_output_indexs) { + MS_ASSERT(graph->nodes.size() > pre_node_output_index); + if (graph->nodes.at(pre_node_output_index)->primitive->value.type == schema::PrimitiveType_Pad) { + pre_nh2nc_nodes->clear(); + pre_not_trans_nodes->clear(); + return RET_OK; + } + } } } else { pre_nh2nc_nodes->clear();