From: @hangangqiang Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/15102/MERGE
| @@ -121,8 +121,8 @@ int RunConverter(int argc, const char **argv) { | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| auto meta_graph = converter->Convert(flags); | |||
| NoSupportOp::GetInstance()->PrintOps(); | |||
| status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | |||
| NotSupportOp::GetInstance()->PrintOps(); | |||
| status = ReturnCode::GetSingleReturnCode()->status_code(); | |||
| if (meta_graph == nullptr) { | |||
| oss.clear(); | |||
| oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); | |||
| @@ -28,67 +28,67 @@ namespace mindspore { | |||
| namespace lite { | |||
| class ReturnCode { | |||
| public: | |||
| ~ReturnCode() = default; | |||
| virtual ~ReturnCode() = default; | |||
| static ReturnCode *GetSingleReturnCode() { | |||
| static ReturnCode returnCode; | |||
| return &returnCode; | |||
| static ReturnCode return_code; | |||
| return &return_code; | |||
| } | |||
| void UpdateReturnCode(STATUS status) { | |||
| if (statusCode == RET_OK) { | |||
| statusCode = status; | |||
| if (status_code_ == RET_OK) { | |||
| status_code_ = status; | |||
| } | |||
| } | |||
| STATUS GetReturnCode() const { return statusCode; } | |||
| STATUS status_code() const { return status_code_; } | |||
| private: | |||
| ReturnCode() { statusCode = RET_OK; } | |||
| int statusCode; | |||
| ReturnCode() = default; | |||
| int status_code_ = RET_OK; | |||
| }; | |||
| class NoSupportOp { | |||
| class NotSupportOp { | |||
| public: | |||
| ~NoSupportOp() = default; | |||
| static NoSupportOp *GetInstance() { | |||
| static NoSupportOp noSupportOp; | |||
| return &noSupportOp; | |||
| virtual ~NotSupportOp() = default; | |||
| static NotSupportOp *GetInstance() { | |||
| static NotSupportOp not_support_op; | |||
| return ¬_support_op; | |||
| } | |||
| void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; } | |||
| void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } | |||
| void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; } | |||
| void InsertOp(const std::string &op_name) { not_support_ops_.insert(op_name); } | |||
| void PrintOps() const { | |||
| if (!noSupportOps.empty()) { | |||
| if (!not_support_ops_.empty()) { | |||
| MS_LOG(ERROR) << "==========================================="; | |||
| MS_LOG(ERROR) << "UNSUPPORTED OP LIST:"; | |||
| for (auto &op_name : noSupportOps) { | |||
| MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name; | |||
| for (auto &op_name : not_support_ops_) { | |||
| MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name; | |||
| } | |||
| MS_LOG(ERROR) << "==========================================="; | |||
| } | |||
| } | |||
| private: | |||
| NoSupportOp() { noSupportOps.clear(); } | |||
| std::set<std::string> noSupportOps; | |||
| std::string fmkType; | |||
| NotSupportOp() = default; | |||
| std::set<std::string> not_support_ops_; | |||
| std::string fmk_type_; | |||
| }; | |||
| class TensorDataType { | |||
| public: | |||
| ~TensorDataType() = default; | |||
| static TensorDataType *GetInstance() { | |||
| static TensorDataType tensorDataType; | |||
| return &tensorDataType; | |||
| static TensorDataType tensor_data_type; | |||
| return &tensor_data_type; | |||
| } | |||
| void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; } | |||
| void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; } | |||
| int32_t GetTensorType(int32_t index) const { | |||
| if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) { | |||
| if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) { | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| return tensorDataTypeMap.at(index); | |||
| return tensor_data_type_map_.at(index); | |||
| } | |||
| private: | |||
| TensorDataType() {} | |||
| std::map<int32_t, int32_t> tensorDataTypeMap; | |||
| std::map<int32_t, int32_t> tensor_data_type_map_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -30,17 +30,17 @@ Flags::Flags() { | |||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | |||
| AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", | |||
| ""); | |||
| AddFlag(&Flags::inputDataTypeIn, "inputDataType", | |||
| AddFlag(&Flags::inputDataTypeStr, "inputDataType", | |||
| "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT", | |||
| "DEFAULT"); | |||
| AddFlag(&Flags::outputDataTypeIn, "outputDataType", | |||
| AddFlag(&Flags::outputDataTypeStr, "outputDataType", | |||
| "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | " | |||
| "UINT8 | DEFAULT", | |||
| "DEFAULT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::quantTypeStr, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); | |||
| AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||
| AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0"); | |||
| AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16"); | |||
| AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); | |||
| AddFlag(&Flags::trainModelIn, "trainModel", | |||
| "whether the model is going to be trained on device. " | |||
| @@ -49,32 +49,32 @@ Flags::Flags() { | |||
| } | |||
| int Flags::InitInputOutputDataType() { | |||
| if (this->inputDataTypeIn == "FLOAT") { | |||
| if (this->inputDataTypeStr == "FLOAT") { | |||
| this->inputDataType = TypeId::kNumberTypeFloat32; | |||
| } else if (this->inputDataTypeIn == "INT8") { | |||
| } else if (this->inputDataTypeStr == "INT8") { | |||
| this->inputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inputDataTypeIn == "UINT8") { | |||
| } else if (this->inputDataTypeStr == "UINT8") { | |||
| this->inputDataType = TypeId::kNumberTypeUInt8; | |||
| } else if (this->inputDataTypeIn == "DEFAULT") { | |||
| } else if (this->inputDataTypeStr == "DEFAULT") { | |||
| this->inputDataType = TypeId::kTypeUnknown; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||
| this->inputDataTypeIn.c_str(); | |||
| this->inputDataTypeStr.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->outputDataTypeIn == "FLOAT") { | |||
| if (this->outputDataTypeStr == "FLOAT") { | |||
| this->outputDataType = TypeId::kNumberTypeFloat32; | |||
| } else if (this->outputDataTypeIn == "INT8") { | |||
| } else if (this->outputDataTypeStr == "INT8") { | |||
| this->outputDataType = TypeId::kNumberTypeInt8; | |||
| } else if (this->outputDataTypeIn == "UINT8") { | |||
| } else if (this->outputDataTypeStr == "UINT8") { | |||
| this->outputDataType = TypeId::kNumberTypeUInt8; | |||
| } else if (this->outputDataTypeIn == "DEFAULT") { | |||
| } else if (this->outputDataTypeStr == "DEFAULT") { | |||
| this->outputDataType = TypeId::kTypeUnknown; | |||
| } else { | |||
| std::cerr | |||
| << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT", | |||
| this->outputDataTypeIn.c_str(); | |||
| this->outputDataTypeStr.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| return RET_OK; | |||
| @@ -110,7 +110,7 @@ bool Flags::IsValidNum(const std::string &str, int *num) { | |||
| } | |||
| int Flags::QuantParamInputCheck() { | |||
| if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) { | |||
| if (!Flags::IsValidNum(this->quantWeightChannelStr, &this->quantWeightChannel)) { | |||
| std::cerr << "quantWeightChannel should be a valid number."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -118,7 +118,7 @@ int Flags::QuantParamInputCheck() { | |||
| std::cerr << "quantWeightChannel should be greater than or equal to zero."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) { | |||
| if (!Flags::IsValidNum(this->quantWeightSizeStr, &this->quantWeightSize)) { | |||
| std::cerr << "quantWeightSize should be a valid number."; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -138,11 +138,11 @@ int Flags::QuantParamInputCheck() { | |||
| } | |||
| int Flags::InitQuantParam() { | |||
| if (this->quantTypeIn == "WeightQuant") { | |||
| if (this->quantTypeStr == "WeightQuant") { | |||
| this->quantType = QuantType_WeightQuant; | |||
| } else if (this->quantTypeIn == "PostTraining") { | |||
| } else if (this->quantTypeStr == "PostTraining") { | |||
| this->quantType = QuantType_PostTraining; | |||
| } else if (this->quantTypeIn.empty()) { | |||
| } else if (this->quantTypeStr.empty()) { | |||
| this->quantType = QuantType_QUANT_NONE; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; | |||
| @@ -65,25 +65,20 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| std::string fmkIn; | |||
| FmkType fmk; | |||
| std::string weightFile; | |||
| std::string inputArrays; | |||
| std::string outputArrays; | |||
| std::string inputShapes; | |||
| // used for quantization | |||
| std::string quantTypeIn; | |||
| QuantType quantType; | |||
| std::string inferenceTypeIn; | |||
| std::string inputDataTypeIn; | |||
| std::string outputDataTypeIn; | |||
| // used for parse aware trainning | |||
| TypeId inputDataType; | |||
| TypeId outputDataType; | |||
| // used for quantization | |||
| std::string quantTypeStr; | |||
| QuantType quantType; | |||
| std::string inputDataTypeStr; | |||
| std::string outputDataTypeStr; | |||
| // used for post-trainning-weight | |||
| std::string quantWeightSizeIn; | |||
| std::string quantWeightSizeStr; | |||
| int quantWeightSize; | |||
| std::string bitNumIn; | |||
| int bitNum; | |||
| std::string configFile; | |||
| std::string quantWeightChannelIn; | |||
| std::string quantWeightChannelStr; | |||
| int quantWeightChannel; | |||
| std::string trainModelIn; | |||
| bool trainModel = false; | |||
| @@ -47,8 +47,8 @@ namespace mindspore::lite { | |||
| std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { | |||
| std::vector<schema::CNodeT *> old_nodes{}; | |||
| old_nodes.resize(graphDefT->nodes.size()); | |||
| std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(), | |||
| old_nodes.resize(graph_defT_->nodes.size()); | |||
| std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(), | |||
| [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); | |||
| return old_nodes; | |||
| } | |||
| @@ -57,33 +57,33 @@ GraphDefTransform::GraphDefTransform() = default; | |||
| GraphDefTransform::~GraphDefTransform() = default; | |||
| void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } | |||
| void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; } | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| { | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| Optimizer unused_op_remove_optimizer; | |||
| if (!ctx.trainModel) { | |||
| unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); | |||
| unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass()); | |||
| } | |||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||
| unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = unused_op_remove_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // generate and infer quant parameters | |||
| { | |||
| Optimizer inferQuantParamPass; | |||
| inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = inferQuantParamPass.Run(graphDefT); | |||
| Optimizer infer_quant_param_pass; | |||
| infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| status = infer_quant_param_pass.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -93,40 +93,40 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| Optimizer format_trans_optimizer; | |||
| auto format_trans_pass = new (std::nothrow) FormatTransPass(); | |||
| if (format_trans_pass == nullptr) { | |||
| MS_LOG(ERROR) << "new formatTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| formatTransPass->SetQuantType(ctx.quantType); | |||
| formatTransPass->SetFmk(ctx.fmk); | |||
| formatTransOptimizer.AddPass(formatTransPass); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| format_trans_pass->set_quant_type(ctx.quantType); | |||
| format_trans_pass->set_fmk_type(ctx.fmk); | |||
| format_trans_optimizer.AddPass(format_trans_pass); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| if (ctx.fmk != converter::FmkType_TF) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| status = format_trans_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| Optimizer format_trans_optimizer; | |||
| format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = format_trans_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -134,15 +134,15 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer formatTransOptimizer; | |||
| Optimizer format_trans_optimizer; | |||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| status = format_trans_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -151,7 +151,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer fusionOptimizer; | |||
| Optimizer replace_optimizer; | |||
| if (!ctx.trainModel) { | |||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||
| if (batch_norm_scale_pass == nullptr) { | |||
| @@ -159,13 +159,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| return RET_ERROR; | |||
| } | |||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||
| fusionOptimizer.AddPass(batch_norm_scale_pass); | |||
| replace_optimizer.AddPass(batch_norm_scale_pass); | |||
| } | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); | |||
| status = replace_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; | |||
| MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -173,13 +173,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| Optimizer fusion_optimizer; | |||
| fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass()); | |||
| fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = fusion_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -188,12 +188,12 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| if (ctx.fmk != converter::FmkType_TF) { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer tensorQuantOptimizer; | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = tensorQuantOptimizer.Run(graphDefT); | |||
| Optimizer tensor_quant_optimizer; | |||
| tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = tensor_quant_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| return status; | |||
| @@ -204,31 +204,31 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| if (ctx.fmk != converter::FmkType_TF) { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer quantNodeOptimizer; | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| Optimizer quant_node_optimizer; | |||
| quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| status = quant_node_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| auto old_nodes2 = GetGraphNodes(); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||
| auto dtype_trans_pass = new (std::nothrow) DTypeTransPass(); | |||
| if (dtype_trans_pass == nullptr) { | |||
| MS_LOG(ERROR) << "new dtype_trans_pass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputDataType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.outputDataType); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| dtype_trans_pass->set_input_data_dtype(ctx.inputDataType); | |||
| dtype_trans_pass->set_output_data_dtype(ctx.outputDataType); | |||
| quant_node_optimizer.AddPass(dtype_trans_pass); | |||
| quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); | |||
| status = quant_node_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -237,22 +237,22 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer switchOptimizer; | |||
| switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); | |||
| switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = switchOptimizer.Run(graphDefT); | |||
| Optimizer switch_optimizer; | |||
| switch_optimizer.AddPass(new (std::nothrow) SwitchPass()); | |||
| switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| status = switch_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run switch graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run switch_optimizer Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // subgraph tensor pass | |||
| { | |||
| Optimizer subgraphTensorOptimizer; | |||
| subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||
| status = subgraphTensorOptimizer.Run(graphDefT); | |||
| Optimizer subgraph_tensor_optimizer; | |||
| subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||
| status = subgraph_tensor_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; | |||
| return status; | |||
| @@ -263,33 +263,33 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| { | |||
| // init old node indices | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer nameOptimizer; | |||
| nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||
| status = nameOptimizer.Run(graphDefT); | |||
| Optimizer name_optimizer; | |||
| name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| name_optimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||
| status = name_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| Optimizer nestedLoopOptimizer; | |||
| nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); | |||
| status = nestedLoopOptimizer.Run(graphDefT); | |||
| Optimizer nested_loop_optimizer; | |||
| nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); | |||
| status = nested_loop_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| Optimizer quantNodeOptimizer; | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| Optimizer quant_param_optimizer; | |||
| quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||
| status = quant_param_optimizer.Run(graph_defT_); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -36,12 +36,12 @@ class GraphDefTransform { | |||
| GraphDefTransform(); | |||
| virtual ~GraphDefTransform(); | |||
| virtual int Transform(const converter::Flags &ctx); | |||
| void SetGraphDef(schema::MetaGraphT *dstDef); | |||
| inline schema::MetaGraphT *GetOutput() { return graphDefT; } | |||
| void SetGraphDef(schema::MetaGraphT *dst_def); | |||
| inline schema::MetaGraphT *GetOutput() { return graph_defT_; } | |||
| protected: | |||
| std::vector<schema::CNodeT *> GetGraphNodes(); | |||
| schema::MetaGraphT *graphDefT = nullptr; | |||
| schema::MetaGraphT *graph_defT_ = nullptr; | |||
| Optimizer *optimizer = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| @@ -55,34 +55,35 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | |||
| auto &graph_in_idxes = graph->inputIndex; | |||
| if (this->input_data_dtype != TypeId::kNumberTypeFloat32 && this->input_data_dtype != TypeId::kNumberTypeUInt8 && | |||
| this->input_data_dtype != TypeId::kNumberTypeInt8 && this->input_data_dtype != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype; | |||
| return RET_ERROR; | |||
| } | |||
| for (auto graphInIdx : graphInIdxes) { | |||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphInIdx); | |||
| for (auto graph_in_idx : graph_in_idxes) { | |||
| MS_ASSERT(graph_in_idx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graph_in_idx); | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown | |||
| ? this->inputDataDType | |||
| : TensorDataType::GetInstance()->GetTensorType(graphInIdx); | |||
| int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown | |||
| ? this->input_data_dtype | |||
| : TensorDataType::GetInstance()->GetTensorType(graph_in_idx); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto nodeName = (*iter)->name; | |||
| for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { | |||
| if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) { | |||
| auto node_name = (*iter)->name; | |||
| for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) { | |||
| if ((*iter)->inputIndex.at(input_indexidx) == graph_in_idx) { | |||
| STATUS status = RET_OK; | |||
| // insert dtype cast node between input tensor and input node | |||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status); | |||
| if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { | |||
| iter = | |||
| InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed"; | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode before " << node_name.c_str() << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -94,33 +95,34 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | |||
| if (this->output_data_dtype != TypeId::kNumberTypeFloat32 && this->output_data_dtype != TypeId::kNumberTypeUInt8 && | |||
| this->output_data_dtype != TypeId::kNumberTypeInt8 && this->output_data_dtype != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->output_data_dtype; | |||
| return RET_ERROR; | |||
| } | |||
| auto &graphOutIdxes = graph->outputIndex; | |||
| for (auto graphOutIdx : graphOutIdxes) { | |||
| MS_ASSERT(graphOutIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphOutIdx); | |||
| auto &graph_out_idxes = graph->outputIndex; | |||
| for (auto graph_out_idx : graph_out_idxes) { | |||
| MS_ASSERT(graph_out_idx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graph_out_idx); | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown | |||
| ? this->outputDataDType | |||
| : TensorDataType::GetInstance()->GetTensorType(graphOutIdx); | |||
| int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown | |||
| ? this->output_data_dtype | |||
| : TensorDataType::GetInstance()->GetTensorType(graph_out_idx); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto nodeName = (*iter)->name; | |||
| auto node_name = (*iter)->name; | |||
| MS_ASSERT(node != nullptr); | |||
| for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status); | |||
| if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { | |||
| iter = | |||
| InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << node_name.c_str() << " failed"; | |||
| return status; | |||
| } | |||
| break; | |||
| @@ -231,52 +233,53 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, int32_t inputDataType, int32_t outputDataType, | |||
| STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| auto existNodeName = (*existNodeIter)->name; | |||
| std::string tileName; | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, int32_t input_data_type, int32_t output_data_type, | |||
| STATUS *error_code) { | |||
| MS_ASSERT((*exist_node_iter) != nullptr); | |||
| auto exist_node_name = (*exist_node_iter)->name; | |||
| std::string tile_name; | |||
| if (place == kBefore) { | |||
| tileName = existNodeName + "_pre"; | |||
| tile_name = exist_node_name + "_pre"; | |||
| } else { | |||
| tileName = existNodeName + "_post"; | |||
| tile_name = exist_node_name + "_post"; | |||
| } | |||
| auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | |||
| if (transNode == nullptr) { | |||
| auto trans_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | |||
| if (trans_node == nullptr) { | |||
| MS_LOG(ERROR) << "new TransNode failed"; | |||
| *errorCode = RET_ERROR; | |||
| *error_code = RET_ERROR; | |||
| return graph->nodes.end(); | |||
| } | |||
| auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; | |||
| if (quantDTypeCastParam == nullptr) { | |||
| auto quant_dtype_cast_param = new (std::nothrow) QuantDTypeCastT; | |||
| if (quant_dtype_cast_param == nullptr) { | |||
| MS_LOG(ERROR) << "new quantDTypeCastParam failed"; | |||
| *errorCode = RET_ERROR; | |||
| *error_code = RET_ERROR; | |||
| return graph->nodes.end(); | |||
| } | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | |||
| transNode->quantType = QuantType_AwareTraining; | |||
| quantDTypeCastParam->src_t = inputDataType; | |||
| quantDTypeCastParam->dst_t = outputDataType; | |||
| if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) { | |||
| transNode->name = "int8toft32_" + tileName + std::to_string(id++); | |||
| } else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) { | |||
| transNode->name = "ft32toint8_" + tileName + std::to_string(id++); | |||
| } else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) { | |||
| transNode->name = "uint8toint8_" + tileName + std::to_string(id++); | |||
| } else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) { | |||
| transNode->name = "int8touint8_" + tileName + std::to_string(id++); | |||
| trans_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| trans_node->primitive->value.value = quant_dtype_cast_param; | |||
| trans_node->primitive->value.type = PrimitiveType_QuantDTypeCast; | |||
| trans_node->quantType = QuantType_AwareTraining; | |||
| quant_dtype_cast_param->src_t = input_data_type; | |||
| quant_dtype_cast_param->dst_t = output_data_type; | |||
| if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeFloat32) { | |||
| trans_node->name = "int8toft32_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeInt8) { | |||
| trans_node->name = "ft32toint8_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeInt8) { | |||
| trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) { | |||
| trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++); | |||
| } | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| trans_node->primitive->value.value = quant_dtype_cast_param; | |||
| int insert_num = 0; | |||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, castOpCopyer); | |||
| return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, | |||
| castOpCopyer); | |||
| } | |||
| void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } | |||
| void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; } | |||
| void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; } | |||
| void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -30,15 +30,15 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 } | |||
| class DTypeTransPass : public GraphPass { | |||
| public: | |||
| DTypeTransPass() : id(0) {} | |||
| DTypeTransPass() : id_(0) {} | |||
| ~DTypeTransPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| void SetInputDataDType(TypeId dataType); | |||
| void set_input_data_dtype(TypeId data_type); | |||
| void SetOutputDataDType(TypeId dataType); | |||
| void set_output_data_dtype(TypeId dataType); | |||
| private: | |||
| STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); | |||
| @@ -51,13 +51,14 @@ class DTypeTransPass : public GraphPass { | |||
| STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| int32_t inputDataType, int32_t outputDataType, STATUS *errorCode); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, int32_t input_data_type, int32_t output_data_type, | |||
| STATUS *error_code); | |||
| private: | |||
| size_t id; | |||
| TypeId inputDataDType = TypeId::kNumberTypeFloat; | |||
| TypeId outputDataDType = TypeId::kNumberTypeFloat; | |||
| size_t id_; | |||
| TypeId input_data_dtype = TypeId::kNumberTypeFloat; | |||
| TypeId output_data_dtype = TypeId::kNumberTypeFloat; | |||
| OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> { | |||
| std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT); | |||
| @@ -45,32 +45,32 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, | |||
| FormatTransNodeType *afterNodeType) { | |||
| STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, | |||
| FormatTransNodeType *after_node_type) { | |||
| if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc | |||
| if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *beforeNodeType = kNHWC2NCHW; | |||
| *afterNodeType = kNCHW2NHWC; | |||
| *before_node_type = kNHWC2NCHW; | |||
| *after_node_type = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || | |||
| fmk_type_ == converter::FmkType_ONNX) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *beforeNodeType = kNCHW2NHWC; | |||
| *afterNodeType = kNHWC2NCHW; | |||
| *before_node_type = kNCHW2NHWC; | |||
| *after_node_type = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } else if (fmk_type_ == converter::FmkType_TF) { | |||
| if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) { | |||
| *beforeNodeType = kNCHW2NHWC; | |||
| *afterNodeType = kNHWC2NCHW; | |||
| *before_node_type = kNCHW2NHWC; | |||
| *after_node_type = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } | |||
| if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| *beforeNodeType = kNHWC2NCHW; | |||
| *afterNodeType = kNCHW2NHWC; | |||
| *before_node_type = kNHWC2NCHW; | |||
| *after_node_type = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } | |||
| return RET_NO_CHANGE; | |||
| @@ -96,36 +96,34 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| } | |||
| auto graphInputIdxes = graph->inputIndex; | |||
| for (size_t i = 0; i < graphInputIdxes.size(); i++) { | |||
| auto graph_input_idxes = graph->inputIndex; | |||
| for (size_t i = 0; i < graph_input_idxes.size(); i++) { | |||
| bool transed = false; | |||
| auto inputIdx = graphInputIdxes.at(i); | |||
| MS_ASSERT(inputIdx < subGraph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(inputIdx); | |||
| auto input_idx = graph_input_idxes.at(i); | |||
| auto &tensor = graph->allTensors.at(input_idx); | |||
| if (tensor->dims.size() != kNCHWDimNumber) { | |||
| continue; | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { | |||
| if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) { | |||
| for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) { | |||
| if ((*iter)->inputIndex.at(input_index_idx) == input_idx) { | |||
| STATUS status = RET_OK; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| // set first tensor format to nhwc | |||
| auto &transNode = *(iter - 1); | |||
| MS_ASSERT(transNode != nullptr); | |||
| MS_ASSERT(transNode->inputIndex.size() == 1); | |||
| MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); | |||
| auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); | |||
| graphInTensor->format = schema::Format::Format_NHWC; | |||
| auto &trans_node = *(iter - 1); | |||
| MS_ASSERT(trans_node != nullptr); | |||
| MS_ASSERT(trans_node->inputIndex.size() == 1); | |||
| auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front()); | |||
| graph_in_tensor->format = schema::Format::Format_NHWC; | |||
| // assume parser not reformat shape | |||
| auto oldDims = graphInTensor->dims; | |||
| auto old_dims = graph_in_tensor->dims; | |||
| if (!transed) { | |||
| graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; | |||
| graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]}; | |||
| transed = true; | |||
| } | |||
| } | |||
| @@ -143,10 +141,10 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert before and after the op cal by nchw/nc4hw4 | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| FormatTransNodeType beforeNodeType = kNCHW2NHWC; | |||
| FormatTransNodeType afterNodeType = kNHWC2NCHW; | |||
| FormatTransNodeType before_node_type = kNCHW2NHWC; | |||
| FormatTransNodeType after_node_type = kNHWC2NCHW; | |||
| STATUS status = RET_OK; | |||
| status = GetInsertFormatTrans(**iter, &beforeNodeType, &afterNodeType); | |||
| status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type); | |||
| if (status == RET_NO_CHANGE) { | |||
| continue; | |||
| } | |||
| @@ -170,17 +168,17 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { | |||
| reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; | |||
| } | |||
| auto specInsertIndexes = GetExtNhwcIndexes(); | |||
| auto opType = GetCNodeTType(**iter); | |||
| if (specInsertIndexes.find(opType) != specInsertIndexes.end()) { | |||
| for (auto insert_index : specInsertIndexes[opType]) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status); | |||
| auto spec_insert_indexes = GetExtNhwcIndexes(); | |||
| auto op_type = GetCNodeTType(**iter); | |||
| if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) { | |||
| for (auto insert_index : spec_insert_indexes[op_type]) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (IsContain(GetNhwcAllInputOpList(), opType)) { | |||
| } else if (IsContain(GetNhwcAllInputOpList(), op_type)) { | |||
| auto input_size = node->inputIndex.size(); | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { | |||
| if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { | |||
| @@ -188,16 +186,16 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| } | |||
| } | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status); | |||
| } | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| @@ -206,29 +204,29 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| return RET_OK; | |||
| } | |||
| NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) { | |||
| MS_ASSERT((*exist_node_iter) != nullptr); | |||
| MS_ASSERT(graph != nullptr); | |||
| auto existNodeName = (*existNodeIter)->name; | |||
| std::string tileName; | |||
| auto exist_node_name = (*exist_node_iter)->name; | |||
| std::string tile_name; | |||
| if (place == kBefore) { | |||
| tileName = existNodeName + "_pre"; | |||
| tile_name = exist_node_name + "_pre"; | |||
| } else { | |||
| tileName = existNodeName + "_post"; | |||
| tile_name = exist_node_name + "_post"; | |||
| } | |||
| auto transNode = std::make_unique<schema::CNodeT>(); | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| auto trans_node = std::make_unique<schema::CNodeT>(); | |||
| trans_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| trans_node->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| auto perm_tensor = std::make_unique<schema::TensorT>(); | |||
| perm_tensor->dataType = kNumberTypeInt32; | |||
| perm_tensor->dims = {4}; | |||
| std::vector<int> perm; | |||
| if (nodeType == kNCHW2NHWC) { | |||
| transNode->name = "nchw2nhwc_" + tileName + std::to_string(id_++); | |||
| if (node_type == kNCHW2NHWC) { | |||
| trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++); | |||
| perm = {0, 2, 3, 1}; | |||
| } else { | |||
| transNode->name = "nhwc2nchw_" + tileName + std::to_string(id_++); | |||
| trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++); | |||
| perm = {0, 3, 1, 2}; | |||
| } | |||
| size_t bytes = perm.size() * sizeof(int); | |||
| @@ -236,27 +234,27 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI | |||
| if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| } | |||
| perm_tensor->name = transNode->name + "_perm"; | |||
| perm_tensor->name = trans_node->name + "_perm"; | |||
| OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { | |||
| auto newOpDef = std::make_unique<schema::CNodeT>(); | |||
| if (newOpDef == nullptr) { | |||
| OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> { | |||
| auto new_op_def = std::make_unique<schema::CNodeT>(); | |||
| if (new_op_def == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->name = inOpDef->name; | |||
| newOpDef->quantType = inOpDef->quantType; | |||
| newOpDef->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (newOpDef->primitive == nullptr) { | |||
| new_op_def->name = in_op_def->name; | |||
| new_op_def->quantType = in_op_def->quantType; | |||
| new_op_def->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (new_op_def->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| return newOpDef; | |||
| new_op_def->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| return new_op_def; | |||
| }; | |||
| int insert_num = 0; | |||
| auto iter = | |||
| InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, TransposeOpCopyer); | |||
| auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, | |||
| transpose_op_copyer); | |||
| size_t index = graph->allTensors.size(); | |||
| graph->allTensors.push_back(std::move(perm_tensor)); | |||
| for (int i = insert_num; i > 0; --i) { | |||
| @@ -34,13 +34,13 @@ class FormatTransPass : public GraphPass { | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| void SetQuantType(QuantType quantType) { this->quant_type_ = quantType; } | |||
| void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; } | |||
| void SetFmk(converter::FmkType fmkType) { this->fmk_type_ = fmkType; } | |||
| void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; } | |||
| protected: | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| FormatTransNodeType nodeType, STATUS *errorCode); | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code); | |||
| STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| @@ -61,8 +61,8 @@ class FormatTransPass : public GraphPass { | |||
| int GetFormat(const schema::CNodeT &); | |||
| STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, | |||
| FormatTransNodeType *afterNodeType); | |||
| STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, | |||
| FormatTransNodeType *after_node_type); | |||
| protected: | |||
| size_t id_ = 0; | |||
| @@ -20,13 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| Optimizer::~Optimizer() { | |||
| for (auto pass : graphPasses) { | |||
| for (auto pass : graph_passes_) { | |||
| if (pass != nullptr) { | |||
| delete (pass); | |||
| } | |||
| } | |||
| for (auto pass : nodePasses) { | |||
| for (auto pass : node_passes_) { | |||
| if (pass != nullptr) { | |||
| delete (pass); | |||
| } | |||
| @@ -35,13 +35,13 @@ Optimizer::~Optimizer() { | |||
| void Optimizer::AddPass(GraphPass *graphPass) { | |||
| if (graphPass != nullptr) { | |||
| this->graphPasses.emplace_back(graphPass); | |||
| this->graph_passes_.emplace_back(graphPass); | |||
| } | |||
| } | |||
| void Optimizer::AddPass(NodePass *nodePass) { | |||
| if (nodePass != nullptr) { | |||
| this->nodePasses.emplace_back(nodePass); | |||
| this->node_passes_.emplace_back(nodePass); | |||
| } | |||
| } | |||
| @@ -51,7 +51,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { | |||
| bool ifNotChanged = true; | |||
| // each node should go through all node pass not each node pass go through all node | |||
| for (auto &opDef : graphDefT->nodes) { | |||
| for (auto pass : this->nodePasses) { | |||
| for (auto pass : this->node_passes_) { | |||
| status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get())); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run NodePass failed"; | |||
| @@ -64,7 +64,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { | |||
| } | |||
| } | |||
| for (auto pass : this->graphPasses) { | |||
| for (auto pass : this->graph_passes_) { | |||
| status = pass->Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run GraphPass failed"; | |||
| @@ -41,10 +41,10 @@ class GraphPass : public Pass<schema::MetaGraphT> { | |||
| }; | |||
| struct GraphNode { | |||
| GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {} | |||
| GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : sub_graph_(subGraph), op_def_(opDefT) {} | |||
| ~GraphNode() = default; | |||
| schema::MetaGraphT *subGraph = nullptr; | |||
| schema::CNodeT *opDef = nullptr; | |||
| schema::MetaGraphT *sub_graph_ = nullptr; | |||
| schema::CNodeT *op_def_ = nullptr; | |||
| }; | |||
| class NodePass : public Pass<GraphNode> { | |||
| @@ -72,8 +72,8 @@ class Optimizer { | |||
| STATUS Run(schema::MetaGraphT *graphDefT); | |||
| private: | |||
| std::vector<GraphPass *> graphPasses; | |||
| std::vector<NodePass *> nodePasses; | |||
| std::vector<GraphPass *> graph_passes_; | |||
| std::vector<NodePass *> node_passes_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -98,7 +98,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| MS_LOG(INFO) << "parse op : " << layer.type(); | |||
| auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type()); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(layer.type()); | |||
| NotSupportOp::GetInstance()->InsertOp(layer.type()); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| continue; | |||
| } | |||
| @@ -47,7 +47,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||
| FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) { | |||
| NoSupportOp::GetInstance()->SetFmkType("ONNX"); | |||
| NotSupportOp::GetInstance()->set_fmk_type("ONNX"); | |||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||
| auto status = InitOriginModel(model_file); | |||
| if (RET_OK != status) { | |||
| @@ -195,7 +195,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F | |||
| for (const auto &onnx_node : onnx_graph.node()) { | |||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | |||
| NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | |||
| status = status == RET_OK ? RET_NOT_FIND_OP : status; | |||
| MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type(); | |||
| } | |||
| @@ -476,7 +476,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( | |||
| FuncGraphPtr paserTfFuction() { return nullptr; } | |||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| NoSupportOp::GetInstance()->SetFmkType("TF"); | |||
| NotSupportOp::GetInstance()->set_fmk_type("TF"); | |||
| auto status = ValidateFileStr(modelFile, ".pb"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; | |||
| @@ -888,7 +888,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| MS_LOG(INFO) << "parse op : " << op_type; | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| NotSupportOp::GetInstance()->InsertOp(op_type); | |||
| MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in " | |||
| << func_graph_ptr->get_attr("graph_name")->ToString(); | |||
| return RET_NOT_FIND_OP; | |||
| @@ -101,7 +101,7 @@ std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, | |||
| STATUS TfliteModelParser::ConvertOps() { | |||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||
| NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | |||
| NotSupportOp::GetInstance()->set_fmk_type("TFLITE"); | |||
| STATUS status = RET_OK; | |||
| int op_idx = 0; | |||
| for (auto &op : tflite_subgraph->operators) { | |||
| @@ -113,7 +113,7 @@ STATUS TfliteModelParser::ConvertOps() { | |||
| MS_LOG(INFO) << "parse node :" << op_name; | |||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| NotSupportOp::GetInstance()->InsertOp(op_type); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| continue; | |||
| } | |||
| @@ -1344,7 +1344,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| // add quant_cast | |||
| quant::QuantCast quant_cast; | |||
| quant_cast.SetInputDataDType(kNumberTypeFloat32); | |||
| quant_cast.set_input_data_dtype(kNumberTypeFloat32); | |||
| status = quant_cast.Run(func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "add QuantCast error"; | |||
| @@ -26,12 +26,12 @@ namespace mindspore::lite::quant { | |||
| class QuantCast { | |||
| public: | |||
| QuantCast() = default; | |||
| ~QuantCast() = default; | |||
| virtual ~QuantCast() = default; | |||
| STATUS Run(const FuncGraphPtr &graph); | |||
| void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } | |||
| void set_input_data_dtype(TypeId data_type) { this->input_data_dtype_ = data_type; } | |||
| private: | |||
| TypeId inputDataDType = kNumberTypeFloat32; | |||
| TypeId input_data_dtype_ = kNumberTypeFloat32; | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H | |||