Merge pull request !21816 from 徐安越/master1tags/v1.5.0-rc1
| @@ -20,7 +20,6 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "include/lite_utils.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace converter { | |||
| @@ -36,7 +35,6 @@ enum MS_API FmkType : int { | |||
| /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. | |||
| struct MS_API ConverterParameters { | |||
| FmkType fmk; | |||
| schema::QuantType quant_type; | |||
| std::string model_file; | |||
| std::string weight_file; | |||
| std::map<std::string, std::string> attrs; | |||
| @@ -33,7 +33,6 @@ namespace lite { | |||
| namespace { | |||
| void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) { | |||
| converter_parameters->fmk = flag.fmk; | |||
| converter_parameters->quant_type = flag.quantType; | |||
| converter_parameters->model_file = flag.modelFile; | |||
| converter_parameters->weight_file = flag.weightFile; | |||
| } | |||
| @@ -28,7 +28,6 @@ class MindirAdjust { | |||
| public: | |||
| MindirAdjust() {} | |||
| ~MindirAdjust() = default; | |||
| void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } | |||
| void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } | |||
| void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } | |||
| bool Run(const FuncGraphPtr &graph); | |||
| @@ -37,7 +36,6 @@ class MindirAdjust { | |||
| int ValueNodeInt64Convert(AnfNodePtr anf_node); | |||
| int ComputeQuantParams(AnfNodePtr anf_node); | |||
| QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; | |||
| FmkType fmk_type_ = FmkType::kFmkTypeMs; | |||
| bool train_flag_ = false; | |||
| }; | |||
| @@ -42,7 +42,6 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const | |||
| } | |||
| auto mindir_adjust_pass = std::make_shared<MindirAdjust>(); | |||
| mindir_adjust_pass->SetFmkType(flag.fmk); | |||
| mindir_adjust_pass->SetQuantType(flag.quantType); | |||
| mindir_adjust_pass->SetTrainFlag(flag.trainModel); | |||
| if (!mindir_adjust_pass->Run(func_graph)) { | |||
| MS_LOG(ERROR) << "MindIr adjust failed."; | |||
| @@ -97,7 +96,6 @@ size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned cha | |||
| } | |||
| FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { | |||
| quant_type_ = flag.quantType; | |||
| FuncGraphPtr func_graph; | |||
| if (flag.dec_key.size() != 0) { | |||
| unsigned char key[32]; | |||
| @@ -128,7 +126,7 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel, flag.quantType); | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel); | |||
| if (!unify_format->Run(func_graph)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| @@ -31,7 +31,6 @@ class MindsporeImporter { | |||
| private: | |||
| STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag); | |||
| schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; | |||
| size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len); | |||
| }; | |||
| @@ -79,7 +79,6 @@ CaffeModelParser::~CaffeModelParser() = default; | |||
| FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| auto model_file = flag.model_file; | |||
| auto weight_file = flag.weight_file; | |||
| quant_type_ = flag.quant_type; | |||
| STATUS status = InitOriginModel(model_file, weight_file); | |||
| if (status != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -112,7 +111,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeCaffe, false, quant_type_); | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| @@ -66,7 +66,6 @@ class CaffeModelParser : public converter::ModelParser { | |||
| caffe::NetParameter caffe_weight_; | |||
| std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_; | |||
| std::unordered_map<std::string, AnfNodePtr> nodes_; | |||
| schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -60,7 +60,6 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||
| FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| string model_file = flag.model_file; | |||
| quant_type_ = flag.quant_type; | |||
| NotSupportOp::GetInstance()->set_fmk_type("ONNX"); | |||
| res_graph_ = std::make_shared<FuncGraph>(); | |||
| auto status = InitOriginModel(model_file); | |||
| @@ -95,7 +94,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeOnnx, false, quant_type_); | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| @@ -99,7 +99,6 @@ class OnnxModelParser : public converter::ModelParser { | |||
| std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_; | |||
| std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_; | |||
| std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node | |||
| schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -492,7 +492,6 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensor | |||
| FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| auto modelFile = flag.model_file; | |||
| quant_type_ = flag.quant_type; | |||
| NotSupportOp::GetInstance()->set_fmk_type("TF"); | |||
| auto status = ValidateFileStr(modelFile, ".pb"); | |||
| if (status != RET_OK) { | |||
| @@ -581,7 +580,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTf, false, quant_type_); | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| @@ -108,7 +108,6 @@ class TFModelParser : public converter::ModelParser { | |||
| std::vector<std::string> while_cond_branch_name_; | |||
| std::vector<std::string> if_then_branch_name_; | |||
| std::unordered_map<std::string, int> node_output_num_; | |||
| schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; | |||
| std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_; | |||
| }; | |||
| } // namespace lite | |||
| @@ -54,7 +54,6 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st | |||
| FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) { | |||
| auto model_file = flag.model_file; | |||
| quant_type_ = flag.quant_type; | |||
| // load graph | |||
| tflite_model_ = ReadTfliteModel(model_file); | |||
| if (tflite_model_ == nullptr) { | |||
| @@ -105,7 +104,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTflite, false, quant_type_); | |||
| auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false); | |||
| if (!unify_format->Run(res_graph_)) { | |||
| MS_LOG(ERROR) << "Run insert transpose failed."; | |||
| return nullptr; | |||
| @@ -52,7 +52,6 @@ class TfliteModelParser : public converter::ModelParser { | |||
| STATUS ConvertGraphOutputs(); | |||
| static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params, | |||
| int round_type = 1); | |||
| schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -32,8 +32,7 @@ constexpr int kNumIndex_0 = 0; | |||
| constexpr int kNumIndex_1 = 1; | |||
| constexpr int kNumIndex_2 = 2; | |||
| constexpr int kNumIndex_3 = 3; | |||
| STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, | |||
| schema::Format *src_format) { | |||
| STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| @@ -47,13 +46,13 @@ STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType | |||
| } else if (format == schema::Format_NCHW) { | |||
| *src_format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "cnode format is invalid."; | |||
| MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) { | |||
| STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| @@ -61,34 +60,22 @@ STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quan | |||
| return lite::RET_ERROR; | |||
| } | |||
| bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise)); | |||
| switch (quant_type) { | |||
| case schema::QuantType_AwareTraining: | |||
| case schema::QuantType_PostTraining: | |||
| case schema::QuantType_WeightQuant: | |||
| case schema::QuantType_QUANT_NONE: { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| if (!is_depth_wise) { | |||
| *src_format = schema::Format_HWCK; | |||
| } else { | |||
| *src_format = schema::Format_HWKC; | |||
| } | |||
| } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { | |||
| *src_format = schema::Format::Format_HWCK; | |||
| } else { | |||
| MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check."; | |||
| return RET_ERROR; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| if (!is_depth_wise) { | |||
| *src_format = schema::Format_HWCK; | |||
| } else { | |||
| *src_format = schema::Format_HWKC; | |||
| } | |||
| } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { | |||
| *src_format = schema::Format::Format_HWCK; | |||
| } else { | |||
| MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, | |||
| schema::Format *src_format) { | |||
| STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| @@ -96,87 +83,49 @@ STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType | |||
| return lite::RET_ERROR; | |||
| } | |||
| bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise)); | |||
| switch (quant_type) { | |||
| case schema::QuantType_AwareTraining: | |||
| case schema::QuantType_PostTraining: | |||
| case schema::QuantType_WeightQuant: | |||
| case schema::QuantType_QUANT_NONE: { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| if (!is_depth_wise) { | |||
| *src_format = schema::Format_KHWC; | |||
| } else { | |||
| *src_format = schema::Format_CHWK; | |||
| } | |||
| } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { | |||
| *src_format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "cannot decide weight format, current situation need to check."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||
| << ", node: " << cnode->fullname_with_scope(); | |||
| return RET_ERROR; | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| if (!is_depth_wise) { | |||
| *src_format = schema::Format_KHWC; | |||
| } else { | |||
| *src_format = schema::Format_CHWK; | |||
| } | |||
| } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { | |||
| *src_format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope(); | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) { | |||
| STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr); | |||
| *src_format = schema::Format_KCHW; | |||
| return RET_OK; | |||
| } | |||
| STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) { | |||
| STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr); | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise)); | |||
| int64_t format = | |||
| prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0; | |||
| switch (quant_type) { | |||
| case schema::QuantType_AwareTraining: { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) { | |||
| if (!is_depth_wise) { | |||
| *src_format = schema::Format_KHWC; | |||
| } else { | |||
| *src_format = schema::Format_CHWK; | |||
| } | |||
| } else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { | |||
| *src_format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } break; | |||
| case schema::QuantType_PostTraining: | |||
| case schema::QuantType_WeightQuant: | |||
| case schema::QuantType_QUANT_NONE: { | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || | |||
| opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| if (format == schema::Format_NHWC) { | |||
| *src_format = schema::Format_KHWC; | |||
| } else if (format == schema::Format_NCHW) { | |||
| *src_format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "format is invalid, format is " << format; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "d an unsupported op type, which need to check. the type is " << prim->name(); | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||
| << ", node: " << cnode->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || | |||
| opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) { | |||
| if (format == schema::Format_NHWC) { | |||
| *src_format = schema::Format_KHWC; | |||
| } else if (format == schema::Format_NCHW) { | |||
| *src_format = schema::Format_KCHW; | |||
| } else { | |||
| MS_LOG(ERROR) << "format is invalid, format is " << format; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "unknown op, please check."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -246,12 +195,12 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, | |||
| schema::Format *dst_format) { | |||
| MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr); | |||
| *dst_format = schema::Format_KHWC; | |||
| std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::QuantType, schema::Format *)>> | |||
| decide_functions = {{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat}, | |||
| {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat}, | |||
| {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat}, | |||
| {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat}, | |||
| {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}}; | |||
| std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = { | |||
| {converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat}, | |||
| {converter::kFmkTypeTf, DecideTFConvWeightSrcFormat}, | |||
| {converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat}, | |||
| {converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat}, | |||
| {converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}}; | |||
| auto iter = decide_functions.find(fmk_type_); | |||
| if (iter == decide_functions.end()) { | |||
| MS_LOG(ERROR) << "current fmk don't support, please check."; | |||
| @@ -259,7 +208,7 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, | |||
| } | |||
| auto decide_func = iter->second; | |||
| MS_ASSERT(decide_func != nullptr); | |||
| if (decide_func(cnode, quant_type_, src_format) != RET_OK) { | |||
| if (decide_func(cnode, src_format) != RET_OK) { | |||
| MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -24,9 +24,8 @@ namespace mindspore { | |||
| namespace lite { | |||
| class UnifyFormatToNHWC : public opt::ToFormatBase { | |||
| public: | |||
| explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false, | |||
| schema::QuantType quant_type = schema::QuantType_QUANT_NONE) | |||
| : ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {} | |||
| explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) | |||
| : ToFormatBase(fmk_type, train_flag) {} | |||
| ~UnifyFormatToNHWC() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| @@ -41,7 +40,6 @@ class UnifyFormatToNHWC : public opt::ToFormatBase { | |||
| bool DecideWhetherInferShapeForNewNode() override; | |||
| STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, | |||
| schema::Format *dst_format) override; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||