From f640ca25c962660a4f43c2fa31f7d2b572a834dc Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Fri, 19 Feb 2021 20:32:59 +0800 Subject: [PATCH] remove convert support_train define --- mindspore/lite/src/ops/addn.cc | 10 -- mindspore/lite/src/ops/conv2d.cc | 14 +-- mindspore/lite/src/ops/pad.cc | 9 -- mindspore/lite/src/ops/primitive_c.cc | 26 ++-- mindspore/lite/src/ops/primitive_c.h | 12 +- mindspore/lite/src/ops/tile.cc | 62 ++++----- .../lite/tools/anf_exporter/anf_exporter.cc | 119 ++++++++---------- .../lite/tools/anf_exporter/anf_exporter.h | 7 +- .../tools/anf_importer/import_from_mindir.cc | 14 +-- mindspore/lite/tools/common/node_util.cc | 73 +++++------ mindspore/lite/tools/common/node_util.h | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 6 - .../lite/tools/converter/anf_transform.cc | 1 + mindspore/lite/tools/converter/converter.cc | 2 +- .../graph/format_trans_pass.cc | 52 +++----- .../graph/global_format_transform_pass.cc | 1 - .../graph/trans_format_insert_pass.cc | 87 ++++--------- .../tools/optimizer/graph/infershape_pass.cc | 16 ++- .../optimizer/graph/mindir_adjust_pass.cc | 2 +- .../optimizer/graph/mindir_adjust_pass.h | 2 + .../graph/weight_format_hardcode_pass.cc | 4 - .../graph/weight_format_transform_pass.cc | 8 +- 22 files changed, 218 insertions(+), 311 deletions(-) diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 387c38b62f..26e244cb28 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -114,16 +114,6 @@ int AddN::InferShape(std::vector inputs, std::vector outputs max_dim = dim; } } -#ifndef SUPPORT_TRAIN - for (size_t i = 0; i < inputs.size(); ++i) { - size_t shift = max_dims - inputs.at(i)->shape().size(); - size_t dim = (i < shift) ? 1 : inputs.at(i)->shape().at(d); - if ((dim != max_dim) && (dim != 1)) { - MS_LOG(ERROR) << "AddN inputs shape is not equal!"; - return RET_INPUT_TENSOR_ERROR; - } - } -#endif output->shape()[d] = max_dim; // set the biggest dimension in the output tensor } diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 5135891cea..c493586503 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -149,13 +149,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT attr->padRight = pad_list.at(3); auto dilation = CastToInt(prim.GetAttr("dilation")); -#ifdef SUPPORT_TRAIN - attr->dilateH = dilation.at(2); - attr->dilateW = dilation.at(3); -#else - attr->dilateH = dilation.at(0); - attr->dilateW = dilation.at(1); -#endif + if (train_flag()) { + attr->dilateH = dilation.at(2); + attr->dilateW = dilation.at(3); + } else { + attr->dilateH = dilation.at(0); + attr->dilateW = dilation.at(1); + } auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size.at(0); attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 7aca1217b9..951b9d4f5f 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -19,9 +19,6 @@ #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" #endif -#ifdef SUPPORT_TRAIN -#include -#endif namespace mindspore { namespace lite { @@ -56,20 +53,16 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs } string paddingmode = "REFLECT"; if (prim.GetAttr("mode") == nullptr) { -#ifdef SUPPORT_TRAIN if (prim.name() == "Pad") { paddingmode = "CONSTANT"; } else { -#endif MS_LOG(ERROR) << "get mode failed!"; delete this->primitive_; delete attr; this->primitive_ = nullptr; attr = nullptr; return RET_ERROR; -#ifdef SUPPORT_TRAIN } -#endif } else { paddingmode = GetValue(prim.GetAttr("mode")); } @@ -77,7 +70,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs attr->paddingMode = schema::PaddingMode_REFLECT; } else if (paddingmode == "SYMMETRIC") { attr->paddingMode = schema::PaddingMode_SYMMETRIC; -#ifdef SUPPORT_TRAIN } else if (paddingmode == "CONSTANT") { attr->paddingMode = schema::PaddingMode_CONSTANT; if (prim.GetAttr("paddings") != nullptr) { @@ -91,7 +83,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs attr->paddings.push_back(i); } } -#endif } else { MS_LOG(ERROR) << "model type not supported!"; delete this->primitive_; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 7125b11872..8aaca1bbd8 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -18,7 +18,6 @@ #ifdef PRIMITIVE_WRITEABLE #include #include - #include "tools/converter/quantizer/quantize_util.h" #include "src/ops/assert_op.h" #include "src/ops/space_to_batch.h" @@ -175,8 +174,6 @@ #include "src/ops/uniform_real.h" #include "src/ops/rank.h" #include "src/ops/is_finite.h" - -#ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" #include "src/ops/activation_grad.h" #include "src/ops/apply_momentum.h" @@ -210,7 +207,6 @@ #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" #include "src/ops/strided_slice_grad.h" #endif -#endif namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -513,13 +509,14 @@ std::shared_ptr GetTupleGetItemPrim() { template ::value>> std::shared_ptr NewPrimitiveC(const mindspore::Primitive &prim, const std::vector &inputs, - const schema::QuantType &quantType) { + const schema::QuantType &quantType, bool train_flag = false) { auto primc = std::make_shared(); if (primc == nullptr) { MS_LOG(ERROR) << "make_shared PrimitiveC failed"; return nullptr; } primc->set_quant_type(quantType); + primc->set_train_flag(train_flag); auto ret = primc->UnPackAttr(prim, inputs); if (ret != RET_OK) { MS_LOG(ERROR) << "UnPackAttr failed"; @@ -529,7 +526,7 @@ std::shared_ptr NewPrimitiveC(const mindspore::Primitive &prim, cons } std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std::vector &inputs, - const schema::QuantType &quantType) { + const schema::QuantType &quantType, bool train_flag) { const auto &op_type = prim.name(); if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { return NewPrimitiveC(prim, inputs, quantType); @@ -544,7 +541,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: } else if (op_type == "Concat") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2D") { - return NewPrimitiveC(prim, inputs, quantType); + return NewPrimitiveC(prim, inputs, quantType, train_flag); } else if (op_type == "Cos") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { @@ -664,7 +661,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: } else if (op_type == "Range") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Tile") { - return NewPrimitiveC(prim, inputs, quantType); + return NewPrimitiveC(prim, inputs, quantType, train_flag); } else if (op_type == "GatherNd") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Square") { @@ -685,7 +682,6 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Gelu") { return NewPrimitiveC(prim, inputs, quantType); -#ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") { @@ -706,7 +702,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2DBackpropFilter") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "Conv2DBackpropInput") { + } else if (op_type == "Conv2DBackpropInput" && train_flag) { return NewPrimitiveC(prim, inputs, quantType); } else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) { return NewPrimitiveC(prim, inputs, quantType); @@ -748,10 +744,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "AbsGrad") { return NewPrimitiveC(prim, inputs, quantType); -#else - } else if (op_type == "Conv2DBackpropInput") { + } else if (op_type == "Conv2DBackpropInput" && !train_flag) { return NewPrimitiveC(prim, inputs, quantType); -#endif } else { MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type; return nullptr; @@ -1065,7 +1059,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) UniformReal(primitive); case schema::PrimitiveType_Rank: return new (std::nothrow) Rank(primitive); -#ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); case schema::PrimitiveType_PoolingGrad: @@ -1140,7 +1133,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); case schema::PrimitiveType_StridedSliceGrad: return new (std::nothrow) StridedSliceGrad(primitive); -#endif default: MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); break; @@ -1170,6 +1162,10 @@ bool PrimitiveC::infer_flag() const { return this->infer_flag_; } void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; } +bool PrimitiveC::train_flag() const { return this->train_flag_; } + +void PrimitiveC::set_train_flag(bool flag) { this->train_flag_ = flag; } + int PrimitiveC::InferShape(std::vector inputs, std::vector outputs) { auto input = inputs.front(); MS_ASSERT(input != nullptr); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index f972607ab2..469992a7b9 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -133,6 +133,10 @@ class PrimitiveC : public mindspore::Primitive { void set_infer_flag(bool flag); + bool train_flag() const; + + void set_train_flag(bool flag); + static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); } static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); @@ -140,7 +144,7 @@ class PrimitiveC : public mindspore::Primitive { static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector *data); static std::shared_ptr Create(const Primitive &prim, const std::vector &inputs, - const schema::QuantType &quantType); + const schema::QuantType &quantType, bool train_flag = false); void PopulaterQuantParam(const Primitive &prim, const std::vector &inputs); void FillDefaultInputQuantParamIfNeed(const size_t &inputSize); void PopulaterInputQuantParam(const Primitive &prim, const std::vector &inputs, @@ -159,6 +163,7 @@ class PrimitiveC : public mindspore::Primitive { bool infer_flag_ = true; int op_type_ = OP_TYPE_NOT_SET; bool enable_huffman_code_ = false; + bool train_flag_ = false; }; std::shared_ptr GetReturnPrim(); @@ -179,6 +184,10 @@ class PrimitiveC { void set_infer_flag(bool flag); + bool train_flag() const; + + void set_train_flag(bool flag); + virtual int InferShape(std::vector inputs, std::vector outputs); int Type() const; @@ -238,6 +247,7 @@ class PrimitiveC { bool infer_flag_ = true; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; int op_type_ = OP_TYPE_NOT_SET; + bool train_flag_ = false; }; using PrimitiveCPtr = std::shared_ptr; typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 4c27c6a2e8..2b6a784a64 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -159,41 +159,41 @@ int Tile::InferShape(std::vector inputs_, std::vector output } else { multiples = GetMultiples(); } -#ifdef SUPPORT_TRAIN - const size_t in_dims = input->shape().size(); - const size_t delta_dims = in_dims - multiples.size(); - - size_t i = 0; - for (; i < delta_dims; ++i) { - int tmp = input->shape().at(i); - out_shape.push_back(tmp); - } - for (; i < in_dims; ++i) { - int tmp = input->shape().at(i) * (multiples[i - delta_dims]); - out_shape.push_back(tmp); - } -#else - std::vector dims = GetDims(); - if (inputs_.size() == 2 && dims.empty()) { - for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { - dims.push_back(dim); + if (train_flag()) { + const size_t in_dims = input->shape().size(); + const size_t delta_dims = in_dims - multiples.size(); + + size_t i = 0; + for (; i < delta_dims; ++i) { + int tmp = input->shape().at(i); + out_shape.push_back(tmp); } - } - const size_t in_dims = input->shape().size(); + for (; i < in_dims; ++i) { + int tmp = input->shape().at(i) * (multiples[i - delta_dims]); + out_shape.push_back(tmp); + } + } else { + std::vector dims = GetDims(); + if (inputs_.size() == 2 && dims.empty()) { + for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { + dims.push_back(dim); + } + } + const size_t in_dims = input->shape().size(); - MS_ASSERT(multiples.size() == dims.size()); - for (size_t i = 0; i < in_dims; ++i) { - out_shape.push_back(input->shape().at(i)); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (input->shape().at(dims.at(i)) != 0 && - multiples.at(i) > std::numeric_limits::max() / input->shape().at(dims.at(i))) { - MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; - return RET_ERROR; + MS_ASSERT(multiples.size() == dims.size()); + for (size_t i = 0; i < in_dims; ++i) { + out_shape.push_back(input->shape().at(i)); + } + for (size_t i = 0; i < dims.size(); ++i) { + if (input->shape().at(dims.at(i)) != 0 && + multiples.at(i) > std::numeric_limits::max() / input->shape().at(dims.at(i))) { + MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; + return RET_ERROR; + } + out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); } - out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); } -#endif output->set_shape(out_shape); return RET_OK; } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index c2002ea112..6420fd13f3 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrType() == schema::PrimitiveType_Depend || + primitive_c->Type() == schema::PrimitiveType_ControlDepend) { + continue; + } + } if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || -#ifdef SUPPORT_TRAIN - (primitive_c->Type() == schema::PrimitiveType_Depend) || - (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || -#endif (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { continue; } @@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu return RET_OK; } -schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, + bool train_flag) { static int subgraph_index = 0; + this->train_flag = train_flag; auto meta_graphT = std::make_unique(); int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); if (ret != RET_OK) { @@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, std::string input_name = input_anode->fullname_with_scope(); auto input_cnode = utils::cast(input_anode); if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { -#ifndef SUPPORT_TRAIN - if (node_id_map_.find(input_name) != node_id_map_.end()) { - output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); - } -#else bool found = false; if (node_id_map_.find(input_name) != node_id_map_.end()) { output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); found = true; } - if (found == false) { + if (!found) { auto input_index_key = input_name + "_o:" + std::to_string(0); if (node_id_map_.find(input_index_key) != node_id_map_.end()) { output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); } } -#endif } else { auto inputs = input_cnode->inputs(); @@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, : GetValue(value_node->value())); auto iter = node_id_map_.find(input_index_key); if (iter == node_id_map_.end()) { -#ifdef SUPPORT_TRAIN input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 iter = node_id_map_.find(input_index_key); if (iter == node_id_map_.end()) { MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; return RET_ERROR; } -#else - MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; - return RET_ERROR; -#endif } output_cnode->inputIndex.emplace_back(iter->second); } @@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr(value); }); (*paramTensor)->dims = dims; -#ifdef SUPPORT_TRAIN - if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1}; -#endif + if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; auto data = value->cast(); (*paramTensor)->data.resize(data->Size()); @@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu (*paramTensor)->format = schema::Format(valueLite->format()); (*paramTensor)->dataType = valueLite->tensor_type(); (*paramTensor)->dims = valueLite->tensor_shape(); -#ifdef SUPPORT_TRAIN - if ((*paramTensor)->dims.size() == 0) { + + if (train_flag && (*paramTensor)->dims.empty()) { (*paramTensor)->dims = {1}; } -#endif + ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), valueLite->tensor_size()); if (ret != EOK) { @@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano auto paramTensor = std::make_unique(); auto value = valueNode->value(); int ret = RET_OK; -#ifdef SUPPORT_TRAIN - paramTensor->name = valueNode->fullname_with_scope(); -#endif + if (train_flag) { + paramTensor->name = valueNode->fullname_with_scope(); + } if (value->isa()) { ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa() || value->isa()) { @@ -797,44 +786,44 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrnodeType = schema::NodeType_CNode; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); -#ifdef SUPPORT_TRAIN - std::string name = cnode_name + "_o:" + std::to_string(i); - node_id_map_[name] = meta_graphT->allTensors.size(); - meta_graphT->allTensors.emplace_back(msTensor); - if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) - break; -#else - if (elements.size() == 1) { - node_id_map_[cnode_name] = meta_graphT->allTensors.size(); - msTensor->name = cnode_name; - } else { + if (train_flag) { std::string name = cnode_name + "_o:" + std::to_string(i); node_id_map_[name] = meta_graphT->allTensors.size(); - msTensor->name = name; - } + meta_graphT->allTensors.emplace_back(msTensor); + if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) + break; + } else { + if (elements.size() == 1) { + node_id_map_[cnode_name] = meta_graphT->allTensors.size(); + msTensor->name = cnode_name; + } else { + std::string name = cnode_name + "_o:" + std::to_string(i); + node_id_map_[name] = meta_graphT->allTensors.size(); + msTensor->name = name; + } - if (!utils::isa(elements[i])) { - MS_LOG(ERROR) << "abstract is not AbstractTensor"; - delete (msTensor); - return; - } - auto type = kNumberTypeFloat32; - if (utils::isa(elements[i])) { - auto abstract_tensor = utils::cast(elements[i]); - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - type = typePtr->type_id(); - } - msTensor->dataType = type; - meta_graphT->allTensors.emplace_back(msTensor); - if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { - break; + if (!utils::isa(elements[i])) { + MS_LOG(ERROR) << "abstract is not AbstractTensor"; + delete (msTensor); + return; + } + auto type = kNumberTypeFloat32; + if (utils::isa(elements[i])) { + auto abstract_tensor = utils::cast(elements[i]); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + type = typePtr->type_id(); + } + msTensor->dataType = type; + meta_graphT->allTensors.emplace_back(msTensor); + if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { + break; + } } -#endif } } else { auto ms_tensor = new (std::nothrow) schema::TensorT(); @@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node } } -schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) { AnfExporter anf_exporter; - return anf_exporter.Export(func_graph, keep_graph, copy_primitive); + return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag); } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 4a7aaecb32..2496ae3c70 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -35,7 +35,8 @@ class AnfExporter { public: AnfExporter() = default; virtual ~AnfExporter() = default; - schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); + schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, + bool train_flag = false); void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *fb_node); int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, @@ -91,11 +92,13 @@ class AnfExporter { std::vector graph_input_nodes_; std::map fg_subgraph_map; uint32_t node_idx = 0; + bool train_flag = false; }; // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple // and clear. -schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, + bool train_flag = false); } // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ diff --git a/mindspore/lite/tools/anf_importer/import_from_mindir.cc b/mindspore/lite/tools/anf_importer/import_from_mindir.cc index 63a01b9caa..3720ec91bb 100644 --- a/mindspore/lite/tools/anf_importer/import_from_mindir.cc +++ b/mindspore/lite/tools/anf_importer/import_from_mindir.cc @@ -855,14 +855,14 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model } int AnfImporterFromMindir::Import(const converter::Flags *flag) { -#if SUPPORT_TRAIN - func_graph_ = LoadMindIR(flag->modelFile, true); - if (func_graph_ != nullptr) { - return RET_OK; - } else { - MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; + if (flag->trainModel) { + func_graph_ = LoadMindIR(flag->modelFile, true); + if (func_graph_ != nullptr) { + return RET_OK; + } else { + MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; + } } -#endif onnx_model_ = ReadOnnxFromBinary(flag->modelFile); if (onnx_model_ == nullptr) { MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 4739a82173..c32e585ff0 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -24,39 +24,40 @@ namespace mindspore { namespace lite { -static const std::vector nhwcOpList = { -#ifdef SUPPORT_TRAIN - schema::PrimitiveType_Conv2DGradFilter, - schema::PrimitiveType_Conv2DGradInput, - schema::PrimitiveType_GroupConv2DGradInput, - schema::PrimitiveType_PoolingGrad, - schema::PrimitiveType_BiasGrad, - schema::PrimitiveType_BNGrad, - schema::PrimitiveType_ApplyMomentum, - schema::PrimitiveType_Sgd, - schema::PrimitiveType_Adam, -#endif - schema::PrimitiveType_Conv2D, - schema::PrimitiveType_DeConv2D, - schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_DeDepthwiseConv2D, - schema::PrimitiveType_Pooling, - schema::PrimitiveType_LocalResponseNormalization, - schema::PrimitiveType_Resize, - schema::PrimitiveType_BatchNorm, - schema::PrimitiveType_FusedBatchNorm, - schema::PrimitiveType_PReLU, - schema::PrimitiveType_BiasAdd, - schema::PrimitiveType_SpaceToDepth, - schema::PrimitiveType_DepthToSpace, - schema::PrimitiveType_TopK}; +static const std::vector nhwcOpList = {schema::PrimitiveType_Conv2DGradFilter, + schema::PrimitiveType_Conv2DGradInput, + schema::PrimitiveType_GroupConv2DGradInput, + schema::PrimitiveType_PoolingGrad, + schema::PrimitiveType_BiasGrad, + schema::PrimitiveType_BNGrad, + schema::PrimitiveType_ApplyMomentum, + schema::PrimitiveType_Sgd, + schema::PrimitiveType_Adam, + schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_Pooling, + schema::PrimitiveType_LocalResponseNormalization, + schema::PrimitiveType_Resize, + schema::PrimitiveType_BatchNorm, + schema::PrimitiveType_FusedBatchNorm, + schema::PrimitiveType_PReLU, + schema::PrimitiveType_BiasAdd, + schema::PrimitiveType_SpaceToDepth, + schema::PrimitiveType_DepthToSpace, + schema::PrimitiveType_TopK}; static const std::vector nhwcOpAllInputList = { -#ifdef SUPPORT_TRAIN schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter, - schema::PrimitiveType_BNGrad -#endif -}; + schema::PrimitiveType_BNGrad}; + +// index {} mean all inputs need insert +static std::unordered_map> extNhwcInsertIndex = { + {schema::PrimitiveType_BNGrad, {0, 1}}, + {schema::PrimitiveType_ApplyMomentum, {3}}, + {schema::PrimitiveType_Sgd, {1}}, + {schema::PrimitiveType_Adam, {9}}}; static const std::vector fp32FullOpList = { schema::PrimitiveType_Concat, schema::PrimitiveType_Add, @@ -133,18 +134,10 @@ static const std::vector int8OpList = {schema::PrimitiveT schema::PrimitiveType_L2Norm}; static const std::vector needInsertOpList = { -#ifdef SUPPORT_TRAIN - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, - schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, - schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add, - schema::PrimitiveType_ActivationGrad -#else schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, - schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum -#endif -}; + schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad}; static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; @@ -156,6 +149,8 @@ std::vector Getfp32FullOpList() { return fp32FullOpList; std::vector GetNhwcOpList() { return nhwcOpList; } +std::unordered_map> GetExtNhwcIndexes() { return extNhwcInsertIndex; } + std::vector GetNhwcAllInputOpList() { return nhwcOpAllInputList; } std::vector GetUint8NhwcOpList() { return int8NeedNhwcOpList; } diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 7f87001974..fb50012e88 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -62,6 +62,8 @@ std::vector GetNhwcOpList(); std::vector GetNhwcAllInputOpList(); +std::unordered_map> GetExtNhwcIndexes(); + std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 3420b057c3..79f30a4933 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -103,12 +103,6 @@ set(LITE_SRC ${SRC_DIR}/dequant.cc ${SRC_DIR}/huffman_decode.cc ) -if(SUPPORT_TRAIN) - set(LITE_SRC - ${LITE_SRC} - ) - -endif() set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) file(GLOB KERNEL_SRC ${ARM_DIR}/base/*.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 6128722572..f4eb1e9044 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -177,6 +177,7 @@ int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const conve auto mindir_adjust_pass = std::make_shared(); mindir_adjust_pass->SetFmkType(config->fmk); mindir_adjust_pass->SetQuantType(config->quantType); + mindir_adjust_pass->SetTrainFlag(config->trainModel); if (!mindir_adjust_pass->Run(old_graph)) { MS_LOG(ERROR) << "mindir adjust failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 69e11accc5..0c87ec13ba 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -88,7 +88,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { } // anf -- fb - auto meta_graph = Export(graph); + auto meta_graph = Export(graph, false, false, flag->trainModel); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; return nullptr; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index d9cbd4ab61..270f22f179 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -48,21 +48,8 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT FormatTransNodeType *afterNodeType) { if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc return RET_NO_CHANGE; - } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw - if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { - return RET_NO_CHANGE; - } - *beforeNodeType = kNCHW2NHWC; - *afterNodeType = kNHWC2NCHW; - return RET_OK; - } else if (fmkType == converter::FmkType_MS) { - if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { - return RET_NO_CHANGE; - } - *beforeNodeType = kNCHW2NHWC; - *afterNodeType = kNHWC2NCHW; - return RET_OK; - } else if (fmkType == converter::FmkType_ONNX) { + } else if (fmkType == converter::FmkType_CAFFE || fmkType == converter::FmkType_MS || + fmkType == converter::FmkType_ONNX) { if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { return RET_NO_CHANGE; } @@ -173,11 +160,19 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { reinterpret_cast(attr)->format = schema::Format_NHWC; } -#ifdef SUPPORT_TRAIN - if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) { - int idx_num = node->inputIndex.size(); - if (GetCNodeTType(**iter) == schema::PrimitiveType_BNGrad) idx_num = 2; - for (int i = 0; i < idx_num; i++) { + auto specInsertIndexes = GetExtNhwcIndexes(); + auto opType = GetCNodeTType(**iter); + if (specInsertIndexes.find(opType) != specInsertIndexes.end()) { + for (auto insert_index : specInsertIndexes[opType]) { + iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; + return RET_ERROR; + } + } + } else if (IsContain(GetNhwcAllInputOpList(), opType)) { + auto input_size = node->inputIndex.size(); + for (size_t i = 0; i < input_size; i++) { iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; @@ -185,23 +180,8 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { } } } else { - int idx = 0; - if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; - if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; - if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; - iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; - return RET_ERROR; - } - } -#else - iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; - return RET_ERROR; + iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); } -#endif iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc index 7d0b7129af..a5fb729bee 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc @@ -194,7 +194,6 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc if (!IsContain(bfs_queue, input_node_index)) { bfs_queue.emplace_back(input_node_index); } - // todo multi output,other edge need insert nh2nc node auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); if (pre_node_output_indexs.size() != 1) { if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 263ee92a7e..69d6b89c76 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -29,30 +29,25 @@ std::vector nchw2nhwc_perm = {0, 2, 3, 1}; std::vector nhwc2nchw_perm = {0, 3, 1, 2}; } // namespace namespace lite { -bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(node != nullptr); - auto input_node_indexes = GetInputNodeIdx(*graph, *node); - pre_type_ = kNONE; - size_t has_trans_count = 0; - auto can_fusion = true; - for (auto input_node_index : input_node_indexes) { +bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector &node_indexes, size_t *has_trans_count, + FormatTransNodeType *trans_type) { + for (auto input_node_index : node_indexes) { MS_ASSERT(graph->nodes.size() > input_node_index); auto &pre_node = graph->nodes.at(input_node_index); MS_ASSERT(pre_node != nullptr); MS_ASSERT(pre_node->primitive != nullptr); MS_ASSERT(pre_node->primitive->value != nullptr); - if (pre_type_ == kNONE) { + if (*trans_type == kNONE) { if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr); if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { - pre_type_ = kNCHW2NHWC; + *trans_type = kNCHW2NHWC; } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { - pre_type_ = kNHWC2NCHW; + *trans_type = kNHWC2NCHW; } else { return false; } - has_trans_count++; + (*has_trans_count)++; } } else { if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { @@ -64,57 +59,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p } else { return false; } - if (pre_type_ != cur_type) { - can_fusion = false; - break; + if (*trans_type != cur_type) { + return false; } else { - has_trans_count++; + (*has_trans_count)++; } } } } - if (!can_fusion) { + return true; +} +bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(node != nullptr); + auto input_node_indexes = GetInputNodeIdx(*graph, *node); + pre_type_ = kNONE; + size_t has_trans_count = 0; + if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) { return false; } auto output_node_indexes = GetOutputNodeIdx(*graph, *node); post_type_ = kNONE; - for (auto output_node_index : output_node_indexes) { - MS_ASSERT(graph->nodes.size() > output_node_index); - auto &post_node = graph->nodes.at(output_node_index); - MS_ASSERT(post_node != nullptr); - MS_ASSERT(post_node->primitive != nullptr); - MS_ASSERT(post_node->primitive->value != nullptr); - if (post_type_ == kNONE) { - if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { - if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { - post_type_ = kNCHW2NHWC; - } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { - post_type_ = kNHWC2NCHW; - } else { - return false; - } - has_trans_count++; - } - } else { - if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { - auto cur_type = kNONE; - if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { - cur_type = kNCHW2NHWC; - } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { - cur_type = kNHWC2NCHW; - } else { - return false; - } - if (post_type_ != cur_type) { - can_fusion = false; - break; - } else { - has_trans_count++; - } - } - } - } - if (!can_fusion) { + if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) { return false; } if (pre_type_ == kNONE && post_type_ == kNONE) { @@ -136,10 +102,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { return has_trans_count >= half_count; } - can_fusion = has_trans_count > half_count; - return can_fusion; + return has_trans_count > half_count; } - STATUS TransOpInsertPass::FindOutTransType() { pre_insert_trans_type_ = kNHWC2NCHW; post_insert_trans_type_ = kNHWC2NCHW; @@ -153,7 +117,7 @@ STATUS TransOpInsertPass::FindOutTransType() { MS_ASSERT(false); } else { if (pre_type_ == post_type_) { - MS_LOG(ERROR) << "Unknow error"; + MS_LOG(ERROR) << "Unknown error"; return RET_ERROR; } pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; @@ -200,13 +164,6 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { STATUS status = RET_OK; auto input_tensor_size = (*iter)->inputIndex.size(); for (size_t i = 0; i < input_tensor_size; i++) { -#ifdef SUPPORT_TRAIN - auto &tensor = graph->allTensors.at((*iter)->inputIndex[i]); - MS_ASSERT(tensor != nullptr); - if (tensor->nodeType == schema::NodeType_ValueNode) { - continue; - } -#endif auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { continue; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index 8504626d10..e1aa855aed 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -37,6 +37,15 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { return para_value_lite; } +bool IsSpecialType(schema::PrimitiveType type) { + if ((type == schema::PrimitiveType_TupleGetItem) || (type == schema::PrimitiveType_Depend) || + (type == schema::PrimitiveType_ControlDepend) || + (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { + return true; + } + return false; +} + abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { MS_ASSERT(nullptr != tensor); std::vector shape(tensor->shape()); @@ -363,12 +372,7 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { return false; } auto type = GetCNodeType(cnode); - - if ((type == schema::PrimitiveType_TupleGetItem) || -#ifdef SUPPORT_TRAIN - (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || -#endif - (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { + if (IsSpecialType(type)) { continue; } std::vector input_tensors; diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index 881a516186..e713634e63 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -147,7 +147,7 @@ int MindirAdjustPass::PrimitiveConvert(std::shared_ptr anf_node) { auto inputs = cnode->inputs(); inputs.erase(inputs.begin()); if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { - auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); + auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_, train_flag_); if (primitive_c == nullptr) { MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); lite::NoSupportOp::GetInstance()->InsertOp(primitive->name()); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h index 6a9383e90c..dbc47652c7 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h @@ -33,6 +33,7 @@ class MindirAdjustPass : public Pass { void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } int ValueNodeInt64Convert(AnfNodePtr anf_node); + void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } int ParameterNodeConvert(AnfNodePtr anf_node); int PrimitiveConvert(AnfNodePtr anf_node); bool Run(const FuncGraphPtr &graph) override; @@ -40,6 +41,7 @@ class MindirAdjustPass : public Pass { protected: QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; FmkType fmk_type_ = FmkType::FmkType_MS; + bool train_flag_ = false; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index af8d1ad173..79952ee49f 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -131,12 +131,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, param_value->set_format(schema::Format::Format_CKHW); } else if (op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_KCHW); -#ifdef SUPPORT_TRAIN } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { param_value->set_format(schema::Format::Format_KCHW); } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { param_value->set_format(schema::Format::Format_CKHW); -#endif } else { MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " << conv_node->fullname_with_scope(); @@ -213,10 +211,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { auto conv_cnode = node->cast(); auto type = opt::GetCNodeType(node); if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && -#ifdef SUPPORT_TRAIN ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && -#endif type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { continue; } diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index f63501901c..1b1733c101 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -43,11 +43,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr continue; } auto type = opt::GetCNodeType(node); - if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D -#ifdef SUPPORT_TRAIN - && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput -#endif - && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { + if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && + type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput && + type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { continue; } auto conv_cnode = node->cast();