From: @zhengjun10 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -114,16 +114,6 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||
| max_dim = dim; | |||
| } | |||
| } | |||
| #ifndef SUPPORT_TRAIN | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| size_t shift = max_dims - inputs.at(i)->shape().size(); | |||
| size_t dim = (i < shift) ? 1 : inputs.at(i)->shape().at(d); | |||
| if ((dim != max_dim) && (dim != 1)) { | |||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| #endif | |||
| output->shape()[d] = max_dim; // set the biggest dimension in the output tensor | |||
| } | |||
| @@ -149,13 +149,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||
| attr->padRight = pad_list.at(3); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation")); | |||
| #ifdef SUPPORT_TRAIN | |||
| attr->dilateH = dilation.at(2); | |||
| attr->dilateW = dilation.at(3); | |||
| #else | |||
| attr->dilateH = dilation.at(0); | |||
| attr->dilateW = dilation.at(1); | |||
| #endif | |||
| if (train_flag()) { | |||
| attr->dilateH = dilation.at(2); | |||
| attr->dilateW = dilation.at(3); | |||
| } else { | |||
| attr->dilateH = dilation.at(0); | |||
| attr->dilateW = dilation.at(1); | |||
| } | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size.at(0); | |||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | |||
| @@ -19,9 +19,6 @@ | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| #ifdef SUPPORT_TRAIN | |||
| #include <tuple> | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -56,20 +53,16 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||
| } | |||
| string paddingmode = "REFLECT"; | |||
| if (prim.GetAttr("mode") == nullptr) { | |||
| #ifdef SUPPORT_TRAIN | |||
| if (prim.name() == "Pad") { | |||
| paddingmode = "CONSTANT"; | |||
| } else { | |||
| #endif | |||
| MS_LOG(ERROR) << "get mode failed!"; | |||
| delete this->primitive_; | |||
| delete attr; | |||
| this->primitive_ = nullptr; | |||
| attr = nullptr; | |||
| return RET_ERROR; | |||
| #ifdef SUPPORT_TRAIN | |||
| } | |||
| #endif | |||
| } else { | |||
| paddingmode = GetValue<string>(prim.GetAttr("mode")); | |||
| } | |||
| @@ -77,7 +70,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||
| attr->paddingMode = schema::PaddingMode_REFLECT; | |||
| } else if (paddingmode == "SYMMETRIC") { | |||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (paddingmode == "CONSTANT") { | |||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||
| if (prim.GetAttr("paddings") != nullptr) { | |||
| @@ -91,7 +83,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||
| attr->paddings.push_back(i); | |||
| } | |||
| } | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "model type not supported!"; | |||
| delete this->primitive_; | |||
| @@ -18,7 +18,6 @@ | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <memory> | |||
| #include <map> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/ops/assert_op.h" | |||
| #include "src/ops/space_to_batch.h" | |||
| @@ -175,8 +174,6 @@ | |||
| #include "src/ops/uniform_real.h" | |||
| #include "src/ops/rank.h" | |||
| #include "src/ops/is_finite.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| #include "src/ops/activation_grad.h" | |||
| #include "src/ops/apply_momentum.h" | |||
| @@ -210,7 +207,6 @@ | |||
| #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" | |||
| #include "src/ops/strided_slice_grad.h" | |||
| #endif | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| @@ -513,13 +509,14 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { | |||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | |||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType) { | |||
| const schema::QuantType &quantType, bool train_flag = false) { | |||
| auto primc = std::make_shared<T>(); | |||
| if (primc == nullptr) { | |||
| MS_LOG(ERROR) << "make_shared PrimitiveC failed"; | |||
| return nullptr; | |||
| } | |||
| primc->set_quant_type(quantType); | |||
| primc->set_train_flag(train_flag); | |||
| auto ret = primc->UnPackAttr(prim, inputs); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "UnPackAttr failed"; | |||
| @@ -529,7 +526,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, cons | |||
| } | |||
| std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType) { | |||
| const schema::QuantType &quantType, bool train_flag) { | |||
| const auto &op_type = prim.name(); | |||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | |||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | |||
| @@ -544,7 +541,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| } else if (op_type == "Concat") { | |||
| return NewPrimitiveC<Concat>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2D") { | |||
| return NewPrimitiveC<Conv2D>(prim, inputs, quantType); | |||
| return NewPrimitiveC<Conv2D>(prim, inputs, quantType, train_flag); | |||
| } else if (op_type == "Cos") { | |||
| return NewPrimitiveC<Cos>(prim, inputs, quantType); | |||
| } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { | |||
| @@ -664,7 +661,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| } else if (op_type == "Range") { | |||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType, train_flag); | |||
| } else if (op_type == "GatherNd") { | |||
| return NewPrimitiveC<GatherNd>(prim, inputs, quantType); | |||
| } else if (op_type == "Square") { | |||
| @@ -685,7 +682,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | |||
| } else if (op_type == "Gelu") { | |||
| return NewPrimitiveC<GeLU>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||
| } else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") { | |||
| @@ -706,7 +702,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2DBackpropFilter") { | |||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| } else if (op_type == "Conv2DBackpropInput" && train_flag) { | |||
| return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | |||
| } else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) { | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| @@ -748,10 +744,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "AbsGrad") { | |||
| return NewPrimitiveC<AbsGrad>(prim, inputs, quantType); | |||
| #else | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| } else if (op_type == "Conv2DBackpropInput" && !train_flag) { | |||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type; | |||
| return nullptr; | |||
| @@ -1065,7 +1059,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) UniformReal(primitive); | |||
| case schema::PrimitiveType_Rank: | |||
| return new (std::nothrow) Rank(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| case schema::PrimitiveType_PoolingGrad: | |||
| @@ -1140,7 +1133,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); | |||
| case schema::PrimitiveType_StridedSliceGrad: | |||
| return new (std::nothrow) StridedSliceGrad(primitive); | |||
| #endif | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | |||
| break; | |||
| @@ -1170,6 +1162,10 @@ bool PrimitiveC::infer_flag() const { return this->infer_flag_; } | |||
| void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; } | |||
| bool PrimitiveC::train_flag() const { return this->train_flag_; } | |||
| void PrimitiveC::set_train_flag(bool flag) { this->train_flag_ = flag; } | |||
| int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| auto input = inputs.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| @@ -133,6 +133,10 @@ class PrimitiveC : public mindspore::Primitive { | |||
| void set_infer_flag(bool flag); | |||
| bool train_flag() const; | |||
| void set_train_flag(bool flag); | |||
| static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); } | |||
| static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); | |||
| @@ -140,7 +144,7 @@ class PrimitiveC : public mindspore::Primitive { | |||
| static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data); | |||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType); | |||
| const schema::QuantType &quantType, bool train_flag = false); | |||
| void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||
| void FillDefaultInputQuantParamIfNeed(const size_t &inputSize); | |||
| void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| @@ -159,6 +163,7 @@ class PrimitiveC : public mindspore::Primitive { | |||
| bool infer_flag_ = true; | |||
| int op_type_ = OP_TYPE_NOT_SET; | |||
| bool enable_huffman_code_ = false; | |||
| bool train_flag_ = false; | |||
| }; | |||
| std::shared_ptr<PrimitiveC> GetReturnPrim(); | |||
| @@ -179,6 +184,10 @@ class PrimitiveC { | |||
| void set_infer_flag(bool flag); | |||
| bool train_flag() const; | |||
| void set_train_flag(bool flag); | |||
| virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs); | |||
| int Type() const; | |||
| @@ -238,6 +247,7 @@ class PrimitiveC { | |||
| bool infer_flag_ = true; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| int op_type_ = OP_TYPE_NOT_SET; | |||
| bool train_flag_ = false; | |||
| }; | |||
| using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | |||
| typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | |||
| @@ -159,41 +159,41 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| } else { | |||
| multiples = GetMultiples(); | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| const size_t in_dims = input->shape().size(); | |||
| const size_t delta_dims = in_dims - multiples.size(); | |||
| size_t i = 0; | |||
| for (; i < delta_dims; ++i) { | |||
| int tmp = input->shape().at(i); | |||
| out_shape.push_back(tmp); | |||
| } | |||
| for (; i < in_dims; ++i) { | |||
| int tmp = input->shape().at(i) * (multiples[i - delta_dims]); | |||
| out_shape.push_back(tmp); | |||
| } | |||
| #else | |||
| std::vector<int> dims = GetDims(); | |||
| if (inputs_.size() == 2 && dims.empty()) { | |||
| for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { | |||
| dims.push_back(dim); | |||
| if (train_flag()) { | |||
| const size_t in_dims = input->shape().size(); | |||
| const size_t delta_dims = in_dims - multiples.size(); | |||
| size_t i = 0; | |||
| for (; i < delta_dims; ++i) { | |||
| int tmp = input->shape().at(i); | |||
| out_shape.push_back(tmp); | |||
| } | |||
| } | |||
| const size_t in_dims = input->shape().size(); | |||
| for (; i < in_dims; ++i) { | |||
| int tmp = input->shape().at(i) * (multiples[i - delta_dims]); | |||
| out_shape.push_back(tmp); | |||
| } | |||
| } else { | |||
| std::vector<int> dims = GetDims(); | |||
| if (inputs_.size() == 2 && dims.empty()) { | |||
| for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { | |||
| dims.push_back(dim); | |||
| } | |||
| } | |||
| const size_t in_dims = input->shape().size(); | |||
| MS_ASSERT(multiples.size() == dims.size()); | |||
| for (size_t i = 0; i < in_dims; ++i) { | |||
| out_shape.push_back(input->shape().at(i)); | |||
| } | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| if (input->shape().at(dims.at(i)) != 0 && | |||
| multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) { | |||
| MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; | |||
| return RET_ERROR; | |||
| MS_ASSERT(multiples.size() == dims.size()); | |||
| for (size_t i = 0; i < in_dims; ++i) { | |||
| out_shape.push_back(input->shape().at(i)); | |||
| } | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| if (input->shape().at(dims.at(i)) != 0 && | |||
| multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) { | |||
| MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; | |||
| return RET_ERROR; | |||
| } | |||
| out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); | |||
| } | |||
| out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); | |||
| } | |||
| #endif | |||
| output->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| } | |||
| RemoveIfMakeTuple(cnode); | |||
| #ifdef SUPPORT_TRAIN | |||
| RemoveIfDepend(cnode); | |||
| #endif | |||
| if (train_flag) { | |||
| RemoveIfDepend(cnode); | |||
| if (primitive_c->Type() == schema::PrimitiveType_Depend || | |||
| primitive_c->Type() == schema::PrimitiveType_ControlDepend) { | |||
| continue; | |||
| } | |||
| } | |||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||
| #endif | |||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | |||
| continue; | |||
| } | |||
| @@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, | |||
| bool train_flag) { | |||
| static int subgraph_index = 0; | |||
| this->train_flag = train_flag; | |||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||
| int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); | |||
| if (ret != RET_OK) { | |||
| @@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| std::string input_name = input_anode->fullname_with_scope(); | |||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | |||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | |||
| #ifndef SUPPORT_TRAIN | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||
| } | |||
| #else | |||
| bool found = false; | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||
| found = true; | |||
| } | |||
| if (found == false) { | |||
| if (!found) { | |||
| auto input_index_key = input_name + "_o:" + std::to_string(0); | |||
| if (node_id_map_.find(input_index_key) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); | |||
| } | |||
| } | |||
| #endif | |||
| } else { | |||
| auto inputs = input_cnode->inputs(); | |||
| @@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| : GetValue<int>(value_node->value())); | |||
| auto iter = node_id_map_.find(input_index_key); | |||
| if (iter == node_id_map_.end()) { | |||
| #ifdef SUPPORT_TRAIN | |||
| input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | |||
| iter = node_id_map_.find(input_index_key); | |||
| if (iter == node_id_map_.end()) { | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||
| return RET_ERROR; | |||
| } | |||
| #else | |||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||
| return RET_ERROR; | |||
| #endif | |||
| } | |||
| output_cnode->inputIndex.emplace_back(iter->second); | |||
| } | |||
| @@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc | |||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | |||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | |||
| (*paramTensor)->dims = dims; | |||
| #ifdef SUPPORT_TRAIN | |||
| if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1}; | |||
| #endif | |||
| if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; | |||
| (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| (*paramTensor)->data.resize(data->Size()); | |||
| @@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu | |||
| (*paramTensor)->format = schema::Format(valueLite->format()); | |||
| (*paramTensor)->dataType = valueLite->tensor_type(); | |||
| (*paramTensor)->dims = valueLite->tensor_shape(); | |||
| #ifdef SUPPORT_TRAIN | |||
| if ((*paramTensor)->dims.size() == 0) { | |||
| if (train_flag && (*paramTensor)->dims.empty()) { | |||
| (*paramTensor)->dims = {1}; | |||
| } | |||
| #endif | |||
| ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), | |||
| valueLite->tensor_size()); | |||
| if (ret != EOK) { | |||
| @@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||
| auto value = valueNode->value(); | |||
| int ret = RET_OK; | |||
| #ifdef SUPPORT_TRAIN | |||
| paramTensor->name = valueNode->fullname_with_scope(); | |||
| #endif | |||
| if (train_flag) { | |||
| paramTensor->name = valueNode->fullname_with_scope(); | |||
| } | |||
| if (value->isa<tensor::Tensor>()) { | |||
| ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) { | |||
| @@ -797,44 +786,44 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| } | |||
| msTensor->nodeType = schema::NodeType_CNode; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| #ifdef SUPPORT_TRAIN | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||
| break; | |||
| #else | |||
| if (elements.size() == 1) { | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = cnode_name; | |||
| } else { | |||
| if (train_flag) { | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = name; | |||
| } | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||
| break; | |||
| } else { | |||
| if (elements.size() == 1) { | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = cnode_name; | |||
| } else { | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = name; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| delete (msTensor); | |||
| return; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| msTensor->dataType = type; | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { | |||
| break; | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| delete (msTensor); | |||
| return; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| msTensor->dataType = type; | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { | |||
| break; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| } else { | |||
| auto ms_tensor = new (std::nothrow) schema::TensorT(); | |||
| @@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node | |||
| } | |||
| } | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) { | |||
| AnfExporter anf_exporter; | |||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive); | |||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag); | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -35,7 +35,8 @@ class AnfExporter { | |||
| public: | |||
| AnfExporter() = default; | |||
| virtual ~AnfExporter() = default; | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, | |||
| bool train_flag = false); | |||
| void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node); | |||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| @@ -91,11 +92,13 @@ class AnfExporter { | |||
| std::vector<schema::CNodeT *> graph_input_nodes_; | |||
| std::map<FuncGraphPtr, int> fg_subgraph_map; | |||
| uint32_t node_idx = 0; | |||
| bool train_flag = false; | |||
| }; | |||
| // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | |||
| // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | |||
| // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple | |||
| // and clear. | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, | |||
| bool train_flag = false); | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| @@ -855,14 +855,14 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model | |||
| } | |||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | |||
| #if SUPPORT_TRAIN | |||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||
| if (func_graph_ != nullptr) { | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; | |||
| if (flag->trainModel) { | |||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||
| if (func_graph_ != nullptr) { | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; | |||
| } | |||
| } | |||
| #endif | |||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | |||
| if (onnx_model_ == nullptr) { | |||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | |||
| @@ -24,39 +24,40 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||
| #ifdef SUPPORT_TRAIN | |||
| schema::PrimitiveType_Conv2DGradFilter, | |||
| schema::PrimitiveType_Conv2DGradInput, | |||
| schema::PrimitiveType_GroupConv2DGradInput, | |||
| schema::PrimitiveType_PoolingGrad, | |||
| schema::PrimitiveType_BiasGrad, | |||
| schema::PrimitiveType_BNGrad, | |||
| schema::PrimitiveType_ApplyMomentum, | |||
| schema::PrimitiveType_Sgd, | |||
| schema::PrimitiveType_Adam, | |||
| #endif | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||
| schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_LocalResponseNormalization, | |||
| schema::PrimitiveType_Resize, | |||
| schema::PrimitiveType_BatchNorm, | |||
| schema::PrimitiveType_FusedBatchNorm, | |||
| schema::PrimitiveType_PReLU, | |||
| schema::PrimitiveType_BiasAdd, | |||
| schema::PrimitiveType_SpaceToDepth, | |||
| schema::PrimitiveType_DepthToSpace, | |||
| schema::PrimitiveType_TopK}; | |||
| static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DGradFilter, | |||
| schema::PrimitiveType_Conv2DGradInput, | |||
| schema::PrimitiveType_GroupConv2DGradInput, | |||
| schema::PrimitiveType_PoolingGrad, | |||
| schema::PrimitiveType_BiasGrad, | |||
| schema::PrimitiveType_BNGrad, | |||
| schema::PrimitiveType_ApplyMomentum, | |||
| schema::PrimitiveType_Sgd, | |||
| schema::PrimitiveType_Adam, | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||
| schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_LocalResponseNormalization, | |||
| schema::PrimitiveType_Resize, | |||
| schema::PrimitiveType_BatchNorm, | |||
| schema::PrimitiveType_FusedBatchNorm, | |||
| schema::PrimitiveType_PReLU, | |||
| schema::PrimitiveType_BiasAdd, | |||
| schema::PrimitiveType_SpaceToDepth, | |||
| schema::PrimitiveType_DepthToSpace, | |||
| schema::PrimitiveType_TopK}; | |||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | |||
| #ifdef SUPPORT_TRAIN | |||
| schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter, | |||
| schema::PrimitiveType_BNGrad | |||
| #endif | |||
| }; | |||
| schema::PrimitiveType_BNGrad}; | |||
| // index {} mean all inputs need insert | |||
| static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = { | |||
| {schema::PrimitiveType_BNGrad, {0, 1}}, | |||
| {schema::PrimitiveType_ApplyMomentum, {3}}, | |||
| {schema::PrimitiveType_Sgd, {1}}, | |||
| {schema::PrimitiveType_Adam, {9}}}; | |||
| static const std::vector<schema::PrimitiveType> fp32FullOpList = { | |||
| schema::PrimitiveType_Concat, schema::PrimitiveType_Add, | |||
| @@ -133,18 +134,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT | |||
| schema::PrimitiveType_L2Norm}; | |||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | |||
| #ifdef SUPPORT_TRAIN | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_ActivationGrad | |||
| #else | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, | |||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum | |||
| #endif | |||
| }; | |||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad}; | |||
| static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; | |||
| @@ -156,6 +149,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; | |||
| std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | |||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; } | |||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | |||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; } | |||
| @@ -62,6 +62,8 @@ std::vector<schema::PrimitiveType> GetNhwcOpList(); | |||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | |||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes(); | |||
| std::vector<schema::PrimitiveType> Getfp32FullOpList(); | |||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList(); | |||
| @@ -101,12 +101,6 @@ set(LITE_SRC | |||
| ${SRC_DIR}/dequant.cc | |||
| ${SRC_DIR}/huffman_decode.cc | |||
| ) | |||
| if(SUPPORT_TRAIN) | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ) | |||
| endif() | |||
| set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) | |||
| file(GLOB KERNEL_SRC | |||
| ${ARM_DIR}/base/*.cc | |||
| @@ -177,6 +177,7 @@ int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const conve | |||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | |||
| mindir_adjust_pass->SetFmkType(config->fmk); | |||
| mindir_adjust_pass->SetQuantType(config->quantType); | |||
| mindir_adjust_pass->SetTrainFlag(config->trainModel); | |||
| if (!mindir_adjust_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "mindir adjust failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| @@ -88,7 +88,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| } | |||
| // anf -- fb | |||
| auto meta_graph = Export(graph); | |||
| auto meta_graph = Export(graph, false, false, flag->trainModel); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | |||
| return nullptr; | |||
| @@ -48,21 +48,8 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT | |||
| FormatTransNodeType *afterNodeType) { | |||
| if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | |||
| return RET_NO_CHANGE; | |||
| } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *beforeNodeType = kNCHW2NHWC; | |||
| *afterNodeType = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } else if (fmkType == converter::FmkType_MS) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *beforeNodeType = kNCHW2NHWC; | |||
| *afterNodeType = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } else if (fmkType == converter::FmkType_ONNX) { | |||
| } else if (fmkType == converter::FmkType_CAFFE || fmkType == converter::FmkType_MS || | |||
| fmkType == converter::FmkType_ONNX) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| @@ -173,11 +160,19 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { | |||
| reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) { | |||
| int idx_num = node->inputIndex.size(); | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_BNGrad) idx_num = 2; | |||
| for (int i = 0; i < idx_num; i++) { | |||
| auto specInsertIndexes = GetExtNhwcIndexes(); | |||
| auto opType = GetCNodeTType(**iter); | |||
| if (specInsertIndexes.find(opType) != specInsertIndexes.end()) { | |||
| for (auto insert_index : specInsertIndexes[opType]) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (IsContain(GetNhwcAllInputOpList(), opType)) { | |||
| auto input_size = node->inputIndex.size(); | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| @@ -185,23 +180,8 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| } | |||
| } else { | |||
| int idx = 0; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| #else | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||
| } | |||
| #endif | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| @@ -194,7 +194,6 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc | |||
| if (!IsContain(bfs_queue, input_node_index)) { | |||
| bfs_queue.emplace_back(input_node_index); | |||
| } | |||
| // todo multi output,other edge need insert nh2nc node | |||
| auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); | |||
| if (pre_node_output_indexs.size() != 1) { | |||
| if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { | |||
| @@ -29,30 +29,25 @@ std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| pre_type_ = kNONE; | |||
| size_t has_trans_count = 0; | |||
| auto can_fusion = true; | |||
| for (auto input_node_index : input_node_indexes) { | |||
| bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count, | |||
| FormatTransNodeType *trans_type) { | |||
| for (auto input_node_index : node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| MS_ASSERT(pre_node->primitive != nullptr); | |||
| MS_ASSERT(pre_node->primitive->value != nullptr); | |||
| if (pre_type_ == kNONE) { | |||
| if (*trans_type == kNONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr); | |||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| pre_type_ = kNCHW2NHWC; | |||
| *trans_type = kNCHW2NHWC; | |||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| pre_type_ = kNHWC2NCHW; | |||
| *trans_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| has_trans_count++; | |||
| (*has_trans_count)++; | |||
| } | |||
| } else { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| @@ -64,57 +59,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| } else { | |||
| return false; | |||
| } | |||
| if (pre_type_ != cur_type) { | |||
| can_fusion = false; | |||
| break; | |||
| if (*trans_type != cur_type) { | |||
| return false; | |||
| } else { | |||
| has_trans_count++; | |||
| (*has_trans_count)++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!can_fusion) { | |||
| return true; | |||
| } | |||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| pre_type_ = kNONE; | |||
| size_t has_trans_count = 0; | |||
| if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) { | |||
| return false; | |||
| } | |||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | |||
| post_type_ = kNONE; | |||
| for (auto output_node_index : output_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > output_node_index); | |||
| auto &post_node = graph->nodes.at(output_node_index); | |||
| MS_ASSERT(post_node != nullptr); | |||
| MS_ASSERT(post_node->primitive != nullptr); | |||
| MS_ASSERT(post_node->primitive->value != nullptr); | |||
| if (post_type_ == kNONE) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| post_type_ = kNCHW2NHWC; | |||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| post_type_ = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| auto cur_type = kNONE; | |||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| if (post_type_ != cur_type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| has_trans_count++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!can_fusion) { | |||
| if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) { | |||
| return false; | |||
| } | |||
| if (pre_type_ == kNONE && post_type_ == kNONE) { | |||
| @@ -136,10 +102,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { | |||
| return has_trans_count >= half_count; | |||
| } | |||
| can_fusion = has_trans_count > half_count; | |||
| return can_fusion; | |||
| return has_trans_count > half_count; | |||
| } | |||
| STATUS TransOpInsertPass::FindOutTransType() { | |||
| pre_insert_trans_type_ = kNHWC2NCHW; | |||
| post_insert_trans_type_ = kNHWC2NCHW; | |||
| @@ -153,7 +117,7 @@ STATUS TransOpInsertPass::FindOutTransType() { | |||
| MS_ASSERT(false); | |||
| } else { | |||
| if (pre_type_ == post_type_) { | |||
| MS_LOG(ERROR) << "Unknow error"; | |||
| MS_LOG(ERROR) << "Unknown error"; | |||
| return RET_ERROR; | |||
| } | |||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| @@ -200,13 +164,6 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { | |||
| STATUS status = RET_OK; | |||
| auto input_tensor_size = (*iter)->inputIndex.size(); | |||
| for (size_t i = 0; i < input_tensor_size; i++) { | |||
| #ifdef SUPPORT_TRAIN | |||
| auto &tensor = graph->allTensors.at((*iter)->inputIndex[i]); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (tensor->nodeType == schema::NodeType_ValueNode) { | |||
| continue; | |||
| } | |||
| #endif | |||
| auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); | |||
| if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { | |||
| continue; | |||
| @@ -37,6 +37,15 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||
| return para_value_lite; | |||
| } | |||
| bool IsSpecialType(schema::PrimitiveType type) { | |||
| if ((type == schema::PrimitiveType_TupleGetItem) || (type == schema::PrimitiveType_Depend) || | |||
| (type == schema::PrimitiveType_ControlDepend) || | |||
| (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| std::vector<int> shape(tensor->shape()); | |||
| @@ -363,12 +372,7 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||
| return false; | |||
| } | |||
| auto type = GetCNodeType(cnode); | |||
| if ((type == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | |||
| #endif | |||
| (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { | |||
| if (IsSpecialType(type)) { | |||
| continue; | |||
| } | |||
| std::vector<lite::Tensor *> input_tensors; | |||
| @@ -147,7 +147,7 @@ int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) { | |||
| auto inputs = cnode->inputs(); | |||
| inputs.erase(inputs.begin()); | |||
| if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { | |||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); | |||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_, train_flag_); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); | |||
| lite::NoSupportOp::GetInstance()->InsertOp(primitive->name()); | |||
| @@ -33,6 +33,7 @@ class MindirAdjustPass : public Pass { | |||
| void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | |||
| void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | |||
| int ValueNodeInt64Convert(AnfNodePtr anf_node); | |||
| void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } | |||
| int ParameterNodeConvert(AnfNodePtr anf_node); | |||
| int PrimitiveConvert(AnfNodePtr anf_node); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| @@ -40,6 +41,7 @@ class MindirAdjustPass : public Pass { | |||
| protected: | |||
| QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | |||
| FmkType fmk_type_ = FmkType::FmkType_MS; | |||
| bool train_flag_ = false; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | |||
| @@ -131,12 +131,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| @@ -213,10 +211,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||
| #ifdef SUPPORT_TRAIN | |||
| ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && | |||
| ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && | |||
| #endif | |||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| @@ -43,11 +43,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||
| #ifdef SUPPORT_TRAIN | |||
| && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput | |||
| #endif | |||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||
| type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput && | |||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||