From: @zhengjun10 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -114,16 +114,6 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||||
| max_dim = dim; | max_dim = dim; | ||||
| } | } | ||||
| } | } | ||||
| #ifndef SUPPORT_TRAIN | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| size_t shift = max_dims - inputs.at(i)->shape().size(); | |||||
| size_t dim = (i < shift) ? 1 : inputs.at(i)->shape().at(d); | |||||
| if ((dim != max_dim) && (dim != 1)) { | |||||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| output->shape()[d] = max_dim; // set the biggest dimension in the output tensor | output->shape()[d] = max_dim; // set the biggest dimension in the output tensor | ||||
| } | } | ||||
| @@ -149,13 +149,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||||
| attr->padRight = pad_list.at(3); | attr->padRight = pad_list.at(3); | ||||
| auto dilation = CastToInt(prim.GetAttr("dilation")); | auto dilation = CastToInt(prim.GetAttr("dilation")); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| attr->dilateH = dilation.at(2); | |||||
| attr->dilateW = dilation.at(3); | |||||
| #else | |||||
| attr->dilateH = dilation.at(0); | |||||
| attr->dilateW = dilation.at(1); | |||||
| #endif | |||||
| if (train_flag()) { | |||||
| attr->dilateH = dilation.at(2); | |||||
| attr->dilateW = dilation.at(3); | |||||
| } else { | |||||
| attr->dilateH = dilation.at(0); | |||||
| attr->dilateW = dilation.at(1); | |||||
| } | |||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | ||||
| attr->kernelH = kernel_size.at(0); | attr->kernelH = kernel_size.at(0); | ||||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | ||||
| @@ -19,9 +19,6 @@ | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| #endif | #endif | ||||
| #ifdef SUPPORT_TRAIN | |||||
| #include <tuple> | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -56,20 +53,16 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| } | } | ||||
| string paddingmode = "REFLECT"; | string paddingmode = "REFLECT"; | ||||
| if (prim.GetAttr("mode") == nullptr) { | if (prim.GetAttr("mode") == nullptr) { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| if (prim.name() == "Pad") { | if (prim.name() == "Pad") { | ||||
| paddingmode = "CONSTANT"; | paddingmode = "CONSTANT"; | ||||
| } else { | } else { | ||||
| #endif | |||||
| MS_LOG(ERROR) << "get mode failed!"; | MS_LOG(ERROR) << "get mode failed!"; | ||||
| delete this->primitive_; | delete this->primitive_; | ||||
| delete attr; | delete attr; | ||||
| this->primitive_ = nullptr; | this->primitive_ = nullptr; | ||||
| attr = nullptr; | attr = nullptr; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } | } | ||||
| #endif | |||||
| } else { | } else { | ||||
| paddingmode = GetValue<string>(prim.GetAttr("mode")); | paddingmode = GetValue<string>(prim.GetAttr("mode")); | ||||
| } | } | ||||
| @@ -77,7 +70,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| attr->paddingMode = schema::PaddingMode_REFLECT; | attr->paddingMode = schema::PaddingMode_REFLECT; | ||||
| } else if (paddingmode == "SYMMETRIC") { | } else if (paddingmode == "SYMMETRIC") { | ||||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | attr->paddingMode = schema::PaddingMode_SYMMETRIC; | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } else if (paddingmode == "CONSTANT") { | } else if (paddingmode == "CONSTANT") { | ||||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | attr->paddingMode = schema::PaddingMode_CONSTANT; | ||||
| if (prim.GetAttr("paddings") != nullptr) { | if (prim.GetAttr("paddings") != nullptr) { | ||||
| @@ -91,7 +83,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs | |||||
| attr->paddings.push_back(i); | attr->paddings.push_back(i); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "model type not supported!"; | MS_LOG(ERROR) << "model type not supported!"; | ||||
| delete this->primitive_; | delete this->primitive_; | ||||
| @@ -18,7 +18,6 @@ | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #include "src/ops/assert_op.h" | #include "src/ops/assert_op.h" | ||||
| #include "src/ops/space_to_batch.h" | #include "src/ops/space_to_batch.h" | ||||
| @@ -175,8 +174,6 @@ | |||||
| #include "src/ops/uniform_real.h" | #include "src/ops/uniform_real.h" | ||||
| #include "src/ops/rank.h" | #include "src/ops/rank.h" | ||||
| #include "src/ops/is_finite.h" | #include "src/ops/is_finite.h" | ||||
| #ifdef SUPPORT_TRAIN | |||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| #include "src/ops/activation_grad.h" | #include "src/ops/activation_grad.h" | ||||
| #include "src/ops/apply_momentum.h" | #include "src/ops/apply_momentum.h" | ||||
| @@ -210,7 +207,6 @@ | |||||
| #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" | #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" | ||||
| #include "src/ops/strided_slice_grad.h" | #include "src/ops/strided_slice_grad.h" | ||||
| #endif | #endif | ||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -513,13 +509,14 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { | |||||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | ||||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs, | std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType) { | |||||
| const schema::QuantType &quantType, bool train_flag = false) { | |||||
| auto primc = std::make_shared<T>(); | auto primc = std::make_shared<T>(); | ||||
| if (primc == nullptr) { | if (primc == nullptr) { | ||||
| MS_LOG(ERROR) << "make_shared PrimitiveC failed"; | MS_LOG(ERROR) << "make_shared PrimitiveC failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| primc->set_quant_type(quantType); | primc->set_quant_type(quantType); | ||||
| primc->set_train_flag(train_flag); | |||||
| auto ret = primc->UnPackAttr(prim, inputs); | auto ret = primc->UnPackAttr(prim, inputs); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "UnPackAttr failed"; | MS_LOG(ERROR) << "UnPackAttr failed"; | ||||
| @@ -529,7 +526,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, cons | |||||
| } | } | ||||
| std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType) { | |||||
| const schema::QuantType &quantType, bool train_flag) { | |||||
| const auto &op_type = prim.name(); | const auto &op_type = prim.name(); | ||||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | ||||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | return NewPrimitiveC<Activation>(prim, inputs, quantType); | ||||
| @@ -544,7 +541,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| } else if (op_type == "Concat") { | } else if (op_type == "Concat") { | ||||
| return NewPrimitiveC<Concat>(prim, inputs, quantType); | return NewPrimitiveC<Concat>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2D") { | } else if (op_type == "Conv2D") { | ||||
| return NewPrimitiveC<Conv2D>(prim, inputs, quantType); | |||||
| return NewPrimitiveC<Conv2D>(prim, inputs, quantType, train_flag); | |||||
| } else if (op_type == "Cos") { | } else if (op_type == "Cos") { | ||||
| return NewPrimitiveC<Cos>(prim, inputs, quantType); | return NewPrimitiveC<Cos>(prim, inputs, quantType); | ||||
| } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { | } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { | ||||
| @@ -664,7 +661,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| } else if (op_type == "Range") { | } else if (op_type == "Range") { | ||||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | return NewPrimitiveC<Range>(prim, inputs, quantType); | ||||
| } else if (op_type == "Tile") { | } else if (op_type == "Tile") { | ||||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||||
| return NewPrimitiveC<Tile>(prim, inputs, quantType, train_flag); | |||||
| } else if (op_type == "GatherNd") { | } else if (op_type == "GatherNd") { | ||||
| return NewPrimitiveC<GatherNd>(prim, inputs, quantType); | return NewPrimitiveC<GatherNd>(prim, inputs, quantType); | ||||
| } else if (op_type == "Square") { | } else if (op_type == "Square") { | ||||
| @@ -685,7 +682,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | ||||
| } else if (op_type == "Gelu") { | } else if (op_type == "Gelu") { | ||||
| return NewPrimitiveC<GeLU>(prim, inputs, quantType); | return NewPrimitiveC<GeLU>(prim, inputs, quantType); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | ||||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | ||||
| } else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") { | ||||
| @@ -706,7 +702,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2DBackpropFilter") { | } else if (op_type == "Conv2DBackpropFilter") { | ||||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2DBackpropInput") { | |||||
| } else if (op_type == "Conv2DBackpropInput" && train_flag) { | |||||
| return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | ||||
| } else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) { | } else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) { | ||||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | ||||
| @@ -748,10 +744,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType); | return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "AbsGrad") { | } else if (op_type == "AbsGrad") { | ||||
| return NewPrimitiveC<AbsGrad>(prim, inputs, quantType); | return NewPrimitiveC<AbsGrad>(prim, inputs, quantType); | ||||
| #else | |||||
| } else if (op_type == "Conv2DBackpropInput") { | |||||
| } else if (op_type == "Conv2DBackpropInput" && !train_flag) { | |||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type; | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -1065,7 +1059,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) UniformReal(primitive); | return new (std::nothrow) UniformReal(primitive); | ||||
| case schema::PrimitiveType_Rank: | case schema::PrimitiveType_Rank: | ||||
| return new (std::nothrow) Rank(primitive); | return new (std::nothrow) Rank(primitive); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| case schema::PrimitiveType_PoolingGrad: | case schema::PrimitiveType_PoolingGrad: | ||||
| @@ -1140,7 +1133,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); | return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); | ||||
| case schema::PrimitiveType_StridedSliceGrad: | case schema::PrimitiveType_StridedSliceGrad: | ||||
| return new (std::nothrow) StridedSliceGrad(primitive); | return new (std::nothrow) StridedSliceGrad(primitive); | ||||
| #endif | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | ||||
| break; | break; | ||||
| @@ -1170,6 +1162,10 @@ bool PrimitiveC::infer_flag() const { return this->infer_flag_; } | |||||
| void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; } | void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; } | ||||
| bool PrimitiveC::train_flag() const { return this->train_flag_; } | |||||
| void PrimitiveC::set_train_flag(bool flag) { this->train_flag_ = flag; } | |||||
| int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| auto input = inputs.front(); | auto input = inputs.front(); | ||||
| MS_ASSERT(input != nullptr); | MS_ASSERT(input != nullptr); | ||||
| @@ -133,6 +133,10 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| void set_infer_flag(bool flag); | void set_infer_flag(bool flag); | ||||
| bool train_flag() const; | |||||
| void set_train_flag(bool flag); | |||||
| static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); } | static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); } | ||||
| static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); | static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); | ||||
| @@ -140,7 +144,7 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data); | static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data); | ||||
| static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType); | |||||
| const schema::QuantType &quantType, bool train_flag = false); | |||||
| void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | ||||
| void FillDefaultInputQuantParamIfNeed(const size_t &inputSize); | void FillDefaultInputQuantParamIfNeed(const size_t &inputSize); | ||||
| void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| @@ -159,6 +163,7 @@ class PrimitiveC : public mindspore::Primitive { | |||||
| bool infer_flag_ = true; | bool infer_flag_ = true; | ||||
| int op_type_ = OP_TYPE_NOT_SET; | int op_type_ = OP_TYPE_NOT_SET; | ||||
| bool enable_huffman_code_ = false; | bool enable_huffman_code_ = false; | ||||
| bool train_flag_ = false; | |||||
| }; | }; | ||||
| std::shared_ptr<PrimitiveC> GetReturnPrim(); | std::shared_ptr<PrimitiveC> GetReturnPrim(); | ||||
| @@ -179,6 +184,10 @@ class PrimitiveC { | |||||
| void set_infer_flag(bool flag); | void set_infer_flag(bool flag); | ||||
| bool train_flag() const; | |||||
| void set_train_flag(bool flag); | |||||
| virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs); | virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs); | ||||
| int Type() const; | int Type() const; | ||||
| @@ -238,6 +247,7 @@ class PrimitiveC { | |||||
| bool infer_flag_ = true; | bool infer_flag_ = true; | ||||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | ||||
| int op_type_ = OP_TYPE_NOT_SET; | int op_type_ = OP_TYPE_NOT_SET; | ||||
| bool train_flag_ = false; | |||||
| }; | }; | ||||
| using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | using PrimitiveCPtr = std::shared_ptr<PrimitiveC>; | ||||
| typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); | ||||
| @@ -159,41 +159,41 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||||
| } else { | } else { | ||||
| multiples = GetMultiples(); | multiples = GetMultiples(); | ||||
| } | } | ||||
| #ifdef SUPPORT_TRAIN | |||||
| const size_t in_dims = input->shape().size(); | |||||
| const size_t delta_dims = in_dims - multiples.size(); | |||||
| size_t i = 0; | |||||
| for (; i < delta_dims; ++i) { | |||||
| int tmp = input->shape().at(i); | |||||
| out_shape.push_back(tmp); | |||||
| } | |||||
| for (; i < in_dims; ++i) { | |||||
| int tmp = input->shape().at(i) * (multiples[i - delta_dims]); | |||||
| out_shape.push_back(tmp); | |||||
| } | |||||
| #else | |||||
| std::vector<int> dims = GetDims(); | |||||
| if (inputs_.size() == 2 && dims.empty()) { | |||||
| for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { | |||||
| dims.push_back(dim); | |||||
| if (train_flag()) { | |||||
| const size_t in_dims = input->shape().size(); | |||||
| const size_t delta_dims = in_dims - multiples.size(); | |||||
| size_t i = 0; | |||||
| for (; i < delta_dims; ++i) { | |||||
| int tmp = input->shape().at(i); | |||||
| out_shape.push_back(tmp); | |||||
| } | } | ||||
| } | |||||
| const size_t in_dims = input->shape().size(); | |||||
| for (; i < in_dims; ++i) { | |||||
| int tmp = input->shape().at(i) * (multiples[i - delta_dims]); | |||||
| out_shape.push_back(tmp); | |||||
| } | |||||
| } else { | |||||
| std::vector<int> dims = GetDims(); | |||||
| if (inputs_.size() == 2 && dims.empty()) { | |||||
| for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) { | |||||
| dims.push_back(dim); | |||||
| } | |||||
| } | |||||
| const size_t in_dims = input->shape().size(); | |||||
| MS_ASSERT(multiples.size() == dims.size()); | |||||
| for (size_t i = 0; i < in_dims; ++i) { | |||||
| out_shape.push_back(input->shape().at(i)); | |||||
| } | |||||
| for (size_t i = 0; i < dims.size(); ++i) { | |||||
| if (input->shape().at(dims.at(i)) != 0 && | |||||
| multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) { | |||||
| MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; | |||||
| return RET_ERROR; | |||||
| MS_ASSERT(multiples.size() == dims.size()); | |||||
| for (size_t i = 0; i < in_dims; ++i) { | |||||
| out_shape.push_back(input->shape().at(i)); | |||||
| } | |||||
| for (size_t i = 0; i < dims.size(); ++i) { | |||||
| if (input->shape().at(dims.at(i)) != 0 && | |||||
| multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) { | |||||
| MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); | |||||
| } | } | ||||
| out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i)); | |||||
| } | } | ||||
| #endif | |||||
| output->set_shape(out_shape); | output->set_shape(out_shape); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||||
| } | } | ||||
| RemoveIfMakeTuple(cnode); | RemoveIfMakeTuple(cnode); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| RemoveIfDepend(cnode); | |||||
| #endif | |||||
| if (train_flag) { | |||||
| RemoveIfDepend(cnode); | |||||
| if (primitive_c->Type() == schema::PrimitiveType_Depend || | |||||
| primitive_c->Type() == schema::PrimitiveType_ControlDepend) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | ||||
| #ifdef SUPPORT_TRAIN | |||||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||||
| #endif | |||||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, | |||||
| bool train_flag) { | |||||
| static int subgraph_index = 0; | static int subgraph_index = 0; | ||||
| this->train_flag = train_flag; | |||||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | ||||
| int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); | int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||||
| std::string input_name = input_anode->fullname_with_scope(); | std::string input_name = input_anode->fullname_with_scope(); | ||||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | auto input_cnode = utils::cast<CNodePtr>(input_anode); | ||||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | ||||
| #ifndef SUPPORT_TRAIN | |||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||||
| } | |||||
| #else | |||||
| bool found = false; | bool found = false; | ||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | if (node_id_map_.find(input_name) != node_id_map_.end()) { | ||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | ||||
| found = true; | found = true; | ||||
| } | } | ||||
| if (found == false) { | |||||
| if (!found) { | |||||
| auto input_index_key = input_name + "_o:" + std::to_string(0); | auto input_index_key = input_name + "_o:" + std::to_string(0); | ||||
| if (node_id_map_.find(input_index_key) != node_id_map_.end()) { | if (node_id_map_.find(input_index_key) != node_id_map_.end()) { | ||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); | output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| } else { | } else { | ||||
| auto inputs = input_cnode->inputs(); | auto inputs = input_cnode->inputs(); | ||||
| @@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||||
| : GetValue<int>(value_node->value())); | : GetValue<int>(value_node->value())); | ||||
| auto iter = node_id_map_.find(input_index_key); | auto iter = node_id_map_.find(input_index_key); | ||||
| if (iter == node_id_map_.end()) { | if (iter == node_id_map_.end()) { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 | ||||
| iter = node_id_map_.find(input_index_key); | iter = node_id_map_.find(input_index_key); | ||||
| if (iter == node_id_map_.end()) { | if (iter == node_id_map_.end()) { | ||||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| #else | |||||
| MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; | |||||
| return RET_ERROR; | |||||
| #endif | |||||
| } | } | ||||
| output_cnode->inputIndex.emplace_back(iter->second); | output_cnode->inputIndex.emplace_back(iter->second); | ||||
| } | } | ||||
| @@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc | |||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | ||||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | [](const int64_t &value) { return static_cast<int32_t>(value); }); | ||||
| (*paramTensor)->dims = dims; | (*paramTensor)->dims = dims; | ||||
| #ifdef SUPPORT_TRAIN | |||||
| if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1}; | |||||
| #endif | |||||
| if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; | |||||
| (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; | (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| auto data = value->cast<tensor::TensorPtr>(); | auto data = value->cast<tensor::TensorPtr>(); | ||||
| (*paramTensor)->data.resize(data->Size()); | (*paramTensor)->data.resize(data->Size()); | ||||
| @@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu | |||||
| (*paramTensor)->format = schema::Format(valueLite->format()); | (*paramTensor)->format = schema::Format(valueLite->format()); | ||||
| (*paramTensor)->dataType = valueLite->tensor_type(); | (*paramTensor)->dataType = valueLite->tensor_type(); | ||||
| (*paramTensor)->dims = valueLite->tensor_shape(); | (*paramTensor)->dims = valueLite->tensor_shape(); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| if ((*paramTensor)->dims.size() == 0) { | |||||
| if (train_flag && (*paramTensor)->dims.empty()) { | |||||
| (*paramTensor)->dims = {1}; | (*paramTensor)->dims = {1}; | ||||
| } | } | ||||
| #endif | |||||
| ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), | ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), | ||||
| valueLite->tensor_size()); | valueLite->tensor_size()); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano | |||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | auto paramTensor = std::make_unique<schema::TensorT>(); | ||||
| auto value = valueNode->value(); | auto value = valueNode->value(); | ||||
| int ret = RET_OK; | int ret = RET_OK; | ||||
| #ifdef SUPPORT_TRAIN | |||||
| paramTensor->name = valueNode->fullname_with_scope(); | |||||
| #endif | |||||
| if (train_flag) { | |||||
| paramTensor->name = valueNode->fullname_with_scope(); | |||||
| } | |||||
| if (value->isa<tensor::Tensor>()) { | if (value->isa<tensor::Tensor>()) { | ||||
| ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); | ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); | ||||
| } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) { | } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) { | ||||
| @@ -797,44 +786,44 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||||
| } | } | ||||
| msTensor->nodeType = schema::NodeType_CNode; | msTensor->nodeType = schema::NodeType_CNode; | ||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||||
| meta_graphT->allTensors.emplace_back(msTensor); | |||||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||||
| break; | |||||
| #else | |||||
| if (elements.size() == 1) { | |||||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||||
| msTensor->name = cnode_name; | |||||
| } else { | |||||
| if (train_flag) { | |||||
| std::string name = cnode_name + "_o:" + std::to_string(i); | std::string name = cnode_name + "_o:" + std::to_string(i); | ||||
| node_id_map_[name] = meta_graphT->allTensors.size(); | node_id_map_[name] = meta_graphT->allTensors.size(); | ||||
| msTensor->name = name; | |||||
| } | |||||
| meta_graphT->allTensors.emplace_back(msTensor); | |||||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||||
| break; | |||||
| } else { | |||||
| if (elements.size() == 1) { | |||||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||||
| msTensor->name = cnode_name; | |||||
| } else { | |||||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||||
| msTensor->name = name; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||||
| delete (msTensor); | |||||
| return; | |||||
| } | |||||
| auto type = kNumberTypeFloat32; | |||||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||||
| type = typePtr->type_id(); | |||||
| } | |||||
| msTensor->dataType = type; | |||||
| meta_graphT->allTensors.emplace_back(msTensor); | |||||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { | |||||
| break; | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||||
| delete (msTensor); | |||||
| return; | |||||
| } | |||||
| auto type = kNumberTypeFloat32; | |||||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||||
| type = typePtr->type_id(); | |||||
| } | |||||
| msTensor->dataType = type; | |||||
| meta_graphT->allTensors.emplace_back(msTensor); | |||||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| #endif | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto ms_tensor = new (std::nothrow) schema::TensorT(); | auto ms_tensor = new (std::nothrow) schema::TensorT(); | ||||
| @@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node | |||||
| } | } | ||||
| } | } | ||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) { | |||||
| AnfExporter anf_exporter; | AnfExporter anf_exporter; | ||||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive); | |||||
| return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag); | |||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -35,7 +35,8 @@ class AnfExporter { | |||||
| public: | public: | ||||
| AnfExporter() = default; | AnfExporter() = default; | ||||
| virtual ~AnfExporter() = default; | virtual ~AnfExporter() = default; | ||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, | |||||
| bool train_flag = false); | |||||
| void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| schema::CNodeT *fb_node); | schema::CNodeT *fb_node); | ||||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | ||||
| @@ -91,11 +92,13 @@ class AnfExporter { | |||||
| std::vector<schema::CNodeT *> graph_input_nodes_; | std::vector<schema::CNodeT *> graph_input_nodes_; | ||||
| std::map<FuncGraphPtr, int> fg_subgraph_map; | std::map<FuncGraphPtr, int> fg_subgraph_map; | ||||
| uint32_t node_idx = 0; | uint32_t node_idx = 0; | ||||
| bool train_flag = false; | |||||
| }; | }; | ||||
| // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. | ||||
| // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify | ||||
| // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple | // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple | ||||
| // and clear. | // and clear. | ||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, | |||||
| bool train_flag = false); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | ||||
| @@ -855,14 +855,14 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model | |||||
| } | } | ||||
| int AnfImporterFromMindir::Import(const converter::Flags *flag) { | int AnfImporterFromMindir::Import(const converter::Flags *flag) { | ||||
| #if SUPPORT_TRAIN | |||||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||||
| if (func_graph_ != nullptr) { | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; | |||||
| if (flag->trainModel) { | |||||
| func_graph_ = LoadMindIR(flag->modelFile, true); | |||||
| if (func_graph_ != nullptr) { | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format"; | |||||
| } | |||||
| } | } | ||||
| #endif | |||||
| onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | onnx_model_ = ReadOnnxFromBinary(flag->modelFile); | ||||
| if (onnx_model_ == nullptr) { | if (onnx_model_ == nullptr) { | ||||
| MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; | ||||
| @@ -24,39 +24,40 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||||
| #ifdef SUPPORT_TRAIN | |||||
| schema::PrimitiveType_Conv2DGradFilter, | |||||
| schema::PrimitiveType_Conv2DGradInput, | |||||
| schema::PrimitiveType_GroupConv2DGradInput, | |||||
| schema::PrimitiveType_PoolingGrad, | |||||
| schema::PrimitiveType_BiasGrad, | |||||
| schema::PrimitiveType_BNGrad, | |||||
| schema::PrimitiveType_ApplyMomentum, | |||||
| schema::PrimitiveType_Sgd, | |||||
| schema::PrimitiveType_Adam, | |||||
| #endif | |||||
| schema::PrimitiveType_Conv2D, | |||||
| schema::PrimitiveType_DeConv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_Pooling, | |||||
| schema::PrimitiveType_LocalResponseNormalization, | |||||
| schema::PrimitiveType_Resize, | |||||
| schema::PrimitiveType_BatchNorm, | |||||
| schema::PrimitiveType_FusedBatchNorm, | |||||
| schema::PrimitiveType_PReLU, | |||||
| schema::PrimitiveType_BiasAdd, | |||||
| schema::PrimitiveType_SpaceToDepth, | |||||
| schema::PrimitiveType_DepthToSpace, | |||||
| schema::PrimitiveType_TopK}; | |||||
| static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DGradFilter, | |||||
| schema::PrimitiveType_Conv2DGradInput, | |||||
| schema::PrimitiveType_GroupConv2DGradInput, | |||||
| schema::PrimitiveType_PoolingGrad, | |||||
| schema::PrimitiveType_BiasGrad, | |||||
| schema::PrimitiveType_BNGrad, | |||||
| schema::PrimitiveType_ApplyMomentum, | |||||
| schema::PrimitiveType_Sgd, | |||||
| schema::PrimitiveType_Adam, | |||||
| schema::PrimitiveType_Conv2D, | |||||
| schema::PrimitiveType_DeConv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_Pooling, | |||||
| schema::PrimitiveType_LocalResponseNormalization, | |||||
| schema::PrimitiveType_Resize, | |||||
| schema::PrimitiveType_BatchNorm, | |||||
| schema::PrimitiveType_FusedBatchNorm, | |||||
| schema::PrimitiveType_PReLU, | |||||
| schema::PrimitiveType_BiasAdd, | |||||
| schema::PrimitiveType_SpaceToDepth, | |||||
| schema::PrimitiveType_DepthToSpace, | |||||
| schema::PrimitiveType_TopK}; | |||||
| static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter, | schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter, | ||||
| schema::PrimitiveType_BNGrad | |||||
| #endif | |||||
| }; | |||||
| schema::PrimitiveType_BNGrad}; | |||||
| // index {} mean all inputs need insert | |||||
| static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = { | |||||
| {schema::PrimitiveType_BNGrad, {0, 1}}, | |||||
| {schema::PrimitiveType_ApplyMomentum, {3}}, | |||||
| {schema::PrimitiveType_Sgd, {1}}, | |||||
| {schema::PrimitiveType_Adam, {9}}}; | |||||
| static const std::vector<schema::PrimitiveType> fp32FullOpList = { | static const std::vector<schema::PrimitiveType> fp32FullOpList = { | ||||
| schema::PrimitiveType_Concat, schema::PrimitiveType_Add, | schema::PrimitiveType_Concat, schema::PrimitiveType_Add, | ||||
| @@ -133,18 +134,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT | |||||
| schema::PrimitiveType_L2Norm}; | schema::PrimitiveType_L2Norm}; | ||||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | static const std::vector<schema::PrimitiveType> needInsertOpList = { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add, | |||||
| schema::PrimitiveType_ActivationGrad | |||||
| #else | |||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | ||||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | ||||
| schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, | schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, | ||||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum | |||||
| #endif | |||||
| }; | |||||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad}; | |||||
| static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; | static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; | ||||
| @@ -156,6 +149,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; | |||||
| std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; } | ||||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; } | |||||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; } | ||||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; } | std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; } | ||||
| @@ -62,6 +62,8 @@ std::vector<schema::PrimitiveType> GetNhwcOpList(); | |||||
| std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | std::vector<schema::PrimitiveType> GetNhwcAllInputOpList(); | ||||
| std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes(); | |||||
| std::vector<schema::PrimitiveType> Getfp32FullOpList(); | std::vector<schema::PrimitiveType> Getfp32FullOpList(); | ||||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList(); | std::vector<schema::PrimitiveType> GetUint8NhwcOpList(); | ||||
| @@ -101,12 +101,6 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/dequant.cc | ${SRC_DIR}/dequant.cc | ||||
| ${SRC_DIR}/huffman_decode.cc | ${SRC_DIR}/huffman_decode.cc | ||||
| ) | ) | ||||
| if(SUPPORT_TRAIN) | |||||
| set(LITE_SRC | |||||
| ${LITE_SRC} | |||||
| ) | |||||
| endif() | |||||
| set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) | set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) | ||||
| file(GLOB KERNEL_SRC | file(GLOB KERNEL_SRC | ||||
| ${ARM_DIR}/base/*.cc | ${ARM_DIR}/base/*.cc | ||||
| @@ -177,6 +177,7 @@ int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const conve | |||||
| auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); | ||||
| mindir_adjust_pass->SetFmkType(config->fmk); | mindir_adjust_pass->SetFmkType(config->fmk); | ||||
| mindir_adjust_pass->SetQuantType(config->quantType); | mindir_adjust_pass->SetQuantType(config->quantType); | ||||
| mindir_adjust_pass->SetTrainFlag(config->trainModel); | |||||
| if (!mindir_adjust_pass->Run(old_graph)) { | if (!mindir_adjust_pass->Run(old_graph)) { | ||||
| MS_LOG(ERROR) << "mindir adjust failed."; | MS_LOG(ERROR) << "mindir adjust failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| @@ -88,7 +88,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| } | } | ||||
| // anf -- fb | // anf -- fb | ||||
| auto meta_graph = Export(graph); | |||||
| auto meta_graph = Export(graph, false, false, flag->trainModel); | |||||
| if (meta_graph == nullptr) { | if (meta_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | MS_LOG(ERROR) << "Export to meta graph return nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -48,21 +48,8 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT | |||||
| FormatTransNodeType *afterNodeType) { | FormatTransNodeType *afterNodeType) { | ||||
| if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | ||||
| return RET_NO_CHANGE; | return RET_NO_CHANGE; | ||||
| } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | |||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| *beforeNodeType = kNCHW2NHWC; | |||||
| *afterNodeType = kNHWC2NCHW; | |||||
| return RET_OK; | |||||
| } else if (fmkType == converter::FmkType_MS) { | |||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| *beforeNodeType = kNCHW2NHWC; | |||||
| *afterNodeType = kNHWC2NCHW; | |||||
| return RET_OK; | |||||
| } else if (fmkType == converter::FmkType_ONNX) { | |||||
| } else if (fmkType == converter::FmkType_CAFFE || fmkType == converter::FmkType_MS || | |||||
| fmkType == converter::FmkType_ONNX) { | |||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | ||||
| return RET_NO_CHANGE; | return RET_NO_CHANGE; | ||||
| } | } | ||||
| @@ -173,11 +160,19 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { | if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { | ||||
| reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; | reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; | ||||
| } | } | ||||
| #ifdef SUPPORT_TRAIN | |||||
| if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) { | |||||
| int idx_num = node->inputIndex.size(); | |||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_BNGrad) idx_num = 2; | |||||
| for (int i = 0; i < idx_num; i++) { | |||||
| auto specInsertIndexes = GetExtNhwcIndexes(); | |||||
| auto opType = GetCNodeTType(**iter); | |||||
| if (specInsertIndexes.find(opType) != specInsertIndexes.end()) { | |||||
| for (auto insert_index : specInsertIndexes[opType]) { | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } else if (IsContain(GetNhwcAllInputOpList(), opType)) { | |||||
| auto input_size = node->inputIndex.size(); | |||||
| for (size_t i = 0; i < input_size; i++) { | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | ||||
| @@ -185,23 +180,8 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| int idx = 0; | |||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; | |||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; | |||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| #else | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||||
| return RET_ERROR; | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||||
| } | } | ||||
| #endif | |||||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | ||||
| @@ -194,7 +194,6 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc | |||||
| if (!IsContain(bfs_queue, input_node_index)) { | if (!IsContain(bfs_queue, input_node_index)) { | ||||
| bfs_queue.emplace_back(input_node_index); | bfs_queue.emplace_back(input_node_index); | ||||
| } | } | ||||
| // todo multi output,other edge need insert nh2nc node | |||||
| auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); | auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); | ||||
| if (pre_node_output_indexs.size() != 1) { | if (pre_node_output_indexs.size() != 1) { | ||||
| if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { | if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { | ||||
| @@ -29,30 +29,25 @@ std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | ||||
| } // namespace | } // namespace | ||||
| namespace lite { | namespace lite { | ||||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||||
| pre_type_ = kNONE; | |||||
| size_t has_trans_count = 0; | |||||
| auto can_fusion = true; | |||||
| for (auto input_node_index : input_node_indexes) { | |||||
| bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count, | |||||
| FormatTransNodeType *trans_type) { | |||||
| for (auto input_node_index : node_indexes) { | |||||
| MS_ASSERT(graph->nodes.size() > input_node_index); | MS_ASSERT(graph->nodes.size() > input_node_index); | ||||
| auto &pre_node = graph->nodes.at(input_node_index); | auto &pre_node = graph->nodes.at(input_node_index); | ||||
| MS_ASSERT(pre_node != nullptr); | MS_ASSERT(pre_node != nullptr); | ||||
| MS_ASSERT(pre_node->primitive != nullptr); | MS_ASSERT(pre_node->primitive != nullptr); | ||||
| MS_ASSERT(pre_node->primitive->value != nullptr); | MS_ASSERT(pre_node->primitive->value != nullptr); | ||||
| if (pre_type_ == kNONE) { | |||||
| if (*trans_type == kNONE) { | |||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | ||||
| MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr); | MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr); | ||||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | ||||
| pre_type_ = kNCHW2NHWC; | |||||
| *trans_type = kNCHW2NHWC; | |||||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | ||||
| pre_type_ = kNHWC2NCHW; | |||||
| *trans_type = kNHWC2NCHW; | |||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| has_trans_count++; | |||||
| (*has_trans_count)++; | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | ||||
| @@ -64,57 +59,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (pre_type_ != cur_type) { | |||||
| can_fusion = false; | |||||
| break; | |||||
| if (*trans_type != cur_type) { | |||||
| return false; | |||||
| } else { | } else { | ||||
| has_trans_count++; | |||||
| (*has_trans_count)++; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (!can_fusion) { | |||||
| return true; | |||||
| } | |||||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||||
| pre_type_ = kNONE; | |||||
| size_t has_trans_count = 0; | |||||
| if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | ||||
| post_type_ = kNONE; | post_type_ = kNONE; | ||||
| for (auto output_node_index : output_node_indexes) { | |||||
| MS_ASSERT(graph->nodes.size() > output_node_index); | |||||
| auto &post_node = graph->nodes.at(output_node_index); | |||||
| MS_ASSERT(post_node != nullptr); | |||||
| MS_ASSERT(post_node->primitive != nullptr); | |||||
| MS_ASSERT(post_node->primitive->value != nullptr); | |||||
| if (post_type_ == kNONE) { | |||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| post_type_ = kNCHW2NHWC; | |||||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| post_type_ = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| has_trans_count++; | |||||
| } | |||||
| } else { | |||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| auto cur_type = kNONE; | |||||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| cur_type = kNCHW2NHWC; | |||||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| cur_type = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (post_type_ != cur_type) { | |||||
| can_fusion = false; | |||||
| break; | |||||
| } else { | |||||
| has_trans_count++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!can_fusion) { | |||||
| if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (pre_type_ == kNONE && post_type_ == kNONE) { | if (pre_type_ == kNONE && post_type_ == kNONE) { | ||||
| @@ -136,10 +102,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { | if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { | ||||
| return has_trans_count >= half_count; | return has_trans_count >= half_count; | ||||
| } | } | ||||
| can_fusion = has_trans_count > half_count; | |||||
| return can_fusion; | |||||
| return has_trans_count > half_count; | |||||
| } | } | ||||
| STATUS TransOpInsertPass::FindOutTransType() { | STATUS TransOpInsertPass::FindOutTransType() { | ||||
| pre_insert_trans_type_ = kNHWC2NCHW; | pre_insert_trans_type_ = kNHWC2NCHW; | ||||
| post_insert_trans_type_ = kNHWC2NCHW; | post_insert_trans_type_ = kNHWC2NCHW; | ||||
| @@ -153,7 +117,7 @@ STATUS TransOpInsertPass::FindOutTransType() { | |||||
| MS_ASSERT(false); | MS_ASSERT(false); | ||||
| } else { | } else { | ||||
| if (pre_type_ == post_type_) { | if (pre_type_ == post_type_) { | ||||
| MS_LOG(ERROR) << "Unknow error"; | |||||
| MS_LOG(ERROR) << "Unknown error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | ||||
| @@ -200,13 +164,6 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { | |||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| auto input_tensor_size = (*iter)->inputIndex.size(); | auto input_tensor_size = (*iter)->inputIndex.size(); | ||||
| for (size_t i = 0; i < input_tensor_size; i++) { | for (size_t i = 0; i < input_tensor_size; i++) { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| auto &tensor = graph->allTensors.at((*iter)->inputIndex[i]); | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| if (tensor->nodeType == schema::NodeType_ValueNode) { | |||||
| continue; | |||||
| } | |||||
| #endif | |||||
| auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); | auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); | ||||
| if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { | if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { | ||||
| continue; | continue; | ||||
| @@ -37,6 +37,15 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { | |||||
| return para_value_lite; | return para_value_lite; | ||||
| } | } | ||||
| bool IsSpecialType(schema::PrimitiveType type) { | |||||
| if ((type == schema::PrimitiveType_TupleGetItem) || (type == schema::PrimitiveType_Depend) || | |||||
| (type == schema::PrimitiveType_ControlDepend) || | |||||
| (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | ||||
| MS_ASSERT(nullptr != tensor); | MS_ASSERT(nullptr != tensor); | ||||
| std::vector<int> shape(tensor->shape()); | std::vector<int> shape(tensor->shape()); | ||||
| @@ -363,12 +372,7 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto type = GetCNodeType(cnode); | auto type = GetCNodeType(cnode); | ||||
| if ((type == schema::PrimitiveType_TupleGetItem) || | |||||
| #ifdef SUPPORT_TRAIN | |||||
| (type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) || | |||||
| #endif | |||||
| (type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) { | |||||
| if (IsSpecialType(type)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<lite::Tensor *> input_tensors; | std::vector<lite::Tensor *> input_tensors; | ||||
| @@ -147,7 +147,7 @@ int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) { | |||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| inputs.erase(inputs.begin()); | inputs.erase(inputs.begin()); | ||||
| if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { | if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { | ||||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_); | |||||
| auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_, train_flag_); | |||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); | MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope(); | ||||
| lite::NoSupportOp::GetInstance()->InsertOp(primitive->name()); | lite::NoSupportOp::GetInstance()->InsertOp(primitive->name()); | ||||
| @@ -33,6 +33,7 @@ class MindirAdjustPass : public Pass { | |||||
| void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | ||||
| void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | ||||
| int ValueNodeInt64Convert(AnfNodePtr anf_node); | int ValueNodeInt64Convert(AnfNodePtr anf_node); | ||||
| void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } | |||||
| int ParameterNodeConvert(AnfNodePtr anf_node); | int ParameterNodeConvert(AnfNodePtr anf_node); | ||||
| int PrimitiveConvert(AnfNodePtr anf_node); | int PrimitiveConvert(AnfNodePtr anf_node); | ||||
| bool Run(const FuncGraphPtr &graph) override; | bool Run(const FuncGraphPtr &graph) override; | ||||
| @@ -40,6 +41,7 @@ class MindirAdjustPass : public Pass { | |||||
| protected: | protected: | ||||
| QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | ||||
| FmkType fmk_type_ = FmkType::FmkType_MS; | FmkType fmk_type_ = FmkType::FmkType_MS; | ||||
| bool train_flag_ = false; | |||||
| }; | }; | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ | ||||
| @@ -131,12 +131,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||||
| param_value->set_format(schema::Format::Format_CKHW); | param_value->set_format(schema::Format::Format_CKHW); | ||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | } else if (op_type == schema::PrimitiveType_DeConv2D) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { | } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { | } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { | ||||
| param_value->set_format(schema::Format::Format_CKHW); | param_value->set_format(schema::Format::Format_CKHW); | ||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | ||||
| << ", node: " << conv_node->fullname_with_scope(); | << ", node: " << conv_node->fullname_with_scope(); | ||||
| @@ -213,10 +211,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||||
| auto conv_cnode = node->cast<CNodePtr>(); | auto conv_cnode = node->cast<CNodePtr>(); | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | ||||
| #ifdef SUPPORT_TRAIN | |||||
| ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && | ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && | ||||
| ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && | ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && | ||||
| #endif | |||||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -43,11 +43,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||||
| #ifdef SUPPORT_TRAIN | |||||
| && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput | |||||
| #endif | |||||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||||
| type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput && | |||||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto conv_cnode = node->cast<CNodePtr>(); | auto conv_cnode = node->cast<CNodePtr>(); | ||||