| @@ -47,13 +47,17 @@ using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>; | |||||
| constexpr int kAnfPopulaterInputNumOne = 1; | constexpr int kAnfPopulaterInputNumOne = 1; | ||||
| constexpr int kAnfPopulaterInputNumTwo = 2; | constexpr int kAnfPopulaterInputNumTwo = 2; | ||||
| constexpr int kAnfPopulaterInputNumThree = 3; | constexpr int kAnfPopulaterInputNumThree = 3; | ||||
| static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, | |||||
| {"ReLU6", schema::ActivationType_RELU6}, | |||||
| {"Sigmoid", schema::ActivationType_SIGMOID}, | |||||
| {"HSwish", schema::ActivationType_HSWISH}, | |||||
| {"HSigmoid", schema::ActivationType_HSIGMOID}}; | |||||
| static std::map<std::string, schema::ActivationType> kActivationTypeMap{ | |||||
| {"ReLU", schema::ActivationType_RELU}, | |||||
| {"ReLU6", schema::ActivationType_RELU6}, | |||||
| {"Sigmoid", schema::ActivationType_SIGMOID}, | |||||
| {"HSwish", schema::ActivationType_HSWISH}, | |||||
| {"HSigmoid", schema::ActivationType_HSIGMOID}, | |||||
| {"Swish", schema::ActivationType_SWISH}, | |||||
| {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, | |||||
| {"Tanh", schema::ActivationType_TANH}, | |||||
| {"Logistic", schema::ActivationType_SIGMOID}}; | |||||
| std::vector<int> CastToInt(const ValuePtr value, bool is_vector); | std::vector<int> CastToInt(const ValuePtr value, bool is_vector); | ||||
| class PrimitiveC : public mindspore::Primitive { | class PrimitiveC : public mindspore::Primitive { | ||||
| public: | public: | ||||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | ||||
| @@ -104,8 +104,8 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| MS_ASSERT(input != nullptr); | MS_ASSERT(input != nullptr); | ||||
| if (inputs_.size() != kSplitInputNum) { | |||||
| MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum; | |||||
| if (inputs_.size() < kSplitInputNum) { | |||||
| MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto output = outputs_.front(); | auto output = outputs_.front(); | ||||
| @@ -194,6 +194,8 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc | ||||
| @@ -135,6 +135,6 @@ mtk_convert_model.tflite | |||||
| mtk_model_face_dress_fp16.tflite | mtk_model_face_dress_fp16.tflite | ||||
| smartreply.tflite | smartreply.tflite | ||||
| mindspore_text_classification_tflite.tflite | mindspore_text_classification_tflite.tflite | ||||
| ml_location.tflite | |||||
| # ml_location.tflite | |||||
| ml_text_correction.tflite | ml_text_correction.tflite | ||||
| ml_pic_shopping.tflite | ml_pic_shopping.tflite | ||||
| @@ -49,6 +49,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/weight_format_transform_pass.cc | ../optimizer/graph/weight_format_transform_pass.cc | ||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/clip_convert_activation_pass.cc | ../optimizer/graph/clip_convert_activation_pass.cc | ||||
| ../optimizer/graph/group_depthwise_op_convert_pass.cc | |||||
| ../optimizer/graph/tflite_inputs_order_exchange_pass.cc | |||||
| ../optimizer/graph/unused_cast_node_remove_pass.cc | ../optimizer/graph/unused_cast_node_remove_pass.cc | ||||
| ../optimizer/graph/unused_transpose_node_remove_pass.cc | ../optimizer/graph/unused_transpose_node_remove_pass.cc | ||||
| ../optimizer/graph/identity_remove_pass.cc | ../optimizer/graph/identity_remove_pass.cc | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | #include "tools/optimizer/fusion/conv_bn_fusion.h" | ||||
| #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" | ||||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | #include "tools/optimizer/fusion/constant_folding_fusion.h" | ||||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | |||||
| #include "tools/optimizer/fusion/layer_norm_fusion.h" | #include "tools/optimizer/fusion/layer_norm_fusion.h" | ||||
| #include "tools/optimizer/fusion/batchmatmul_fusion.h" | #include "tools/optimizer/fusion/batchmatmul_fusion.h" | ||||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | ||||
| @@ -34,6 +33,8 @@ | |||||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | ||||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | #include "tools/optimizer/graph/weight_format_transform_pass.h" | ||||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | #include "tools/optimizer/graph/clip_convert_activation_pass.h" | ||||
| #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" | |||||
| #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" | |||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/infershape_pass.h" | #include "tools/optimizer/graph/infershape_pass.h" | ||||
| @@ -43,8 +44,7 @@ | |||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| using std::string; | using std::string; | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| AnfTransform::AnfTransform() = default; | AnfTransform::AnfTransform() = default; | ||||
| AnfTransform::~AnfTransform() = default; | AnfTransform::~AnfTransform() = default; | ||||
| @@ -65,7 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | ||||
| // for now - trainning is not supporting fuse operations | // for now - trainning is not supporting fuse operations | ||||
| if (config != nullptr && !config->trainModel) { | |||||
| if (!config->trainModel) { | |||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>()); | ||||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | ||||
| @@ -119,6 +119,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| } | } | ||||
| pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | ||||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | ||||
| if (config->fmk == lite::converter::FmkType_TFLITE) { | |||||
| convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); | |||||
| convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); | |||||
| } | |||||
| optimizer->AddPassManager(cf_pm); | optimizer->AddPassManager(cf_pm); | ||||
| optimizer->AddPassManager(convert_pm); | optimizer->AddPassManager(convert_pm); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| @@ -168,5 +172,4 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| return new_graph; | return new_graph; | ||||
| } | } | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -32,8 +32,9 @@ class ModelParser { | |||||
| virtual ~ModelParser() = default; | virtual ~ModelParser() = default; | ||||
| virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { | |||||
| auto *meta_graph = ParseToFb(modelFile, weightFile, quantType); | |||||
| virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| auto *meta_graph = ParseToFb(model_file, weight_file, quant_type); | |||||
| if (meta_graph == nullptr) { | if (meta_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "parse model to fb failed"; | MS_LOG(ERROR) << "parse model to fb failed"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -43,8 +44,8 @@ class ModelParser { | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType = QuantType_QUANT_NONE) = 0; | |||||
| virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) = 0; | |||||
| public: | public: | ||||
| static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { | static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { | ||||
| @@ -31,22 +31,22 @@ CaffeModelParser::~CaffeModelParser() {} | |||||
| const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | ||||
| schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| int status = ValidateFileStr(modelFile, ".prototxt"); | |||||
| schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| int status = ValidateFileStr(model_file, ".prototxt"); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (weightFile.empty()) { | |||||
| if (weight_file.empty()) { | |||||
| MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; | MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| status = ValidateFileStr(weightFile, ".caffemodel"); | |||||
| status = ValidateFileStr(weight_file, ".caffemodel"); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; | MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -57,18 +57,18 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co | |||||
| TensorCache tensorCache; | TensorCache tensorCache; | ||||
| caffe::NetParameter proto; | caffe::NetParameter proto; | ||||
| status = ReadProtoFromText((const char *)modelFile.c_str(), &proto); | |||||
| status = ReadProtoFromText((const char *)model_file.c_str(), &proto); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; | |||||
| MS_LOG(ERROR) << "Read prototxt file failed, model path: " << model_file; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| metaGraph->name = proto.name(); | metaGraph->name = proto.name(); | ||||
| caffe::NetParameter weight; | caffe::NetParameter weight; | ||||
| status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight); | |||||
| status = ReadProtoFromBinaryFile((const char *)weight_file.c_str(), &weight); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; | |||||
| MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weight_file; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co | |||||
| } | } | ||||
| NoSupportOp::GetInstance()->SetFmkType("CAFFE"); | NoSupportOp::GetInstance()->SetFmkType("CAFFE"); | ||||
| status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType); | |||||
| status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quant_type); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseLayer failed " << status; | MS_LOG(ERROR) << "ParseLayer failed " << status; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -97,7 +97,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| metaGraph->name = GetModelName(modelFile); | |||||
| metaGraph->name = GetModelName(model_file); | |||||
| SetAllTensors(tensorCache, metaGraph.get()); | SetAllTensors(tensorCache, metaGraph.get()); | ||||
| @@ -34,8 +34,8 @@ class CaffeModelParser : public ModelParser { | |||||
| virtual ~CaffeModelParser(); | virtual ~CaffeModelParser(); | ||||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||||
| private: | private: | ||||
| STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | ||||
| @@ -623,9 +623,9 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| int status = ValidateFileStr(modelFile, ".onnx"); | |||||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| int status = ValidateFileStr(model_file, ".onnx"); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| @@ -633,9 +633,9 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||||
| } | } | ||||
| onnx::ModelProto onnx_model; | onnx::ModelProto onnx_model; | ||||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model); | |||||
| status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; | |||||
| MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -645,13 +645,13 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||||
| auto dst_graph = std::make_unique<schema::MetaGraphT>(); | auto dst_graph = std::make_unique<schema::MetaGraphT>(); | ||||
| auto dst_sub_graph = std::make_unique<schema::SubGraphT>(); | auto dst_sub_graph = std::make_unique<schema::SubGraphT>(); | ||||
| int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType); | |||||
| int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type); | |||||
| dst_graph->subGraph.push_back(std::move(dst_sub_graph)); | dst_graph->subGraph.push_back(std::move(dst_sub_graph)); | ||||
| subGraphNum += 1; | subGraphNum += 1; | ||||
| if (ret == RET_ERROR) { | if (ret == RET_ERROR) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| dst_graph->name = GetModelName(modelFile); | |||||
| dst_graph->name = GetModelName(model_file); | |||||
| std::vector<uint32_t> input_temp_index; | std::vector<uint32_t> input_temp_index; | ||||
| for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) { | for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) { | ||||
| @@ -45,8 +45,8 @@ class OnnxModelParser : public ModelParser { | |||||
| int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | ||||
| const QuantType &quantType); | const QuantType &quantType); | ||||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | ||||
| @@ -1,273 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/model_parser_for_tflite.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include "src/param_value_lite.h" | |||||
| namespace mindspore::lite { | |||||
| FuncGraphPtr ModelParserForTflite::Parse(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| // load graph | |||||
| tfliteModel = ReadTfliteModel(modelFile.c_str()); | |||||
| if (tfliteModel == nullptr) { | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| if (tfliteModel->subgraphs.size() != 1) { | |||||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| funcGraphPtr = std::make_shared<FuncGraph>(); | |||||
| auto status = ConvertGraphInputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| status = ConvertOps(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert ops failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| status = ConvertGraphOutputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| return funcGraphPtr; | |||||
| } | |||||
| STATUS ModelParserForTflite::ConvertOps() { | |||||
| const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); | |||||
| const auto &tfliteModelBuffers = tfliteModel->buffers; | |||||
| NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | |||||
| STATUS status = RET_OK; | |||||
| int opIdx = 0; | |||||
| for (auto &op : tfliteSubgraph->operators) { | |||||
| auto tfliteOpType = (tfliteModel->operator_codes[op->opcode_index])->builtin_code; | |||||
| auto opType = GetMSOpType(tfliteOpType); | |||||
| // parse primitive | |||||
| auto nodeParser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); | |||||
| if (nodeParser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(opType); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| } | |||||
| PrimitiveC *primitiveC = nullptr; | |||||
| if (status == RET_OK) { | |||||
| status = nodeParser->Parse(op, tfliteModel, primitiveC); | |||||
| if (status != RET_OK) { | |||||
| if (status == RET_NOT_FIND_OP) { | |||||
| opType = (opType != "Custom" ? opType : (tfliteModel->operator_codes[op->opcode_index])->custom_code); | |||||
| NoSupportOp::GetInstance()->InsertOp(opType); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed"; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))}; | |||||
| // parse inputs | |||||
| for (auto inputIdx : op->inputs) { | |||||
| const auto &inputTensor = tfliteSubgraph->tensors[inputIdx]; | |||||
| if (nodes.find(inputIdx) != nodes.end()) { | |||||
| opInputs.emplace_back(nodes.at(inputIdx)); | |||||
| continue; | |||||
| } | |||||
| // const tensor | |||||
| if (tfliteModelBuffers.at(inputTensor->buffer)->data.empty()) { | |||||
| ParameterPtr parameter; | |||||
| ConvertConstTensor(inputTensor.get(), parameter); | |||||
| opInputs.emplace_back(parameter); | |||||
| nodes.insert(std::pair(inputIdx, parameter)); | |||||
| continue; | |||||
| } | |||||
| MS_LOG(ERROR) << "tensor" << inputIdx << " is neither a node output nor a weight tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto newCNode = funcGraphPtr->NewCNode(opInputs); | |||||
| newCNode->set_fullname_with_scope(opType + "-" + std::to_string(opIdx++)); | |||||
| // parse outputs | |||||
| status = ConvertOutputTensor(op.get(), newCNode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert output tensors for " << newCNode->fullname_with_scope() << " failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| } | |||||
| return status; | |||||
| } | |||||
| STATUS ModelParserForTflite::ConvertGraphInputs() { | |||||
| const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); | |||||
| for (auto tfliteGraphInput : tfliteSubgraph->inputs) { | |||||
| if (tfliteGraphInput < 0) { | |||||
| tfliteGraphInput = tfliteGraphInput + tfliteSubgraph->tensors.size(); | |||||
| } | |||||
| auto parameter = funcGraphPtr->add_parameter(); | |||||
| const auto &tensor = tfliteSubgraph->tensors.at(tfliteGraphInput); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("graph_input_" + std::to_string(tfliteGraphInput) + "_parameter"); | |||||
| nodes.insert(std::pair(tfliteGraphInput, parameter)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS ModelParserForTflite::ConvertGraphOutputs() { | |||||
| const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); | |||||
| if (tfliteSubgraph->outputs.size() > 1) { | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||||
| if (make_tuple_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||||
| make_tuple_inputs.emplace_back(make_tuple_prim); | |||||
| for (auto outputNode : tfliteSubgraph->outputs) { | |||||
| auto cnode = nodes.at(outputNode); | |||||
| if (nullptr == cnode) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | |||||
| make_tuple_inputs.emplace_back(cnode); | |||||
| } | |||||
| auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto return_prim_ptr = GetReturnPrim(); | |||||
| if (return_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto value_node = NewValueNode(return_prim_ptr); | |||||
| op_inputs.emplace_back(value_node); | |||||
| op_inputs.emplace_back(make_tuple_cnode); | |||||
| auto cnode = funcGraphPtr->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| funcGraphPtr->set_return(cnode); | |||||
| } else { | |||||
| auto returnPrim = GetReturnPrim(); | |||||
| if (returnPrim == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto valueNode = NewValueNode(returnPrim); | |||||
| std::vector<AnfNodePtr> opInputs{valueNode}; | |||||
| auto cnode = nodes.at(tfliteSubgraph->outputs.front()); | |||||
| if (nullptr == cnode) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | |||||
| opInputs.emplace_back(cnode); | |||||
| auto returnCnode = funcGraphPtr->NewCNode(opInputs); | |||||
| returnCnode->set_fullname_with_scope("return"); | |||||
| funcGraphPtr->set_return(returnCnode); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS ModelParserForTflite::ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter) { | |||||
| parameter = funcGraphPtr->add_parameter(); | |||||
| const auto &tfliteModelBuffers = tfliteModel->buffers; | |||||
| auto type_id = static_cast<TypeId>(tensor->type); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("const_" + std::to_string(nodes.size()) + "_parameter"); | |||||
| ParamValueLitePtr paramValue = std::make_shared<ParamValueLite>(); | |||||
| MS_ASSERT(paramValue != nullptr); | |||||
| paramValue->set_tensor_shape(tensor->shape); | |||||
| paramValue->set_tensor_type(GetTfliteDataType(tensor->type)); | |||||
| paramValue->set_format(schema::Format::Format_NHWC); | |||||
| const auto &data = tfliteModelBuffers.at(tensor->buffer)->data; | |||||
| if (!data.empty()) { | |||||
| auto size = data.size(); | |||||
| char *tensor_data = new (std::nothrow) char[size]; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| std::memcpy(tensor_data, data.data(), size); | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(size); | |||||
| parameter->set_default_param(paramValue); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS ModelParserForTflite::ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode) { | |||||
| MS_ASSERT(op != nullptr); | |||||
| MS_ASSERT(dstCNode != nullptr); | |||||
| const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); | |||||
| if (op->outputs.size() == 1) { | |||||
| const auto &tensor = tfliteSubgraph->tensors.at(op->outputs.front()); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| dstCNode->set_abstract(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector)); | |||||
| nodes.insert(std::pair(op->outputs.front(), dstCNode)); | |||||
| } else { | |||||
| AbstractBasePtrList abstractList; | |||||
| for (auto outputIdx : op->outputs) { | |||||
| const auto &tensor = tfliteSubgraph->tensors.at(outputIdx); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector)); | |||||
| auto tupleGetItemPrimPtr = GetTupleGetItemPrim(); | |||||
| if (tupleGetItemPrimPtr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); | |||||
| auto getItemValue = NewValueNode(MakeValue<int>(outputIdx)); | |||||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, dstCNode, getItemValue}; | |||||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | |||||
| getItemCNode->set_fullname_with_scope(dstCNode->fullname_with_scope() + "_getitem_" + std::to_string(outputIdx)); | |||||
| nodes.insert(std::pair(outputIdx, getItemCNode)); | |||||
| } | |||||
| dstCNode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MODEL_PARSER_FOR_TFLITE_H | |||||
| #define LITE_MODEL_PARSER_FOR_TFLITE_H | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||||
| namespace mindspore::lite { | |||||
| class ModelParserForTflite : public TfliteModelParser { | |||||
| public: | |||||
| ModelParserForTflite() = default; | |||||
| ~ModelParserForTflite() override = default; | |||||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override; | |||||
| private: | |||||
| std::unordered_map<int, AnfNodePtr> nodes; | |||||
| std::unique_ptr<tflite::ModelT> tfliteModel; | |||||
| FuncGraphPtr funcGraphPtr; | |||||
| STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter); | |||||
| STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode); | |||||
| STATUS ConvertOps(); | |||||
| STATUS ConvertGraphInputs(); | |||||
| STATUS ConvertGraphOutputs(); | |||||
| }; | |||||
| } // namespace mindspore::lite | |||||
| #endif // LITE_MODEL_PARSER_FOR_TFLITE_H | |||||
| @@ -18,9 +18,11 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include "src/ops/activation.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "tools/converter/parser/tflite/tflite_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| @@ -86,12 +88,40 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser()); | |||||
| lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| auto ms_op_type = GetMSOpType(tflite_op_type); | |||||
| if (kActivationTypeMap.find(ms_op_type) == kActivationTypeMap.end()) { | |||||
| MS_LOG(ERROR) << ms_op_type << "is a not supported activation type"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->type = kActivationTypeMap.find(GetMSOpType(tflite_op_type))->second; | |||||
| if (attr->type == schema::ActivationType_LEAKY_RELU) { | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << GetMSOpType(tflite_op_type) << " attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->alpha = tflite_attr->alpha; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser()); | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); | TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); | ||||
| TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); | |||||
| } // namespace mindspore::lite | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteActivationParser : public TfliteNodeParser { | class TfliteActivationParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | TfliteActivationParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -32,9 +31,10 @@ class TfliteActivationParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H | ||||
| @@ -18,9 +18,10 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_addn_parser.h" | #include "tools/converter/parser/tflite/tflite_addn_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "src/ops/addn.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| @@ -55,7 +56,18 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto attr = std::make_unique<schema::AddNT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_AddN; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser()); | TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser()); | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -32,6 +32,9 @@ class TfliteAddNParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,6 +76,39 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->outMaxValue = false; | |||||
| attr->topK = 1; | |||||
| attr->keepDims = false; | |||||
| attr->axisType = 1; | |||||
| // get axis attr | |||||
| auto axis_idx = tflite_op->inputs[1]; | |||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||||
| if (buf_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the buf data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto data_ptr = buf_data->data.data(); | |||||
| if (data_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "the data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_ArgMax; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,9 @@ class TfliteArgmaxParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,6 +76,39 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->outMaxValue = false; | |||||
| attr->topK = 1; | |||||
| attr->keepDims = false; | |||||
| attr->axisType = 1; | |||||
| // get axis attr | |||||
| auto axis_idx = tflite_op->inputs[1]; | |||||
| auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||||
| if (buf_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the buf data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto data_ptr = buf_data->data.data(); | |||||
| if (data_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "the data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_ArgMin; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); | TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteArgminParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -179,6 +179,133 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (tflite_op_type == tflite::BuiltinOperator_ADD) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| auto attr = std::make_unique<schema::AddT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| primitive->value.type = schema::PrimitiveType_Add; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SUB) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSubParser"; | |||||
| auto attr = std::make_unique<schema::SubT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| primitive->value.type = schema::PrimitiveType_Sub; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MUL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||||
| auto attr = std::make_unique<schema::MulT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| primitive->value.type = schema::PrimitiveType_Mul; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_DIV) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||||
| auto attr = std::make_unique<schema::DivT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| primitive->value.type = schema::PrimitiveType_Div; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_DIV) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | |||||
| std::unique_ptr<schema::FloorDivT> attr = std::make_unique<schema::FloorDivT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_FloorDiv; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_MOD) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | |||||
| auto attr = std::make_unique<schema::FloorModT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_FloorMod; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SQUARED_DIFFERENCE) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; | |||||
| auto attr = std::make_unique<schema::SquaredDifferenceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_SquaredDifference; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_POW) { | |||||
| MS_LOG(DEBUG) << "parse TflitePowParser"; | |||||
| auto attr = std::make_unique<schema::PowerT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->power = 1.0f; | |||||
| attr->scale = 1.0f; | |||||
| attr->shift = 0.0f; | |||||
| primitive->value.type = schema::PrimitiveType_Power; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MAXIMUM) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMaximumParser"; | |||||
| auto attr = std::make_unique<schema::MaximumT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Maximum; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MINIMUM) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMinimumParser"; | |||||
| auto attr = std::make_unique<schema::MinimumT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Minimum; | |||||
| primitive->value.value = attr.release(); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "op hasn't been supported"; | |||||
| return nullptr; | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| @@ -320,6 +447,124 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (tflite_op_type == tflite::BuiltinOperator_ABS) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAbsParser"; | |||||
| auto attr = std::make_unique<schema::AbsT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Abs; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_EXP) { | |||||
| MS_LOG(DEBUG) << "parse TfliteExpParser"; | |||||
| auto attr = std::make_unique<schema::ExpT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->base = -1; // -1 represent base = e | |||||
| attr->scale = 1; | |||||
| attr->shift = 0; | |||||
| primitive->value.type = schema::PrimitiveType_Exp; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SQRT) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSqrtParser"; | |||||
| auto attr = std::make_unique<schema::SqrtT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Sqrt; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_RSQRT) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; | |||||
| auto attr = std::make_unique<schema::RsqrtT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Rsqrt; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SQUARE) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSquareParser"; | |||||
| auto attr = std::make_unique<schema::SquareT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Square; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SIN) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSinParser"; | |||||
| auto attr = std::make_unique<schema::SinT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Sin; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_COS) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCosParser"; | |||||
| std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Cos; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_LOG) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogParser"; | |||||
| auto attr = std::make_unique<schema::LogT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Log; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_ROUND) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRoundParser"; | |||||
| auto attr = std::make_unique<schema::RoundT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Round; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_CEIL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCeilParser"; | |||||
| auto attr = std::make_unique<schema::CeilT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Ceil; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorParser"; | |||||
| auto attr = std::make_unique<schema::FloorT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Floor; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_NEG) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNegParser"; | |||||
| auto attr = std::make_unique<schema::NegT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Neg; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| @@ -406,29 +651,91 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (tflite_op_type == tflite::BuiltinOperator_EQUAL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteEqualParser"; | |||||
| std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Equal; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_NOT_EQUAL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | |||||
| std::unique_ptr<schema::NotEqualT> attr = std::make_unique<schema::NotEqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_NotEqual; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_GREATER) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | |||||
| std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Greater; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_GREATER_EQUAL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | |||||
| std::unique_ptr<schema::GreaterEqualT> attr = std::make_unique<schema::GreaterEqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_GreaterEqual; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_LESS) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessParser"; | |||||
| std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Less; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_LESS_EQUAL) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | |||||
| std::unique_ptr<schema::LessEqualT> attr = std::make_unique<schema::LessEqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LessEqual; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TflitePowParser("Pow", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); | TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); | TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); | TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); | |||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); | TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); | ||||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); | TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); | ||||
| @@ -32,6 +32,9 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| class TfliteSingleInputOpParser : public TfliteNodeParser { | class TfliteSingleInputOpParser : public TfliteNodeParser { | ||||
| @@ -41,6 +44,9 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| class TfliteCompareOpParser : public TfliteNodeParser { | class TfliteCompareOpParser : public TfliteNodeParser { | ||||
| @@ -50,7 +56,11 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,6 +74,29 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_BatchToSpace; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | ||||
| TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); | TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); | ||||
| @@ -32,7 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -57,6 +57,28 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) { | |||||
| MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser()); | TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteBroadcastToParser : public TfliteNodeParser { | class TfliteBroadcastToParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | ||||
| @@ -32,8 +31,10 @@ class TfliteBroadcastToParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H | ||||
| @@ -63,6 +63,32 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| primitive->value.type = schema::PrimitiveType_Cast; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser()); | TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteCastParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,6 +60,26 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); | |||||
| if (tfliteAttr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op concat attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = tfliteAttr->axis; | |||||
| attr->n = tflite_op->inputs.size(); | |||||
| primitive->value.type = schema::PrimitiveType_Concat; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); | TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteConcatParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,8 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| @@ -74,7 +73,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | const auto &data_tensor = tflite_subgraph->tensors[data_index]; | ||||
| std::vector<int> params; | |||||
| std::vector<int64_t> params; | |||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| @@ -96,7 +95,63 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get conv attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->group = 1; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->dilateH = tflite_attr->dilation_h_factor; | |||||
| attr->dilateW = tflite_attr->dilation_w_factor; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| attr->hasBias = true; | |||||
| // get the conv op weight tensor | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->channelIn = weight_shape[3]; | |||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | |||||
| auto data_index = tflite_op->inputs[0]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int64_t> params; | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return nullptr; | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); | TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteConvParser : public TfliteNodeParser { | class TfliteConvParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteConvParser() : TfliteNodeParser("Conv2D") {} | TfliteConvParser() : TfliteNodeParser("Conv2D") {} | ||||
| @@ -32,8 +31,9 @@ class TfliteConvParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | ||||
| @@ -15,9 +15,8 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_converter.h" | #include "tools/converter/parser/tflite/tflite_converter.h" | ||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); } | TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); } | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -21,18 +21,15 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| #include "tools/converter/converter.h" | #include "tools/converter/converter.h" | ||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | |||||
| #include "tools/converter/graphdef_transform.h" | #include "tools/converter/graphdef_transform.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteConverter : public Converter { | class TfliteConverter : public Converter { | ||||
| public: | public: | ||||
| TfliteConverter(); | TfliteConverter(); | ||||
| ~TfliteConverter() override = default; | ~TfliteConverter() override = default; | ||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ | ||||
| @@ -271,6 +271,48 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| } | } | ||||
| return status; | return status; | ||||
| } | } | ||||
| PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto op = new schema::CNodeT; | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &custom_attr = tflite_op->custom_options; | |||||
| const auto &opcode_index = tflite_op->opcode_index; | |||||
| const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code; | |||||
| int status = RET_OK; | |||||
| if (custom_type == "TFLite_Detection_PostProcess") { | |||||
| status = DetectPostProcess(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "Predict") { | |||||
| status = Predict(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "Normalize") { | |||||
| status = Normalize(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "ExtractFeatures") { | |||||
| status = ExtractFeatures(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "AudioSpectrogram") { | |||||
| status = AudioSpectrogram(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "Mfcc") { | |||||
| status = Mfcc(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "FlexRFFT") { | |||||
| status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph); | |||||
| } else if (custom_type == "FlexReal") { | |||||
| status = FftReal(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "FlexImag") { | |||||
| status = FftImag(custom_attr, op, tflite_op); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "the custom op hasn't been supported now"; | |||||
| status = RET_NOT_FIND_OP; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| return nullptr; | |||||
| } | |||||
| auto primitive = op->primitive.release(); | |||||
| delete op; | |||||
| return PrimitiveC::Create(primitive); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); | TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -31,6 +31,8 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| static STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | static STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | const std::unique_ptr<tflite::OperatorT> &tflite_op); | ||||
| @@ -75,7 +75,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[2]; | auto data_index = tflite_op->inputs[2]; | ||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | const auto &data_tensor = tflite_subgraph->tensors[data_index]; | ||||
| std::vector<int> params; | |||||
| std::vector<int64_t> params; | |||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| @@ -96,6 +96,64 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op deconv attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->group = 1; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->dilateH = 1; | |||||
| attr->dilateW = 1; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| attr->hasBias = true; | |||||
| // get the conv op weight tensor | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->channelIn = weight_shape[3]; | |||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | |||||
| auto data_index = tflite_op->inputs[2]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int64_t> params; | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return nullptr; | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); | TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteDeConvParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,6 +60,26 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op depthtospace attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->blockSize = tflite_attr->block_size; | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_Concat; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser()); | TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,8 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| @@ -82,7 +81,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| attr->kernelW = weight_shape[2]; | attr->kernelW = weight_shape[2]; | ||||
| // calculate pad params | // calculate pad params | ||||
| std::vector<int> params; | |||||
| std::vector<int64_t> params; | |||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| @@ -104,7 +103,71 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | |||||
| std::unique_ptr<schema::DepthwiseConv2DT> attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op de attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->dilateH = tflite_attr->dilation_h_factor; | |||||
| attr->dilateW = tflite_attr->dilation_w_factor; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| attr->hasBias = true; | |||||
| attr->channelMultiplier = tflite_attr->depth_multiplier; | |||||
| // get the data tensor | |||||
| auto data_index = tflite_op->inputs[1]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| if (data_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the data tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto data_shape = data_tensor->shape; | |||||
| attr->channelIn = data_shape[3]; | |||||
| // get the weight tensor | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | |||||
| std::vector<int64_t> params; | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return nullptr; | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); | TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | ||||
| @@ -32,8 +31,10 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | ||||
| @@ -75,6 +75,45 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "input tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "output tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||||
| (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || | |||||
| GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { | |||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||||
| } else { | |||||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Cast; | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -31,6 +31,9 @@ class TfliteDequantizeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -56,6 +56,30 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ExpandDimsT> attr = std::make_unique<schema::ExpandDimsT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> dims; | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) { | |||||
| MS_LOG(ERROR) << "get expand_dims -> dim failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->dim = dims[0]; | |||||
| primitive->value.type = schema::PrimitiveType_ExpandDims; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,6 +32,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -57,6 +57,32 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::FillT> attr = std::make_unique<schema::FillT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (tflite_op->inputs.size() > 1) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { | |||||
| MS_LOG(ERROR) << "get fill -> dims failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Fill; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser()); | TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,9 @@ class TfliteFillParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -69,6 +69,37 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::FullConnectionT> attr = std::make_unique<schema::FullConnectionT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op fully connect attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| bool hasBias = tflite_op->inputs.size() > 2 && tflite_op->inputs[2] != -1; | |||||
| attr->hasBias = hasBias; | |||||
| attr->axis = 1; | |||||
| attr->useAxis = false; | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); | TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); | ||||
| TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser()); | TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser()); | ||||
| @@ -32,7 +32,10 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,6 +54,26 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::GatherNdT> attr = std::make_unique<schema::GatherNdT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->batchDims = 0; | |||||
| primitive->value.type = schema::PrimitiveType_GatherNd; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser()); | TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,6 +60,32 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op gather attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = tflite_attr->axis; | |||||
| attr->batchDims = 0; | |||||
| primitive->value.type = schema::PrimitiveType_Gather; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser()); | TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteGatherParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -55,6 +55,24 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_HashtableLookup; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -55,6 +55,27 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| std::unique_ptr<schema::L2NormT> attr = std::make_unique<schema::L2NormT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions(); | |||||
| attr->axis = {-1}; | |||||
| attr->epsilon = 1e-6f; | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_L2Norm; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); | TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteL2NormParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -78,6 +78,44 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_AND) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; | |||||
| std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_NOT) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogicalNotParser"; | |||||
| std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalNot; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_OR) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogicalOrParser"; | |||||
| std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalOr; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser()); | TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser()); | ||||
| TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser()); | TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser()); | ||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteLogicalParser : public TfliteNodeParser { | class TfliteLogicalParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteLogicalParser() : TfliteNodeParser("node_name") {} | TfliteLogicalParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -32,8 +31,9 @@ class TfliteLogicalParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_PARSER_H | ||||
| @@ -60,6 +60,34 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op LRN attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->depth_radius = tflite_attr->radius; | |||||
| attr->alpha = tflite_attr->alpha; | |||||
| attr->beta = tflite_attr->beta; | |||||
| attr->bias = tflite_attr->bias; | |||||
| primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser()); | TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteLRNParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,6 +64,35 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions(); | |||||
| switch (tflite_attr->type) { | |||||
| case tflite::LSHProjectionType_SPARSE: | |||||
| attr->type = schema::LshProjectionType_SPARSE; | |||||
| break; | |||||
| case tflite::LSHProjectionType_DENSE: | |||||
| attr->type = schema::LshProjectionType_DENSE; | |||||
| break; | |||||
| default: | |||||
| attr->type = schema::LshProjectionType_UNKNOWN; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LshProjection; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,79 +13,167 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | #include "tools/converter/parser/tflite/tflite_model_parser.h" | ||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "tools/common/storage.h" | |||||
| #include "flatbuffers/flatbuffers.h" | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include "src/param_value_lite.h" | |||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "tools/common/node_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| TfliteModelParser::TfliteModelParser() = default; | |||||
| TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); } | |||||
| namespace mindspore::lite { | |||||
| std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) { | std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) { | ||||
| size_t size = 0; | size_t size = 0; | ||||
| tfliteModelBuf = ReadFile(model_path, &size); | |||||
| if (tfliteModelBuf == nullptr) { | |||||
| tflite_model_buf_ = ReadFile(model_path, &size); | |||||
| if (tflite_model_buf_ == nullptr) { | |||||
| MS_LOG(ERROR) << "the file buffer is nullptr"; | MS_LOG(ERROR) << "the file buffer is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size); | |||||
| flatbuffers::Verifier verify((const uint8_t *)tflite_model_buf_, size); | |||||
| if (!tflite::VerifyModelBuffer(verify)) { | if (!tflite::VerifyModelBuffer(verify)) { | ||||
| MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; | MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return tflite::UnPackModel(tfliteModelBuf); | |||||
| return tflite::UnPackModel(tflite_model_buf_); | |||||
| } | } | ||||
| STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| MS_ASSERT(tflite_tensor != nullptr); | |||||
| auto buffer_idx = tflite_tensor->buffer; | |||||
| FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| // load graph | |||||
| tflite_model_ = ReadTfliteModel(model_file.c_str()); | |||||
| if (tflite_model_ == nullptr) { | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| const auto &buf = tflite_model_buffer[buffer_idx]; | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| if (tflite_model_->subgraphs.size() != 1) { | |||||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | } | ||||
| func_graph_ = std::make_shared<FuncGraph>(); | |||||
| if (!buf->data.empty()) { | |||||
| auto data_size = buf->data.size(); | |||||
| tensor->data.resize(data_size); | |||||
| if (memcpy_s(tensor->data.data(), data_size, buf->data.data(), data_size) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy tensor data failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "src tensor data is empty"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| auto status = ConvertGraphInputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | } | ||||
| return RET_OK; | |||||
| status = ConvertOps(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert ops failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| status = ConvertGraphOutputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| return func_graph_; | |||||
| } | } | ||||
| void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor, | |||||
| schema::TensorT *tensor) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| tensor->quantParams.clear(); | |||||
| STATUS TfliteModelParser::ConvertOps() { | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||||
| const auto &tflite_model_buffers = tflite_model_->buffers; | |||||
| NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | |||||
| STATUS status = RET_OK; | |||||
| int op_idx = 0; | |||||
| for (auto &op : tflite_subgraph->operators) { | |||||
| auto tfliteOpType = (tflite_model_->operator_codes[op->opcode_index])->builtin_code; | |||||
| auto op_type = GetMSOpType(tfliteOpType); | |||||
| auto op_name = op_type + "-" + std::to_string(op_idx); | |||||
| op_idx++; | |||||
| // parse primitive | |||||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||||
| if (node_parser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| continue; | |||||
| } | |||||
| auto primitiveC = node_parser->ParseLitePrimitive(op, tflite_model_); | |||||
| if (primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "parse node " << op_type.c_str() << " parser failed"; | |||||
| continue; | |||||
| } | |||||
| status = ConvertOpQuantParams(op.get(), primitiveC); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert " << op_name << " quant param failed."; | |||||
| return status; | |||||
| } | |||||
| if (tflite_tensor->quantization == nullptr) { | |||||
| MS_LOG(ERROR) << "tflite_tensor->quantization is null"; | |||||
| return; | |||||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))}; | |||||
| // parse inputs | |||||
| for (auto input_idx : op->inputs) { | |||||
| if (input_idx < 0) { | |||||
| input_idx += tflite_subgraph->tensors.size(); | |||||
| } | |||||
| const auto &input_tensor = tflite_subgraph->tensors[input_idx]; | |||||
| if (nodes_.find(input_idx) != nodes_.end()) { | |||||
| op_inputs.emplace_back(nodes_.at(input_idx)); | |||||
| continue; | |||||
| } | |||||
| // const tensor | |||||
| if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) { | |||||
| auto parameter = func_graph_->add_parameter(); | |||||
| status = ConvertConstTensor(input_tensor.get(), parameter.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | |||||
| return status; | |||||
| } | |||||
| op_inputs.emplace_back(parameter); | |||||
| nodes_.insert(std::pair(input_idx, parameter)); | |||||
| continue; | |||||
| } | |||||
| MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor."; | |||||
| } | |||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||||
| new_cnode->set_fullname_with_scope(op_name); | |||||
| // parse outputs | |||||
| status = ConvertOutputTensor(op.get(), new_cnode); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert output tensors for " << new_cnode->fullname_with_scope() << " failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return status; | |||||
| } | |||||
| STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, | |||||
| std::vector<QuantParamT> *quant_params) { | |||||
| if (tflite_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| quant_params->clear(); | |||||
| if (tflite_tensor->quantization == nullptr || | |||||
| (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | |||||
| tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { | |||||
| std::vector<schema::QuantParamT> notinited_quant_params(1); | |||||
| *quant_params = notinited_quant_params; | |||||
| return RET_OK; | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) { | for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) { | ||||
| std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>(); | std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>(); | ||||
| if (quant_param == nullptr) { | if (quant_param == nullptr) { | ||||
| MS_LOG(ERROR) << "quant_param is null"; | MS_LOG(ERROR) << "quant_param is null"; | ||||
| return; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| if (!tflite_tensor->quantization->scale.empty()) { | if (!tflite_tensor->quantization->scale.empty()) { | ||||
| @@ -104,364 +192,219 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||||
| quant_param->max = tflite_tensor->quantization->max[i]; | quant_param->max = tflite_tensor->quantization->max[i]; | ||||
| } | } | ||||
| quant_param->inited = true; | quant_param->inited = true; | ||||
| tensor->quantParams.emplace_back(std::move(quant_param)); | |||||
| quant_params->emplace_back(*std::move(quant_param)); | |||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const QuantType &quant_type, schema::MetaGraphT *sub_graph) { | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| MS_ASSERT(sub_graph != nullptr); | |||||
| STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "tflite op is null, get quant params failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| int idx = 0; | |||||
| int status = RET_OK; | |||||
| NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||||
| const auto opcode_index = tflite_op->opcode_index; | |||||
| const auto &operator_code = tflite_model->operator_codes[opcode_index]; | |||||
| if (operator_code == nullptr) { | |||||
| MS_LOG(ERROR) << "operator_code is null"; | |||||
| return RET_ERROR; | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||||
| for (auto input_idx : op->inputs) { | |||||
| if (input_idx < 0) { | |||||
| input_idx += tflite_subgraph->tensors.size(); | |||||
| } | } | ||||
| auto op_type = GetMSOpType(operator_code->builtin_code); | |||||
| auto op = std::make_unique<schema::CNodeT>(); | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| const auto &input_tensor = tflite_subgraph->tensors[input_idx]; | |||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| auto status = SetTensorQuantParam(input_tensor.get(), &quant_params); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "set input tensor quant param failed."; | |||||
| return status; | |||||
| } | } | ||||
| op->name = op_type + "-" + std::to_string(idx++); | |||||
| op->quantType = quant_type; | |||||
| MS_LOG(INFO) << "parse op: " << op->name.c_str(); | |||||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||||
| if (node_parser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| continue; | |||||
| primitive_c->AddInputQuantParam(quant_params); | |||||
| } | |||||
| for (auto output_idx : op->outputs) { | |||||
| if (output_idx < 0) { | |||||
| output_idx += tflite_subgraph->tensors.size(); | |||||
| } | } | ||||
| if (status == RET_OK || op_type == "Custom") { | |||||
| int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); | |||||
| status = (status == RET_OK ? status_node : status); | |||||
| if (status_node != RET_OK) { | |||||
| if (status_node == RET_NOT_FIND_OP) { | |||||
| op_type = (op_type != "Custom" ? op_type : operator_code->custom_code); | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| continue; | |||||
| } | |||||
| sub_graph->nodes.emplace_back(op.release()); | |||||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | |||||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | |||||
| const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); | |||||
| std::vector<schema::QuantParamT> quant_params; | |||||
| auto status = SetTensorQuantParam(output_tensor.get(), &quant_params); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "set output tensor quant param failed."; | |||||
| return status; | |||||
| } | } | ||||
| primitive_c->AddOutputQuantParam(quant_params); | |||||
| } | } | ||||
| return status; | |||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::MetaGraphT *sub_graph) { | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| MS_ASSERT(sub_graph != nullptr); | |||||
| std::set<int> output_index; | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||||
| for (int idx : tflite_op->outputs) { | |||||
| if (idx < 0) { | |||||
| idx += tflite_subgraph->tensors.size(); | |||||
| } | |||||
| output_index.insert(idx); | |||||
| STATUS TfliteModelParser::ConvertGraphInputs() { | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||||
| for (auto tflite_graph_input : tflite_subgraph->inputs) { | |||||
| if (tflite_graph_input < 0) { | |||||
| tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); | |||||
| } | } | ||||
| auto parameter = func_graph_->add_parameter(); | |||||
| const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter"); | |||||
| nodes_.insert(std::pair(tflite_graph_input, parameter)); | |||||
| } | } | ||||
| for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { | |||||
| auto idx = tensorsInfo.tensorsId[i]; | |||||
| if (idx < 0) { | |||||
| idx += tflite_subgraph->tensors.size(); | |||||
| } | |||||
| const auto &tflite_tensor = tflite_subgraph->tensors[idx]; | |||||
| if (tflite_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tflite_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||||
| if (tflite_subgraph->outputs.size() > 1) { | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||||
| if (make_tuple_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| tensor->format = tensorsInfo.tensorsFormat[i]; | |||||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type); | |||||
| tensor->dims = tflite_tensor->shape; | |||||
| // if graph input tensor | |||||
| bool isInput = false; | |||||
| auto tflite_inputs = tflite_subgraph->inputs; | |||||
| for (int tflite_input : tflite_inputs) { | |||||
| if (idx == tflite_input) { | |||||
| isInput = true; | |||||
| break; | |||||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||||
| make_tuple_inputs.emplace_back(make_tuple_prim); | |||||
| for (auto outputNode : tflite_subgraph->outputs) { | |||||
| auto cnode = nodes_.at(outputNode); | |||||
| if (nullptr == cnode) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| make_tuple_inputs.emplace_back(cnode); | |||||
| } | } | ||||
| auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||||
| // add data for const tensor | |||||
| auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer); | |||||
| if (tensor_buffer == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_buffer is null"; | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto return_prim_ptr = GetReturnPrim(); | |||||
| if (return_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto isConst = (!tensor_buffer->data.empty()); | |||||
| if (isConst) { | |||||
| int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "obtain const tensor failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| // set tensor attr | |||||
| if (isInput || isConst) { | |||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| } else { | |||||
| if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) { | |||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| } else { | |||||
| tensor->nodeType = schema::NodeType_Parameter; | |||||
| } | |||||
| auto value_node = NewValueNode(return_prim_ptr); | |||||
| op_inputs.emplace_back(value_node); | |||||
| op_inputs.emplace_back(make_tuple_cnode); | |||||
| auto cnode = func_graph_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| func_graph_->set_return(cnode); | |||||
| } else { | |||||
| auto returnPrim = GetReturnPrim(); | |||||
| if (returnPrim == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| // quant param | |||||
| if (tflite_tensor->quantization != nullptr && | |||||
| !(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | |||||
| tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { | |||||
| SetTensorQuantParam(tflite_tensor, tensor.get()); | |||||
| auto valueNode = NewValueNode(returnPrim); | |||||
| std::vector<AnfNodePtr> op_inputs{valueNode}; | |||||
| auto cnode = nodes_.at(tflite_subgraph->outputs.front()); | |||||
| if (nullptr == cnode) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| tensors.push_back(tensor.release()); | |||||
| } | |||||
| for (auto iter : tensors) { | |||||
| std::unique_ptr<schema::TensorT> temp(iter); | |||||
| sub_graph->allTensors.emplace_back(move(temp)); | |||||
| op_inputs.emplace_back(cnode); | |||||
| auto returnCnode = func_graph_->NewCNode(op_inputs); | |||||
| returnCnode->set_fullname_with_scope("return"); | |||||
| func_graph_->set_return(returnCnode); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| schema::MetaGraphT *sub_graph) { | |||||
| MS_ASSERT(sub_graph != nullptr); | |||||
| MS_ASSERT(tflite_subgraph != nullptr); | |||||
| // graph input | |||||
| std::vector<int> graph_inputs; | |||||
| for (size_t i = 0; i < tflite_subgraph->inputs.size(); i++) { | |||||
| const int idx = tflite_subgraph->inputs[i]; | |||||
| int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx; | |||||
| auto iter = tensorsInfo.tensorsIdMap.find(id); | |||||
| if (iter != tensorsInfo.tensorsIdMap.end()) { | |||||
| graph_inputs.push_back(iter->second); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "get graph input failed"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) { | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null, get const tensor failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end()); | |||||
| // graph output | |||||
| std::vector<int> graph_outputs; | |||||
| for (size_t i = 0; i < tflite_subgraph->outputs.size(); i++) { | |||||
| const int idx = tflite_subgraph->outputs[i]; | |||||
| int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx; | |||||
| auto iter = tensorsInfo.tensorsIdMap.find(id); | |||||
| if (iter != tensorsInfo.tensorsIdMap.end()) { | |||||
| graph_outputs.push_back(iter->second); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "get graph output failed"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter is null, get const tensor failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end()); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) { | |||||
| MS_ASSERT(sub_graph != nullptr); | |||||
| for (auto &op : sub_graph->nodes) { | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||||
| auto attr = op->primitive->value.AsDepthwiseConv2D(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "attr is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (attr->channelMultiplier > 1) { | |||||
| // get channel attr | |||||
| if (op->inputIndex.empty()) { | |||||
| MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const auto data_id = op->inputIndex[0]; | |||||
| if (sub_graph->allTensors.size() <= data_id) { | |||||
| MS_LOG(ERROR) << "the number of allTensors is less than " << data_id; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto &data_tensor = sub_graph->allTensors.at(data_id); | |||||
| if (data_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the data tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto data_shape = data_tensor->dims; | |||||
| if (data_shape.empty()) { | |||||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running"; | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>(); | |||||
| if (conv_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "conv_attr is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (data_shape[3] == 1) { | |||||
| conv_attr->channelIn = data_shape[3]; | |||||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||||
| // update attr | |||||
| conv_attr->group = 1; | |||||
| conv_attr->format = attr->format; | |||||
| conv_attr->kernelH = attr->kernelH; | |||||
| conv_attr->kernelW = attr->kernelW; | |||||
| conv_attr->strideH = attr->strideH; | |||||
| conv_attr->strideW = attr->strideW; | |||||
| conv_attr->padMode = attr->padMode; | |||||
| conv_attr->padUp = attr->padUp; | |||||
| conv_attr->padDown = attr->padDown; | |||||
| conv_attr->padLeft = attr->padLeft; | |||||
| conv_attr->padRight = attr->padRight; | |||||
| conv_attr->dilateH = attr->dilateH; | |||||
| conv_attr->dilateW = attr->dilateW; | |||||
| conv_attr->hasBias = attr->hasBias; | |||||
| conv_attr->activationType = attr->activationType; | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = conv_attr.release(); | |||||
| // update weight | |||||
| auto weight_id = op->inputIndex[1]; | |||||
| auto &weight_tensor = sub_graph->allTensors.at(weight_id); | |||||
| if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { | |||||
| auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter schema::Format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (weight_tensor->dataType == kNumberTypeInt8) { | |||||
| auto status = TransFilterFormat<int8_t>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { | |||||
| auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans filter format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "The dataType of weight tensor is unsupported."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight_tensor->format = schema::Format::Format_CHWK; | |||||
| } | |||||
| } | |||||
| const auto &tfliteModelBuffers = tflite_model_->buffers; | |||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter"); | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_ASSERT(param_value != nullptr); | |||||
| param_value->set_tensor_shape(tensor->shape); | |||||
| param_value->set_tensor_type(GetTfliteDataType(tensor->type)); | |||||
| param_value->set_format(schema::Format::Format_NHWC); | |||||
| const auto &data = tfliteModelBuffers.at(tensor->buffer)->data; | |||||
| if (!data.empty()) { | |||||
| auto size = data.size(); | |||||
| char *tensor_data = new (std::nothrow) char[size]; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | } | ||||
| std::memcpy(tensor_data, data.data(), size); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| parameter->set_default_param(param_value); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| std::unique_ptr<schema::MetaGraphT> TfliteModelParser::ConstructMainGraph( | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, const QuantType &quant_type) { | |||||
| MS_ASSERT(tflite_model != nullptr); | |||||
| if (tflite_model->subgraphs.empty()) { | |||||
| MS_LOG(ERROR) << "read tflite model main subgraphs failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs[0]; | |||||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "new meta graph failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); | |||||
| return nullptr; | |||||
| } | |||||
| meta_graph->name = "MS_model converted by TF-Lite"; | |||||
| quantType = quant_type; | |||||
| // convert op | |||||
| int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "parse op failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| // convert tensor | |||||
| status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert tensor failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null, get output tensor failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| // set graph input/output | |||||
| status = GetGraphInfo(tflite_subgraph, meta_graph.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert tensors failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| if (dst_cnode == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter is null, get output tensor failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| // update for depthwiseConv | |||||
| status = ConvertGroupDepthwiseOp(meta_graph.get()); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | |||||
| if (op->outputs.size() == 1) { | |||||
| int output_idx = | |||||
| op->outputs.front() < 0 ? tflite_subgraph->tensors.size() + op->outputs.front() : op->outputs.front(); | |||||
| const auto &tensor = tflite_subgraph->tensors.at(output_idx); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); | |||||
| } else { | |||||
| AbstractBasePtrList abstract_list; | |||||
| int op_idx = 0; | |||||
| for (auto output_idx : op->outputs) { | |||||
| if (output_idx < 0) { | |||||
| output_idx = output_idx + tflite_subgraph->tensors.size(); | |||||
| } | |||||
| const auto &tensor = tflite_subgraph->tensors.at(output_idx); | |||||
| std::vector<int64_t> shape_vector; | |||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); | |||||
| if (tuple_get_item_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||||
| auto get_item_value = NewValueNode(MakeValue<int>(op_idx)); | |||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; | |||||
| CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); | |||||
| get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); | |||||
| nodes_.insert(std::pair(output_idx, get_item_cnode)); | |||||
| op_idx++; | |||||
| } | |||||
| dst_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| } | } | ||||
| return meta_graph; | |||||
| return RET_OK; | |||||
| } | } | ||||
| schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| if (model_file.empty()) { | |||||
| MS_LOG(ERROR) << "model_file is empty"; | |||||
| return nullptr; | |||||
| } | |||||
| // load graph | |||||
| auto tflite_model = ReadTfliteModel(model_file.c_str()); | |||||
| if (tflite_model == nullptr) { | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| // construct main_meta_graph | |||||
| auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type); | |||||
| if (main_meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "ConstructMainGraph failed"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | |||||
| return nullptr; | |||||
| } | |||||
| return main_meta_graph.release(); | |||||
| MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| return nullptr; | |||||
| } | } | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,67 +13,42 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_MODEL_PARSER_H | |||||
| #define LITE_TFLITE_MODEL_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #include <fcntl.h> | |||||
| #include <unistd.h> | |||||
| #include <google/protobuf/io/coded_stream.h> | |||||
| #include <google/protobuf/io/zero_copy_stream_impl.h> | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "securec/include/securec.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class TfliteModelParser : public ModelParser { | class TfliteModelParser : public ModelParser { | ||||
| public: | public: | ||||
| TfliteModelParser(); | |||||
| ~TfliteModelParser() override; | |||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quantTyp) override; | |||||
| protected: | |||||
| std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); | |||||
| static STATUS CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const tflite::TensorT *tflite_tensor, schema::TensorT *tensor); | |||||
| static void SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor); | |||||
| STATUS ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const QuantType &quant_type, | |||||
| schema::MetaGraphT *sub_graph); | |||||
| TfliteModelParser() = default; | |||||
| STATUS ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::MetaGraphT *sub_graph); | |||||
| ~TfliteModelParser() override = default; | |||||
| STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph); | |||||
| static STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); | |||||
| QuantType quantType = QuantType_QUANT_NONE; | |||||
| char *tfliteModelBuf = nullptr; | |||||
| std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const QuantType &quant_type); | |||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| private: | private: | ||||
| TfliteTensorsInfo tensorsInfo; | |||||
| std::vector<schema::TensorT *> tensors; | |||||
| std::map<std::string, schema::CNodeT *> opMap; | |||||
| std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap; | |||||
| std::unordered_map<int, AnfNodePtr> nodes_; | |||||
| std::unique_ptr<tflite::ModelT> tflite_model_; | |||||
| FuncGraphPtr func_graph_; | |||||
| char *tflite_model_buf_ = nullptr; | |||||
| std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); | |||||
| STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter); | |||||
| STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode); | |||||
| STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c); | |||||
| STATUS ConvertOps(); | |||||
| STATUS ConvertGraphInputs(); | |||||
| STATUS ConvertGraphOutputs(); | |||||
| STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #endif // LITE_TFLITE_MODEL_PARSER_H | |||||
| @@ -41,9 +41,9 @@ class TfliteNodeParser { | |||||
| virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0; | ||||
| virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, PrimitiveC *primitiveC) { | |||||
| return RET_OK; | |||||
| virtual lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| return nullptr; | |||||
| } | } | ||||
| static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, | static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, | ||||
| @@ -64,6 +64,38 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op onehot attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto axis = tflite_attr->axis; | |||||
| const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->axis = axis; | |||||
| primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser()); | TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteOneHotParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -90,6 +90,59 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_PAD) { | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op pad attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||||
| attr->constantValue = 0.0f; | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { | |||||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) { | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op pad attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| switch (tflite_attr->mode) { | |||||
| case tflite::MirrorPadMode_REFLECT: | |||||
| attr->paddingMode = schema::PaddingMode_REFLECT; | |||||
| break; | |||||
| case tflite::MirrorPadMode_SYMMETRIC: | |||||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | |||||
| return nullptr; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "this pad:" << tflite_op_type << " hasn't been supported"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Pad; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); | TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); | ||||
| TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); | TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); | ||||
| @@ -32,6 +32,8 @@ class TflitePadParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,8 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| @@ -43,17 +42,13 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "MeanPooling") == 0) { | |||||
| MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } else if (std::strcmp(node_name, "MaxPooling") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) { | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | attr->poolingMode = schema::PoolMode_MAX_POOLING; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||||
| MS_LOG(ERROR) << "pooling mode " << tflite_op_type << " hasn't been supported"; | |||||
| return RET_NOT_FIND_OP; | return RET_NOT_FIND_OP; | ||||
| } | } | ||||
| @@ -75,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | const auto &data_tensor = tflite_subgraph->tensors[data_index]; | ||||
| std::vector<int> params; | |||||
| std::vector<int64_t> params; | |||||
| int status = | int status = | ||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| @@ -95,8 +90,58 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) { | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op pooling attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->windowW = tflite_attr->filter_width; | |||||
| attr->windowH = tflite_attr->filter_height; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->global = false; | |||||
| attr->roundMode = schema::RoundMode_FLOOR; | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| // calculate pad params | |||||
| auto data_index = tflite_op->inputs[0]; | |||||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||||
| std::vector<int64_t> params; | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return nullptr; | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser()); | TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser()); | ||||
| TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser()); | TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser()); | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TflitePoolingParser : public TfliteNodeParser { | class TflitePoolingParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TflitePoolingParser() : TfliteNodeParser("node_name") {} | TflitePoolingParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -32,8 +31,9 @@ class TflitePoolingParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_POOLING_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_POOLING_PARSER_H | ||||
| @@ -52,6 +52,24 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->channelShared = true; | |||||
| primitive->value.type = schema::PrimitiveType_PReLU; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser()); | TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TflitePReLUParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,6 +74,50 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "input tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "output tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||||
| (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { | |||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||||
| primitive->value.value = attr.release(); | |||||
| } else { | |||||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| primitive->value.type = schema::PrimitiveType_Cast; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser()); | TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -31,6 +31,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,6 +71,43 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->dType = 0; | |||||
| std::vector<int> limit; | |||||
| std::vector<int> delta; | |||||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "range -> limit get failed"; | |||||
| return nullptr; | |||||
| } else if (status == RET_OK) { | |||||
| status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| if (status == RET_OK) { | |||||
| attr->limit = limit.front(); | |||||
| attr->delta = delta.front(); | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Range; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser()); | TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteRangeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -50,6 +50,24 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::RankT> attr = std::make_unique<schema::RankT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Rank; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser()); | TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteRankParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -85,11 +85,65 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op reduce attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->keepDims = tflite_attr->keep_dims; | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MAX) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReduceMaxParser"; | |||||
| attr->mode = schema::ReduceMode_ReduceMax; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MIN) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReduceMinParser"; | |||||
| attr->mode = schema::ReduceMode_ReduceMin; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_PROD) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReduceProdParser"; | |||||
| attr->mode = schema::ReduceMode_ReduceProd; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_SUM) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSumParser"; | |||||
| attr->mode = schema::ReduceMode_ReduceSum; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_MEAN) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMeanParser"; | |||||
| attr->mode = schema::ReduceMode_ReduceMean; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_ANY) { | |||||
| // attr->mode; | |||||
| MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) { | |||||
| MS_LOG(ERROR) << "get reduce -> axes failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Reduce; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteSumParser("Sum", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_tfliteReduceMaxParser("ReduceMax", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_tfliteReduceMinParser("ReduceMin", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_tfliteReduceProdParser("ReduceProd", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteMeanParser("Mean", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceParser()); | |||||
| TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteReduceParser : public TfliteNodeParser { | class TfliteReduceParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteReduceParser() : TfliteNodeParser("node_name") {} | TfliteReduceParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -32,8 +31,9 @@ class TfliteReduceParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H | ||||
| @@ -18,8 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | ||||
| @@ -43,8 +42,8 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions(); | |||||
| if (tfliteAttr == nullptr) { | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| if (tflite_op->inputs.size() < 2) { | if (tflite_op->inputs.size() < 2) { | ||||
| MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); | MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -68,9 +67,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| } | } | ||||
| } else { | } else { | ||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| attr->shape.resize(tfliteAttr->new_shape.size()); | |||||
| for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) { | |||||
| attr->shape[i] = tfliteAttr->new_shape[i]; | |||||
| attr->shape.resize(tflite_attr->new_shape.size()); | |||||
| for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) { | |||||
| attr->shape[i] = tflite_attr->new_shape[i]; | |||||
| } | } | ||||
| } | } | ||||
| @@ -83,7 +82,50 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| if (tflite_op->inputs.size() < 2) { | |||||
| MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); | |||||
| return nullptr; | |||||
| } | |||||
| auto shape_tensor_index = tflite_op->inputs[1]; | |||||
| const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index]; | |||||
| if (shape_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "shape_tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto &buf_data = tflite_model->buffers[shape_tensor->buffer]; | |||||
| if (buf_data == nullptr) { | |||||
| MS_LOG(ERROR) << "buf_data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!buf_data->data.empty()) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) { | |||||
| MS_LOG(ERROR) << "get reshape -> shape failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->shape.resize(tflite_attr->new_shape.size()); | |||||
| for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) { | |||||
| attr->shape[i] = tflite_attr->new_shape[i]; | |||||
| } | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); | TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteReshapeParser : public TfliteNodeParser { | class TfliteReshapeParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | ||||
| @@ -32,8 +31,10 @@ class TfliteReshapeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H | ||||
| @@ -120,6 +120,89 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON; | |||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||||
| if (tflite_op_type == tflite::BuiltinOperator_RESIZE_BILINEAR) { | |||||
| MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions(); | |||||
| if (tfliteAttr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op ResizeBilinear attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (tfliteAttr->align_corners) { | |||||
| attr->alignCorners = tfliteAttr->align_corners; | |||||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; | |||||
| } | |||||
| if (tfliteAttr->half_pixel_centers) { | |||||
| attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON | |||||
| ? schema::CoordinateTransformMode_TF_HALF_PIXEL | |||||
| : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); | |||||
| } | |||||
| attr->method = schema::ResizeMethod_LINEAR; | |||||
| } else if (tflite_op_type == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) { | |||||
| MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); | |||||
| if (tfliteAttr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op ResizeNearestNeighbor attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (tfliteAttr->align_corners) { | |||||
| attr->alignCorners = tfliteAttr->align_corners; | |||||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; | |||||
| } | |||||
| if (tfliteAttr->half_pixel_centers) { | |||||
| attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON | |||||
| ? schema::CoordinateTransformMode_TF_HALF_PIXEL | |||||
| : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); | |||||
| } | |||||
| attr->method = schema::ResizeMethod_NEAREST; | |||||
| attr->nearestMode = schema::NearestMode_NORMAL; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong resize type"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| attr->preserveAspectRatio = false; | |||||
| auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | |||||
| const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex]; | |||||
| if (shape_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "shape_tensor is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto resizeTensorBufferIndex = shape_tensor->buffer; | |||||
| const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex); | |||||
| if (buff == nullptr) { | |||||
| MS_LOG(ERROR) << "buff_data is null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto buffData = reinterpret_cast<int32_t *>(buff->data.data()); | |||||
| if (buffData != nullptr) { | |||||
| auto height = buffData[0]; | |||||
| auto width = buffData[1]; | |||||
| attr->newWidth = width; | |||||
| attr->newHeight = height; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Resize; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser()); | TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser()); | ||||
| TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser()); | TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser()); | ||||
| @@ -23,8 +23,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| namespace mindspore::lite { | |||||
| class TfliteResizeParser : public TfliteNodeParser { | class TfliteResizeParser : public TfliteNodeParser { | ||||
| public: | public: | ||||
| TfliteResizeParser() : TfliteNodeParser("node_name") {} | TfliteResizeParser() : TfliteNodeParser("node_name") {} | ||||
| @@ -32,8 +31,9 @@ class TfliteResizeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H | ||||
| @@ -55,6 +55,30 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ReverseT> attr = std::make_unique<schema::ReverseT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) { | |||||
| MS_LOG(ERROR) << "get reverse -> axis failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Reverse; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser()); | TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteReverseParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -62,6 +62,32 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ReverseSequenceT> attr = std::make_unique<schema::ReverseSequenceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op reverse attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->seqAxis = tflite_attr->seq_dim; | |||||
| attr->batchAxis = tflite_attr->batch_dim; | |||||
| primitive->value.type = schema::PrimitiveType_ReverseSequence; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser()); | TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,6 +58,29 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ScatterNDT> attr = std::make_unique<schema::ScatterNDT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op ScatterNd attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_ScatterND; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser()); | TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -50,6 +50,24 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Shape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser()); | TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteShapeParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -59,6 +59,33 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op attr failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->includeAllGrams = tflite_attr->include_all_ngrams; | |||||
| attr->maxSkipSize = tflite_attr->max_skip_size; | |||||
| attr->ngramSize = tflite_attr->ngram_size; | |||||
| primitive->value.type = schema::PrimitiveType_SkipGram; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); | TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,6 +66,41 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is null"; | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) { | |||||
| MS_LOG(ERROR) << "get slice -> begin failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) { | |||||
| MS_LOG(ERROR) << "get slice -> size failed"; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> axes; | |||||
| axes.clear(); | |||||
| for (size_t i = 0; i < attr->begin.size(); ++i) { | |||||
| axes.push_back(i); | |||||
| } | |||||
| attr->axes = axes; | |||||
| primitive->value.type = schema::PrimitiveType_Slice; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | |||||
| TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); | TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -32,6 +32,8 @@ class TfliteSliceParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | ||||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||