diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 89e5f0c007..a898f174f3 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -47,13 +47,17 @@ using TensorPtr = std::shared_ptr; constexpr int kAnfPopulaterInputNumOne = 1; constexpr int kAnfPopulaterInputNumTwo = 2; constexpr int kAnfPopulaterInputNumThree = 3; -static std::map kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, - {"ReLU6", schema::ActivationType_RELU6}, - {"Sigmoid", schema::ActivationType_SIGMOID}, - {"HSwish", schema::ActivationType_HSWISH}, - {"HSigmoid", schema::ActivationType_HSIGMOID}}; +static std::map kActivationTypeMap{ + {"ReLU", schema::ActivationType_RELU}, + {"ReLU6", schema::ActivationType_RELU6}, + {"Sigmoid", schema::ActivationType_SIGMOID}, + {"HSwish", schema::ActivationType_HSWISH}, + {"HSigmoid", schema::ActivationType_HSIGMOID}, + {"Swish", schema::ActivationType_SWISH}, + {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, + {"Tanh", schema::ActivationType_TANH}, + {"Logistic", schema::ActivationType_SIGMOID}}; std::vector CastToInt(const ValuePtr value, bool is_vector); - class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index 2b890a1472..78dd3a5385 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -104,8 +104,8 @@ int Split::InferShape(std::vector inputs_, std::vector outpu MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); - if (inputs_.size() != kSplitInputNum) { - MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum; + if (inputs_.size() < kSplitInputNum) { + MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum; return RET_ERROR; } auto output = outputs_.front(); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 032f96b362..260bfaf816 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -194,6 +194,8 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc + ${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc + ${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg index c724bc8b29..bb8909455b 100644 --- a/mindspore/lite/test/models_tflite.cfg +++ b/mindspore/lite/test/models_tflite.cfg @@ -135,6 +135,6 @@ mtk_convert_model.tflite mtk_model_face_dress_fp16.tflite smartreply.tflite mindspore_text_classification_tflite.tflite -ml_location.tflite +# ml_location.tflite ml_text_correction.tflite ml_pic_shopping.tflite diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 292b0af7b2..bc48d25c94 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -49,6 +49,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc + ../optimizer/graph/group_depthwise_op_convert_pass.cc + ../optimizer/graph/tflite_inputs_order_exchange_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc ../optimizer/graph/unused_transpose_node_remove_pass.cc ../optimizer/graph/identity_remove_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 8eb0283fc2..c91e22eb44 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -25,7 +25,6 @@ #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" -#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/layer_norm_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" @@ -34,6 +33,8 @@ #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" +#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" +#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" #include "tools/optimizer/graph/infershape_pass.h" @@ -43,8 +44,7 @@ #include "tools/converter/quantizer/weight_quantizer.h" using std::string; -namespace mindspore { -namespace lite { +namespace mindspore::lite { AnfTransform::AnfTransform() = default; AnfTransform::~AnfTransform() = default; @@ -65,7 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver cf_pm->AddPass(std::make_shared()); // for now - trainning is not supporting fuse operations - if (config != nullptr && !config->trainModel) { + if (!config->trainModel) { // remove quantdtype when awaretraining pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -119,6 +119,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver } pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); + if (config->fmk == lite::converter::FmkType_TFLITE) { + convert_pm->AddPass(std::make_shared()); + convert_pm->AddPass(std::make_shared()); + } optimizer->AddPassManager(cf_pm); optimizer->AddPassManager(convert_pm); optimizer->AddPassManager(pm); @@ -168,5 +172,4 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return new_graph; } -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 1ebacf4d21..35592c450e 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -32,8 +32,9 @@ class ModelParser { virtual ~ModelParser() = default; - virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { - auto *meta_graph = ParseToFb(modelFile, weightFile, quantType); + virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + auto *meta_graph = ParseToFb(model_file, weight_file, quant_type); if (meta_graph == nullptr) { MS_LOG(ERROR) << "parse model to fb failed"; return nullptr; @@ -43,8 +44,8 @@ class ModelParser { return func_graph; } - virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) = 0; + virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type = QuantType_QUANT_NONE) = 0; public: static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 4c4649ce9d..926b1c33eb 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -31,22 +31,22 @@ CaffeModelParser::~CaffeModelParser() {} const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; -schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - int status = ValidateFileStr(modelFile, ".prototxt"); +schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + int status = ValidateFileStr(model_file, ".prototxt"); if (status != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - if (weightFile.empty()) { + if (weight_file.empty()) { MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); return nullptr; } - status = ValidateFileStr(weightFile, ".caffemodel"); + status = ValidateFileStr(weight_file, ".caffemodel"); if (status != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -57,18 +57,18 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co TensorCache tensorCache; caffe::NetParameter proto; - status = ReadProtoFromText((const char *)modelFile.c_str(), &proto); + status = ReadProtoFromText((const char *)model_file.c_str(), &proto); if (status != RET_OK) { - MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; + MS_LOG(ERROR) << "Read prototxt file failed, model path: " << model_file; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } metaGraph->name = proto.name(); caffe::NetParameter weight; - status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight); + status = ReadProtoFromBinaryFile((const char *)weight_file.c_str(), &weight); if (status != RET_OK) { - MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; + MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weight_file; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } @@ -81,7 +81,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co } NoSupportOp::GetInstance()->SetFmkType("CAFFE"); - status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType); + status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quant_type); if (status != RET_OK) { MS_LOG(ERROR) << "ParseLayer failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -97,7 +97,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - metaGraph->name = GetModelName(modelFile); + metaGraph->name = GetModelName(model_file); SetAllTensors(tensorCache, metaGraph.get()); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 04fda2641c..7c12315d05 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -34,8 +34,8 @@ class CaffeModelParser : public ModelParser { virtual ~CaffeModelParser(); - schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) override; + schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type = QuantType_QUANT_NONE) override; private: STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 41cebcc8e5..a82db65116 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -623,9 +623,9 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT return RET_OK; } -schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - int status = ValidateFileStr(modelFile, ".onnx"); +schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + int status = ValidateFileStr(model_file, ".onnx"); if (status != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -633,9 +633,9 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con } onnx::ModelProto onnx_model; - status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model); + status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model); if (status != RET_OK) { - MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; + MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } @@ -645,13 +645,13 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con auto dst_graph = std::make_unique(); auto dst_sub_graph = std::make_unique(); - int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType); + int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type); dst_graph->subGraph.push_back(std::move(dst_sub_graph)); subGraphNum += 1; if (ret == RET_ERROR) { return nullptr; } - dst_graph->name = GetModelName(modelFile); + dst_graph->name = GetModelName(model_file); std::vector input_temp_index; for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 420f4661ec..df68a25c9e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -45,8 +45,8 @@ class OnnxModelParser : public ModelParser { int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, const QuantType &quantType); - schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) override; + schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type = QuantType_QUANT_NONE) override; static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); diff --git a/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.cc b/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.cc deleted file mode 100644 index d3c66aa428..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.cc +++ /dev/null @@ -1,273 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tflite/model_parser_for_tflite.h" -#include -#include -#include -#include -#include "src/param_value_lite.h" - -namespace mindspore::lite { - -FuncGraphPtr ModelParserForTflite::Parse(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - // load graph - tfliteModel = ReadTfliteModel(modelFile.c_str()); - if (tfliteModel == nullptr) { - MS_LOG(ERROR) << "read tflite model failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - - if (tfliteModel->subgraphs.size() != 1) { - MS_LOG(ERROR) << "read tflite model subgraphs failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - funcGraphPtr = std::make_shared(); - - auto status = ConvertGraphInputs(); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert graph inputs failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - - status = ConvertOps(); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert ops failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - - status = ConvertGraphOutputs(); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert graph outputs failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - return funcGraphPtr; -} - -STATUS ModelParserForTflite::ConvertOps() { - const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); - const auto &tfliteModelBuffers = tfliteModel->buffers; - NoSupportOp::GetInstance()->SetFmkType("TFLITE"); - STATUS status = RET_OK; - int opIdx = 0; - for (auto &op : tfliteSubgraph->operators) { - auto tfliteOpType = (tfliteModel->operator_codes[op->opcode_index])->builtin_code; - auto opType = GetMSOpType(tfliteOpType); - - // parse primitive - auto nodeParser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); - if (nodeParser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(opType); - status = (status == RET_OK ? RET_NOT_FIND_OP : status); - continue; - } - PrimitiveC *primitiveC = nullptr; - if (status == RET_OK) { - status = nodeParser->Parse(op, tfliteModel, primitiveC); - if (status != RET_OK) { - if (status == RET_NOT_FIND_OP) { - opType = (opType != "Custom" ? opType : (tfliteModel->operator_codes[op->opcode_index])->custom_code); - NoSupportOp::GetInstance()->InsertOp(opType); - } else { - MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed"; - } - continue; - } - - std::vector opInputs = {NewValueNode(std::shared_ptr(primitiveC))}; - // parse inputs - for (auto inputIdx : op->inputs) { - const auto &inputTensor = tfliteSubgraph->tensors[inputIdx]; - if (nodes.find(inputIdx) != nodes.end()) { - opInputs.emplace_back(nodes.at(inputIdx)); - continue; - } - // const tensor - if (tfliteModelBuffers.at(inputTensor->buffer)->data.empty()) { - ParameterPtr parameter; - ConvertConstTensor(inputTensor.get(), parameter); - opInputs.emplace_back(parameter); - nodes.insert(std::pair(inputIdx, parameter)); - continue; - } - MS_LOG(ERROR) << "tensor" << inputIdx << " is neither a node output nor a weight tensor."; - return RET_ERROR; - } - auto newCNode = funcGraphPtr->NewCNode(opInputs); - newCNode->set_fullname_with_scope(opType + "-" + std::to_string(opIdx++)); - - // parse outputs - status = ConvertOutputTensor(op.get(), newCNode); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert output tensors for " << newCNode->fullname_with_scope() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; - } - } - } - return status; -} - -STATUS ModelParserForTflite::ConvertGraphInputs() { - const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); - for (auto tfliteGraphInput : tfliteSubgraph->inputs) { - if (tfliteGraphInput < 0) { - tfliteGraphInput = tfliteGraphInput + tfliteSubgraph->tensors.size(); - } - auto parameter = funcGraphPtr->add_parameter(); - const auto &tensor = tfliteSubgraph->tensors.at(tfliteGraphInput); - std::vector shape_vector; - (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); - parameter->set_name("graph_input_" + std::to_string(tfliteGraphInput) + "_parameter"); - nodes.insert(std::pair(tfliteGraphInput, parameter)); - } - return RET_OK; -} -STATUS ModelParserForTflite::ConvertGraphOutputs() { - const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); - if (tfliteSubgraph->outputs.size() > 1) { - std::vector make_tuple_inputs; - auto make_tuple_prim_ptr = GetMakeTuplePrim(); - if (make_tuple_prim_ptr == nullptr) { - MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; - return RET_NULL_PTR; - } - auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); - make_tuple_inputs.emplace_back(make_tuple_prim); - for (auto outputNode : tfliteSubgraph->outputs) { - auto cnode = nodes.at(outputNode); - if (nullptr == cnode) { - MS_LOG(ERROR) << "Can't find input node."; - return RET_NOT_FIND_OP; - } - make_tuple_inputs.emplace_back(cnode); - } - auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); - make_tuple_cnode->set_fullname_with_scope("return tuple"); - - std::vector op_inputs; - auto return_prim_ptr = GetReturnPrim(); - if (return_prim_ptr == nullptr) { - MS_LOG(ERROR) << "GetReturnPrim return nullptr"; - return RET_NULL_PTR; - } - auto value_node = NewValueNode(return_prim_ptr); - op_inputs.emplace_back(value_node); - op_inputs.emplace_back(make_tuple_cnode); - auto cnode = funcGraphPtr->NewCNode(op_inputs); - cnode->set_fullname_with_scope("return"); - funcGraphPtr->set_return(cnode); - } else { - auto returnPrim = GetReturnPrim(); - if (returnPrim == nullptr) { - MS_LOG(ERROR) << "GetReturnPrim return nullptr"; - return RET_NULL_PTR; - } - auto valueNode = NewValueNode(returnPrim); - std::vector opInputs{valueNode}; - auto cnode = nodes.at(tfliteSubgraph->outputs.front()); - if (nullptr == cnode) { - MS_LOG(ERROR) << "Can't find input node."; - return RET_NOT_FIND_OP; - } - opInputs.emplace_back(cnode); - auto returnCnode = funcGraphPtr->NewCNode(opInputs); - returnCnode->set_fullname_with_scope("return"); - funcGraphPtr->set_return(returnCnode); - } - return RET_OK; -} - -STATUS ModelParserForTflite::ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter) { - parameter = funcGraphPtr->add_parameter(); - const auto &tfliteModelBuffers = tfliteModel->buffers; - auto type_id = static_cast(tensor->type); - auto type_ptr = TypeIdToType(type_id); - std::vector shape_vector; - (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); - parameter->set_name("const_" + std::to_string(nodes.size()) + "_parameter"); - - ParamValueLitePtr paramValue = std::make_shared(); - MS_ASSERT(paramValue != nullptr); - paramValue->set_tensor_shape(tensor->shape); - paramValue->set_tensor_type(GetTfliteDataType(tensor->type)); - paramValue->set_format(schema::Format::Format_NHWC); - const auto &data = tfliteModelBuffers.at(tensor->buffer)->data; - if (!data.empty()) { - auto size = data.size(); - char *tensor_data = new (std::nothrow) char[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - std::memcpy(tensor_data, data.data(), size); - paramValue->set_tensor_addr(tensor_data); - paramValue->set_tensor_size(size); - parameter->set_default_param(paramValue); - } - return RET_OK; -} - -STATUS ModelParserForTflite::ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode) { - MS_ASSERT(op != nullptr); - MS_ASSERT(dstCNode != nullptr); - const auto &tfliteSubgraph = tfliteModel->subgraphs.front(); - if (op->outputs.size() == 1) { - const auto &tensor = tfliteSubgraph->tensors.at(op->outputs.front()); - std::vector shape_vector; - (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type)); - dstCNode->set_abstract(std::make_shared(typePtr, shape_vector)); - nodes.insert(std::pair(op->outputs.front(), dstCNode)); - } else { - AbstractBasePtrList abstractList; - for (auto outputIdx : op->outputs) { - const auto &tensor = tfliteSubgraph->tensors.at(outputIdx); - std::vector shape_vector; - (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type)); - abstractList.emplace_back(std::make_shared(typePtr, shape_vector)); - auto tupleGetItemPrimPtr = GetTupleGetItemPrim(); - if (tupleGetItemPrimPtr == nullptr) { - MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; - return RET_NULL_PTR; - } - auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); - auto getItemValue = NewValueNode(MakeValue(outputIdx)); - std::vector inputs{tupleGetItemPrim, dstCNode, getItemValue}; - CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); - getItemCNode->set_fullname_with_scope(dstCNode->fullname_with_scope() + "_getitem_" + std::to_string(outputIdx)); - nodes.insert(std::pair(outputIdx, getItemCNode)); - } - dstCNode->set_abstract(std::make_shared(abstractList)); - } - return RET_OK; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.h b/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.h deleted file mode 100644 index a223b2b36a..0000000000 --- a/mindspore/lite/tools/converter/parser/tflite/model_parser_for_tflite.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef LITE_MODEL_PARSER_FOR_TFLITE_H -#define LITE_MODEL_PARSER_FOR_TFLITE_H - -#include -#include -#include -#include "tools/converter/parser/tflite/tflite_model_parser.h" - -namespace mindspore::lite { -class ModelParserForTflite : public TfliteModelParser { - public: - ModelParserForTflite() = default; - - ~ModelParserForTflite() override = default; - - FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override; - - private: - std::unordered_map nodes; - std::unique_ptr tfliteModel; - FuncGraphPtr funcGraphPtr; - STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter); - STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode); - STATUS ConvertOps(); - STATUS ConvertGraphInputs(); - STATUS ConvertGraphOutputs(); -}; -} // namespace mindspore::lite -#endif // LITE_MODEL_PARSER_FOR_TFLITE_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index 0b2b264fbd..337fda4652 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -18,9 +18,11 @@ #include #include #include +#include "src/ops/activation.h" +#include "src/ops/primitive_c.h" +#include "tools/converter/parser/tflite/tflite_util.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, @@ -86,12 +88,40 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, return RET_OK; } -TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser()); -TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser()); -TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser()); -TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser()); -TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser()); +lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto ms_op_type = GetMSOpType(tflite_op_type); + if (kActivationTypeMap.find(ms_op_type) == kActivationTypeMap.end()) { + MS_LOG(ERROR) << ms_op_type << "is a not supported activation type"; + return nullptr; + } + attr->type = kActivationTypeMap.find(GetMSOpType(tflite_op_type))->second; + if (attr->type == schema::ActivationType_LEAKY_RELU) { + const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << GetMSOpType(tflite_op_type) << " attr failed"; + return nullptr; + } + attr->alpha = tflite_attr->alpha; + } + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} + +TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser()); +TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser()); +TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser()); +TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser()); +TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); -TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); -} // namespace lite -} // namespace mindspore +TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index 6418678c60..c39f1dbb0b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteActivationParser : public TfliteNodeParser { public: TfliteActivationParser() : TfliteNodeParser("node_name") {} @@ -32,9 +31,10 @@ class TfliteActivationParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; -}; -} // namespace lite -} // namespace mindspore + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; +}; +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc index 6c4353b9eb..dcf38dbc01 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -18,9 +18,10 @@ #include "tools/converter/parser/tflite/tflite_addn_parser.h" #include #include +#include +#include "src/ops/addn.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { @@ -55,7 +56,18 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_AddN; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser()); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h index 4417fe9862..dd137de5aa 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -32,6 +32,9 @@ class TfliteAddNParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index 568a096401..223cfc0805 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -76,6 +76,39 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->outMaxValue = false; + attr->topK = 1; + attr->keepDims = false; + attr->axisType = 1; + + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; + auto &buf_data = tflite_model->buffers[buffer_idx]; + if (buf_data == nullptr) { + MS_LOG(ERROR) << "the buf data is null"; + return nullptr; + } + auto data_ptr = buf_data->data.data(); + if (data_ptr == nullptr) { + MS_LOG(ERROR) << "the data is null"; + return nullptr; + } + attr->axis = *(static_cast(static_cast(data_ptr))); + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_ArgMax; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h index 60038edf3e..50a4ca9354 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -32,6 +32,9 @@ class TfliteArgmaxParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index b1d2846aa5..8374232050 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -76,6 +76,39 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->outMaxValue = false; + attr->topK = 1; + attr->keepDims = false; + attr->axisType = 1; + + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; + auto &buf_data = tflite_model->buffers[buffer_idx]; + if (buf_data == nullptr) { + MS_LOG(ERROR) << "the buf data is null"; + return nullptr; + } + auto data_ptr = buf_data->data.data(); + if (data_ptr == nullptr) { + MS_LOG(ERROR) << "the data is null"; + return nullptr; + } + attr->axis = *(static_cast(static_cast(data_ptr))); + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_ArgMin; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h index 0422c0232c..be2757aaf0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -32,6 +32,8 @@ class TfliteArgminParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 0b37afe175..2d138b6269 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -179,6 +179,133 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto primitive = std::make_unique(); + if (tflite_op_type == tflite::BuiltinOperator_ADD) { + MS_LOG(DEBUG) << "parse TfliteAddParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; + return nullptr; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + primitive->value.type = schema::PrimitiveType_Add; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_SUB) { + MS_LOG(DEBUG) << "parse TfliteSubParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; + return nullptr; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + primitive->value.type = schema::PrimitiveType_Sub; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_MUL) { + MS_LOG(DEBUG) << "parse TfliteMulParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; + return nullptr; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + primitive->value.type = schema::PrimitiveType_Mul; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_DIV) { + MS_LOG(DEBUG) << "parse TfliteDivParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed"; + return nullptr; + } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + primitive->value.type = schema::PrimitiveType_Div; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_DIV) { + MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_FloorDiv; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_MOD) { + MS_LOG(DEBUG) << "parse TfliteFloorModParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_FloorMod; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_SQUARED_DIFFERENCE) { + MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_SquaredDifference; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_POW) { + MS_LOG(DEBUG) << "parse TflitePowParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->power = 1.0f; + attr->scale = 1.0f; + attr->shift = 0.0f; + primitive->value.type = schema::PrimitiveType_Power; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_MAXIMUM) { + MS_LOG(DEBUG) << "parse TfliteMaximumParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Maximum; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_MINIMUM) { + MS_LOG(DEBUG) << "parse TfliteMinimumParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Minimum; + primitive->value.value = attr.release(); + } else { + MS_LOG(ERROR) << "op hasn't been supported"; + return nullptr; + } + return PrimitiveC::Create(primitive.release()); +} STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, @@ -320,6 +447,124 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto primitive = std::make_unique(); + if (tflite_op_type == tflite::BuiltinOperator_ABS) { + MS_LOG(DEBUG) << "parse TfliteAbsParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Abs; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_EXP) { + MS_LOG(DEBUG) << "parse TfliteExpParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->base = -1; // -1 represent base = e + attr->scale = 1; + attr->shift = 0; + primitive->value.type = schema::PrimitiveType_Exp; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_SQRT) { + MS_LOG(DEBUG) << "parse TfliteSqrtParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Sqrt; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_RSQRT) { + MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Rsqrt; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_SQUARE) { + MS_LOG(DEBUG) << "parse TfliteSquareParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Square; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_SIN) { + MS_LOG(DEBUG) << "parse TfliteSinParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Sin; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_COS) { + MS_LOG(DEBUG) << "parse TfliteCosParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Cos; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_LOG) { + MS_LOG(DEBUG) << "parse TfliteLogParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Log; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_ROUND) { + MS_LOG(DEBUG) << "parse TfliteRoundParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Round; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_CEIL) { + MS_LOG(DEBUG) << "parse TfliteCeilParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Ceil; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_FLOOR) { + MS_LOG(DEBUG) << "parse TfliteFloorParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Floor; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_NEG) { + MS_LOG(DEBUG) << "parse TfliteNegParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Neg; + primitive->value.value = attr.release(); + } + return PrimitiveC::Create(primitive.release()); +} STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, @@ -406,29 +651,91 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto primitive = std::make_unique(); + + if (tflite_op_type == tflite::BuiltinOperator_EQUAL) { + MS_LOG(DEBUG) << "parse TfliteEqualParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Equal; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_NOT_EQUAL) { + MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_NotEqual; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_GREATER) { + MS_LOG(DEBUG) << "parse TfliteGreaterParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Greater; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_GREATER_EQUAL) { + MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_GreaterEqual; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_LESS) { + MS_LOG(DEBUG) << "parse TfliteLessParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Less; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_LESS_EQUAL) { + MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LessEqual; + primitive->value.value = attr.release(); + } + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser()); +TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteDoubleInputOpParser()); +TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDoubleInputOpParser()); TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser()); +TfliteNodeRegister g_TflitePowParser("Pow", new TfliteDoubleInputOpParser()); TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); +TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); +TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); -TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteLogParser("Log", new TfliteSingleInputOpParser()); TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); -TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); +TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index b99b5fbeb9..65cbc7c91c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -32,6 +32,9 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; class TfliteSingleInputOpParser : public TfliteNodeParser { @@ -41,6 +44,9 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; class TfliteCompareOpParser : public TfliteNodeParser { @@ -50,7 +56,11 @@ class TfliteCompareOpParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index 69a01c54ad..1f2f1cc041 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -74,6 +74,29 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { + MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; + return nullptr; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) { + MS_LOG(ERROR) << "get batchToSpace -> crops failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_BatchToSpace; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h index 71d32f2531..766d5798d4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -32,7 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index 0c49c600a7..02b29d1eda 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -57,6 +57,28 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) { + MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_BroadcastTo; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h index fe72058c9f..6652744337 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteBroadcastToParser : public TfliteNodeParser { public: TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} @@ -32,8 +31,10 @@ class TfliteBroadcastToParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index cd561c1e81..28d630ad94 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -63,6 +63,32 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "tensor is null"; + return nullptr; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "tensor is null"; + return nullptr; + } + attr->dstT = GetTfliteDataType(out_tensor->type); + primitive->value.type = schema::PrimitiveType_Cast; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index 8f9dd06906..ae71a4bb8d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -32,6 +32,8 @@ class TfliteCastParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc index 8b9027478f..9a3b51ca49 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -60,6 +60,26 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op concat attr failed"; + return nullptr; + } + attr->axis = tfliteAttr->axis; + attr->n = tflite_op->inputs.size(); + primitive->value.type = schema::PrimitiveType_Concat; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h index 647d0a3918..e99cb1ac87 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -32,6 +32,8 @@ class TfliteConcatParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 4345f02732..1fec27f723 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -18,8 +18,7 @@ #include #include -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { @@ -74,7 +73,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu // calculate pad params auto data_index = tflite_op->inputs[0]; const auto &data_tensor = tflite_subgraph->tensors[data_index]; - std::vector params; + std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -96,7 +95,63 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get conv attr failed"; + return nullptr; + } + attr->group = 1; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format::Format_NHWC; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + attr->hasBias = true; + + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; + if (weight_tensor == nullptr) { + MS_LOG(ERROR) << "the weight tensor is null"; + return nullptr; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; + std::vector params; + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "get padding params failed"; + return nullptr; + } else if (status == RET_OK) { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h index 308250edeb..8c98656fc9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteConvParser : public TfliteNodeParser { public: TfliteConvParser() : TfliteNodeParser("Conv2D") {} @@ -32,8 +31,9 @@ class TfliteConvParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc index 24b6ec0783..34f28d7306 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc @@ -15,9 +15,8 @@ */ #include "tools/converter/parser/tflite/tflite_converter.h" +#include "tools/converter/parser/tflite/tflite_model_parser.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); } -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h index 9bc53fb955..eba2150e7c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -21,18 +21,15 @@ #include #include #include "tools/converter/converter.h" -#include "tools/converter/parser/tflite/tflite_model_parser.h" #include "tools/converter/graphdef_transform.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteConverter : public Converter { public: TfliteConverter(); ~TfliteConverter() override = default; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index ba22e6b79f..60aed771a7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -271,6 +271,48 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } return status; } +PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto op = new schema::CNodeT; + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + const auto &custom_attr = tflite_op->custom_options; + const auto &opcode_index = tflite_op->opcode_index; + const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code; + int status = RET_OK; + if (custom_type == "TFLite_Detection_PostProcess") { + status = DetectPostProcess(custom_attr, op, tflite_op); + } else if (custom_type == "Predict") { + status = Predict(custom_attr, op, tflite_op); + } else if (custom_type == "Normalize") { + status = Normalize(custom_attr, op, tflite_op); + } else if (custom_type == "ExtractFeatures") { + status = ExtractFeatures(custom_attr, op, tflite_op); + } else if (custom_type == "AudioSpectrogram") { + status = AudioSpectrogram(custom_attr, op, tflite_op); + } else if (custom_type == "Mfcc") { + status = Mfcc(custom_attr, op, tflite_op); + } else if (custom_type == "FlexRFFT") { + status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph); + } else if (custom_type == "FlexReal") { + status = FftReal(custom_attr, op, tflite_op); + } else if (custom_type == "FlexImag") { + status = FftImag(custom_attr, op, tflite_op); + } else { + MS_LOG(ERROR) << "the custom op hasn't been supported now"; + status = RET_NOT_FIND_OP; + } + if (status != RET_OK) { + return nullptr; + } + auto primitive = op->primitive.release(); + delete op; + return PrimitiveC::Create(primitive); +} TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h index b1ad78dee5..aa10234c15 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h @@ -31,6 +31,8 @@ class TfliteCustomParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; static STATUS DetectPostProcess(const std::vector &custom_attr, schema::CNodeT *op, const std::unique_ptr &tflite_op); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 703be87fcc..66be277d76 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -75,7 +75,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni // calculate pad params auto data_index = tflite_op->inputs[2]; const auto &data_tensor = tflite_subgraph->tensors[data_index]; - std::vector params; + std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -96,6 +96,64 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op deconv attr failed"; + return nullptr; + } + + attr->group = 1; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = 1; + attr->dilateW = 1; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format::Format_NHWC; + attr->activationType = schema::ActivationType_NO_ACTIVATION; + attr->hasBias = true; + + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; + if (weight_tensor == nullptr) { + MS_LOG(ERROR) << "the weight tensor is null"; + return nullptr; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + auto data_index = tflite_op->inputs[2]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; + std::vector params; + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "get padding params failed"; + return nullptr; + } else if (status == RET_OK) { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } + primitive->value.type = schema::PrimitiveType_DeConv2D; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h index a52e89b7aa..5f94fb11e1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -32,6 +32,8 @@ class TfliteDeConvParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index cb08fede98..4767c62225 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -60,6 +60,26 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op depthtospace attr failed"; + return nullptr; + } + attr->blockSize = tflite_attr->block_size; + attr->format = schema::Format::Format_NHWC; + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_Concat; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h index ae303db657..58be13f27d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -32,6 +32,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 27ae6c6b7e..d733b8aeec 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -18,8 +18,7 @@ #include #include -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, @@ -82,7 +81,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, attr->kernelW = weight_shape[2]; // calculate pad params - std::vector params; + std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -104,7 +103,71 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; + std::unique_ptr attr = std::make_unique(); + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op de attr failed"; + return nullptr; + } + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format::Format_NHWC; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + attr->hasBias = true; + attr->channelMultiplier = tflite_attr->depth_multiplier; + + // get the data tensor + auto data_index = tflite_op->inputs[1]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the data tensor is null"; + return nullptr; + } + auto data_shape = data_tensor->shape; + attr->channelIn = data_shape[3]; + + // get the weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; + if (weight_tensor == nullptr) { + MS_LOG(ERROR) << "the weight tensor is null"; + return nullptr; + } + auto weight_shape = weight_tensor->shape; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + std::vector params; + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "get padding params failed"; + return nullptr; + } else if (status == RET_OK) { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } + + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h index 22885dc466..1357ed5b97 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteDepthwiseConv2DParser : public TfliteNodeParser { public: TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} @@ -32,8 +31,10 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index f35b26f631..bdca34405f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -75,6 +75,45 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "input tensor is null"; + return nullptr; + } + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "output tensor is null"; + return nullptr; + } + if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && + (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || + GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_QuantDTypeCast; + } else { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Cast; + } + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h index 61b3d0f25c..d9944928d6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -31,6 +31,9 @@ class TfliteDequantizeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index c0304af882..adcab9769d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -56,6 +56,30 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + std::vector dims; + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) { + MS_LOG(ERROR) << "get expand_dims -> dim failed"; + return nullptr; + } + attr->dim = dims[0]; + primitive->value.type = schema::PrimitiveType_ExpandDims; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h index f4f2e6c551..7169bb8c6c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -32,6 +32,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index a27e3f63db..1aa64ca07f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -57,6 +57,32 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (tflite_op->inputs.size() > 1) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { + MS_LOG(ERROR) << "get fill -> dims failed"; + return nullptr; + } + } + + primitive->value.type = schema::PrimitiveType_Fill; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h index 9703db3959..ce4fb12fcd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -32,6 +32,9 @@ class TfliteFillParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 92543ac5c3..ebca042418 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -69,6 +69,37 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op fully connect attr failed"; + return nullptr; + } + + bool hasBias = tflite_op->inputs.size() > 2 && tflite_op->inputs[2] != -1; + + attr->hasBias = hasBias; + attr->axis = 1; + attr->useAxis = false; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + + primitive->value.type = schema::PrimitiveType_FullConnection; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h index d81fe6a4e8..af8caa994c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -32,7 +32,10 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index b74a2205e7..ab1ada4aa7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -54,6 +54,26 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->batchDims = 0; + + primitive->value.type = schema::PrimitiveType_GatherNd; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h index 6c3bb2a77e..095220900c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -32,6 +32,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index 5b00acb82d..a55bce231e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -60,6 +60,32 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op gather attr failed"; + return nullptr; + } + attr->axis = tflite_attr->axis; + attr->batchDims = 0; + + primitive->value.type = schema::PrimitiveType_Gather; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h index 30e06be447..427f01c167 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -32,6 +32,8 @@ class TfliteGatherParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc index c57e6cac58..a98a7dae88 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -55,6 +55,24 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, } return RET_OK; } +PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_HashtableLookup; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h index d23157d7ad..1af5fd118c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h @@ -32,6 +32,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc index 617b606423..b9ff61a67c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -55,6 +55,27 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions(); + attr->axis = {-1}; + attr->epsilon = 1e-6f; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_L2Norm; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h index 4a929d163b..c1b897d5b1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h @@ -32,6 +32,8 @@ class TfliteL2NormParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc index 1d417f0020..3009a9d7fe 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -78,6 +78,44 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_AND) { + MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalAnd; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_NOT) { + MS_LOG(DEBUG) << "parse TfliteLogicalNotParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalNot; + primitive->value.value = attr.release(); + } else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_OR) { + MS_LOG(DEBUG) << "parse TfliteLogicalOrParser"; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalOr; + primitive->value.value = attr.release(); + } + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser()); TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h index 6740a775b6..8b60184581 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteLogicalParser : public TfliteNodeParser { public: TfliteLogicalParser() : TfliteNodeParser("node_name") {} @@ -32,8 +31,9 @@ class TfliteLogicalParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index b0fe0909c7..82d63c30dd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -60,6 +60,34 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op LRN attr failed"; + return nullptr; + } + attr->depth_radius = tflite_attr->radius; + attr->alpha = tflite_attr->alpha; + attr->beta = tflite_attr->beta; + attr->bias = tflite_attr->bias; + + primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h index 575aaa1fca..869c955116 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -32,6 +32,8 @@ class TfliteLRNParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc index 5d08a82d20..103d1a94f2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc @@ -64,6 +64,35 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions(); + switch (tflite_attr->type) { + case tflite::LSHProjectionType_SPARSE: + attr->type = schema::LshProjectionType_SPARSE; + break; + case tflite::LSHProjectionType_DENSE: + attr->type = schema::LshProjectionType_DENSE; + break; + default: + attr->type = schema::LshProjectionType_UNKNOWN; + } + primitive->value.type = schema::PrimitiveType_LshProjection; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h index c452e94b8d..359ea23002 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h @@ -32,6 +32,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 177701e84d..b9ed34bad0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,79 +13,167 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "tools/converter/parser/tflite/tflite_model_parser.h" -#include -#include +#include #include -#include -#include "tools/common/graph_util.h" -#include "tools/common/storage.h" -#include "flatbuffers/flatbuffers.h" +#include +#include +#include +#include "src/param_value_lite.h" #include "src/common/file_utils.h" -#include "tools/common/node_util.h" - -namespace mindspore { -namespace lite { -TfliteModelParser::TfliteModelParser() = default; -TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); } +namespace mindspore::lite { std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { size_t size = 0; - tfliteModelBuf = ReadFile(model_path, &size); - if (tfliteModelBuf == nullptr) { + tflite_model_buf_ = ReadFile(model_path, &size); + if (tflite_model_buf_ == nullptr) { MS_LOG(ERROR) << "the file buffer is nullptr"; return nullptr; } - flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size); + flatbuffers::Verifier verify((const uint8_t *)tflite_model_buf_, size); if (!tflite::VerifyModelBuffer(verify)) { MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; return nullptr; } - return tflite::UnPackModel(tfliteModelBuf); + return tflite::UnPackModel(tflite_model_buf_); } -STATUS TfliteModelParser::CopyConstTensorData(const std::vector> &tflite_model_buffer, - const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { - MS_ASSERT(tensor != nullptr); - MS_ASSERT(tflite_tensor != nullptr); - auto buffer_idx = tflite_tensor->buffer; +FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + // load graph + tflite_model_ = ReadTfliteModel(model_file.c_str()); + if (tflite_model_ == nullptr) { + MS_LOG(ERROR) << "read tflite model failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); + return nullptr; + } - const auto &buf = tflite_model_buffer[buffer_idx]; - if (buf == nullptr) { - MS_LOG(ERROR) << "tensor is null"; - return RET_NULL_PTR; + if (tflite_model_->subgraphs.size() != 1) { + MS_LOG(ERROR) << "read tflite model subgraphs failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); + return nullptr; } + func_graph_ = std::make_shared(); - if (!buf->data.empty()) { - auto data_size = buf->data.size(); - tensor->data.resize(data_size); - if (memcpy_s(tensor->data.data(), data_size, buf->data.data(), data_size) != EOK) { - MS_LOG(ERROR) << "memcpy tensor data failed"; - return RET_MEMORY_FAILED; - } - } else { - MS_LOG(ERROR) << "src tensor data is empty"; - return RET_INPUT_TENSOR_ERROR; + auto status = ConvertGraphInputs(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert graph inputs failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; } - return RET_OK; + + status = ConvertOps(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert ops failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + status = ConvertGraphOutputs(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert graph outputs failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + return func_graph_; } -void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr &tflite_tensor, - schema::TensorT *tensor) { - MS_ASSERT(tensor != nullptr); - tensor->quantParams.clear(); +STATUS TfliteModelParser::ConvertOps() { + const auto &tflite_subgraph = tflite_model_->subgraphs.front(); + const auto &tflite_model_buffers = tflite_model_->buffers; + NoSupportOp::GetInstance()->SetFmkType("TFLITE"); + STATUS status = RET_OK; + int op_idx = 0; + for (auto &op : tflite_subgraph->operators) { + auto tfliteOpType = (tflite_model_->operator_codes[op->opcode_index])->builtin_code; + auto op_type = GetMSOpType(tfliteOpType); + auto op_name = op_type + "-" + std::to_string(op_idx); + op_idx++; + // parse primitive + auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(op_type); + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + continue; + } + + if (status != RET_OK) { + continue; + } + + auto primitiveC = node_parser->ParseLitePrimitive(op, tflite_model_); + if (primitiveC == nullptr) { + MS_LOG(ERROR) << "parse node " << op_type.c_str() << " parser failed"; + continue; + } + + status = ConvertOpQuantParams(op.get(), primitiveC); + if (status != RET_OK) { + MS_LOG(ERROR) << "convert " << op_name << " quant param failed."; + return status; + } - if (tflite_tensor->quantization == nullptr) { - MS_LOG(ERROR) << "tflite_tensor->quantization is null"; - return; + std::vector op_inputs = {NewValueNode(std::shared_ptr(primitiveC))}; + // parse inputs + for (auto input_idx : op->inputs) { + if (input_idx < 0) { + input_idx += tflite_subgraph->tensors.size(); + } + const auto &input_tensor = tflite_subgraph->tensors[input_idx]; + if (nodes_.find(input_idx) != nodes_.end()) { + op_inputs.emplace_back(nodes_.at(input_idx)); + continue; + } + // const tensor + if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) { + auto parameter = func_graph_->add_parameter(); + status = ConvertConstTensor(input_tensor.get(), parameter.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; + return status; + } + op_inputs.emplace_back(parameter); + nodes_.insert(std::pair(input_idx, parameter)); + continue; + } + MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor."; + } + auto new_cnode = func_graph_->NewCNode(op_inputs); + new_cnode->set_fullname_with_scope(op_name); + + // parse outputs + status = ConvertOutputTensor(op.get(), new_cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert output tensors for " << new_cnode->fullname_with_scope() << " failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return status; + } + } + return status; +} + +STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, + std::vector *quant_params) { + if (tflite_tensor == nullptr) { + MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; + return RET_NULL_PTR; + } + quant_params->clear(); + + if (tflite_tensor->quantization == nullptr || + (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { + std::vector notinited_quant_params(1); + *quant_params = notinited_quant_params; + return RET_OK; } + for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) { std::unique_ptr quant_param = std::make_unique(); if (quant_param == nullptr) { MS_LOG(ERROR) << "quant_param is null"; - return; + return RET_NULL_PTR; } if (!tflite_tensor->quantization->scale.empty()) { @@ -104,364 +192,219 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptrmax = tflite_tensor->quantization->max[i]; } quant_param->inited = true; - tensor->quantParams.emplace_back(std::move(quant_param)); + quant_params->emplace_back(*std::move(quant_param)); } + return RET_OK; } -STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - const QuantType &quant_type, schema::MetaGraphT *sub_graph) { - MS_ASSERT(tflite_model != nullptr); - MS_ASSERT(tflite_subgraph != nullptr); - MS_ASSERT(sub_graph != nullptr); +STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c) { + if (op == nullptr) { + MS_LOG(ERROR) << "tflite op is null, get quant params failed."; + return RET_NULL_PTR; + } - int idx = 0; - int status = RET_OK; - NoSupportOp::GetInstance()->SetFmkType("TFLITE"); - for (const auto &tflite_op : tflite_subgraph->operators) { - const auto opcode_index = tflite_op->opcode_index; - const auto &operator_code = tflite_model->operator_codes[opcode_index]; - if (operator_code == nullptr) { - MS_LOG(ERROR) << "operator_code is null"; - return RET_ERROR; + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; + return RET_NULL_PTR; + } + const auto &tflite_subgraph = tflite_model_->subgraphs.front(); + for (auto input_idx : op->inputs) { + if (input_idx < 0) { + input_idx += tflite_subgraph->tensors.size(); } - auto op_type = GetMSOpType(operator_code->builtin_code); - - auto op = std::make_unique(); - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; + const auto &input_tensor = tflite_subgraph->tensors[input_idx]; + std::vector quant_params; + auto status = SetTensorQuantParam(input_tensor.get(), &quant_params); + if (status != RET_OK) { + MS_LOG(ERROR) << "set input tensor quant param failed."; + return status; } - op->name = op_type + "-" + std::to_string(idx++); - op->quantType = quant_type; - MS_LOG(INFO) << "parse op: " << op->name.c_str(); - - auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); - if (node_parser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(op_type); - status = (status == RET_OK ? RET_NOT_FIND_OP : status); - continue; + primitive_c->AddInputQuantParam(quant_params); + } + for (auto output_idx : op->outputs) { + if (output_idx < 0) { + output_idx += tflite_subgraph->tensors.size(); } - if (status == RET_OK || op_type == "Custom") { - int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); - status = (status == RET_OK ? status_node : status); - if (status_node != RET_OK) { - if (status_node == RET_NOT_FIND_OP) { - op_type = (op_type != "Custom" ? op_type : operator_code->custom_code); - NoSupportOp::GetInstance()->InsertOp(op_type); - } else { - MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; - } - continue; - } - if (status != RET_OK) { - continue; - } - sub_graph->nodes.emplace_back(op.release()); - opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); - tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); + const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); + std::vector quant_params; + auto status = SetTensorQuantParam(output_tensor.get(), &quant_params); + if (status != RET_OK) { + MS_LOG(ERROR) << "set output tensor quant param failed."; + return status; } + primitive_c->AddOutputQuantParam(quant_params); } - return status; + return RET_OK; } -STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, - const std::vector> &tflite_model_buffer, - schema::MetaGraphT *sub_graph) { - MS_ASSERT(tflite_subgraph != nullptr); - MS_ASSERT(sub_graph != nullptr); - std::set output_index; - for (const auto &tflite_op : tflite_subgraph->operators) { - for (int idx : tflite_op->outputs) { - if (idx < 0) { - idx += tflite_subgraph->tensors.size(); - } - output_index.insert(idx); +STATUS TfliteModelParser::ConvertGraphInputs() { + const auto &tflite_subgraph = tflite_model_->subgraphs.front(); + for (auto tflite_graph_input : tflite_subgraph->inputs) { + if (tflite_graph_input < 0) { + tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); } + auto parameter = func_graph_->add_parameter(); + const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); + std::vector shape_vector; + (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + parameter->set_abstract(abstract_tensor); + parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter"); + nodes_.insert(std::pair(tflite_graph_input, parameter)); } - for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { - auto idx = tensorsInfo.tensorsId[i]; - if (idx < 0) { - idx += tflite_subgraph->tensors.size(); - } - const auto &tflite_tensor = tflite_subgraph->tensors[idx]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tflite_tensor is null"; - return RET_NULL_PTR; - } - - std::unique_ptr tensor = std::make_unique(); - if (tensor == nullptr) { - MS_LOG(ERROR) << "tensor is null"; + return RET_OK; +} +STATUS TfliteModelParser::ConvertGraphOutputs() { + const auto &tflite_subgraph = tflite_model_->subgraphs.front(); + if (tflite_subgraph->outputs.size() > 1) { + std::vector make_tuple_inputs; + auto make_tuple_prim_ptr = GetMakeTuplePrim(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; return RET_NULL_PTR; } - - tensor->format = tensorsInfo.tensorsFormat[i]; - tensor->dataType = GetTfliteDataType(tflite_tensor->type); - tensor->dims = tflite_tensor->shape; - - // if graph input tensor - bool isInput = false; - auto tflite_inputs = tflite_subgraph->inputs; - for (int tflite_input : tflite_inputs) { - if (idx == tflite_input) { - isInput = true; - break; + auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); + make_tuple_inputs.emplace_back(make_tuple_prim); + for (auto outputNode : tflite_subgraph->outputs) { + auto cnode = nodes_.at(outputNode); + if (nullptr == cnode) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; } + make_tuple_inputs.emplace_back(cnode); } + auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); - // add data for const tensor - auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer); - if (tensor_buffer == nullptr) { - MS_LOG(ERROR) << "tensor_buffer is null"; + std::vector op_inputs; + auto return_prim_ptr = GetReturnPrim(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; return RET_NULL_PTR; } - auto isConst = (!tensor_buffer->data.empty()); - if (isConst) { - int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "obtain const tensor failed"; - return status; - } - } - - // set tensor attr - if (isInput || isConst) { - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - } else { - if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) { - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - } else { - tensor->nodeType = schema::NodeType_Parameter; - } + auto value_node = NewValueNode(return_prim_ptr); + op_inputs.emplace_back(value_node); + op_inputs.emplace_back(make_tuple_cnode); + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + func_graph_->set_return(cnode); + } else { + auto returnPrim = GetReturnPrim(); + if (returnPrim == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; } - - // quant param - if (tflite_tensor->quantization != nullptr && - !(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && - tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { - SetTensorQuantParam(tflite_tensor, tensor.get()); + auto valueNode = NewValueNode(returnPrim); + std::vector op_inputs{valueNode}; + auto cnode = nodes_.at(tflite_subgraph->outputs.front()); + if (nullptr == cnode) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; } - - tensors.push_back(tensor.release()); - } - - for (auto iter : tensors) { - std::unique_ptr temp(iter); - sub_graph->allTensors.emplace_back(move(temp)); + op_inputs.emplace_back(cnode); + auto returnCnode = func_graph_->NewCNode(op_inputs); + returnCnode->set_fullname_with_scope("return"); + func_graph_->set_return(returnCnode); } return RET_OK; } -STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr &tflite_subgraph, - schema::MetaGraphT *sub_graph) { - MS_ASSERT(sub_graph != nullptr); - MS_ASSERT(tflite_subgraph != nullptr); - // graph input - std::vector graph_inputs; - for (size_t i = 0; i < tflite_subgraph->inputs.size(); i++) { - const int idx = tflite_subgraph->inputs[i]; - int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx; - auto iter = tensorsInfo.tensorsIdMap.find(id); - if (iter != tensorsInfo.tensorsIdMap.end()) { - graph_inputs.push_back(iter->second); - } else { - MS_LOG(ERROR) << "get graph input failed"; - return RET_INPUT_TENSOR_ERROR; - } +STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) { + if (tensor == nullptr) { + MS_LOG(ERROR) << "tensor is null, get const tensor failed."; + return RET_NULL_PTR; } - sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end()); - // graph output - std::vector graph_outputs; - for (size_t i = 0; i < tflite_subgraph->outputs.size(); i++) { - const int idx = tflite_subgraph->outputs[i]; - int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx; - auto iter = tensorsInfo.tensorsIdMap.find(id); - if (iter != tensorsInfo.tensorsIdMap.end()) { - graph_outputs.push_back(iter->second); - } else { - MS_LOG(ERROR) << "get graph output failed"; - return RET_INPUT_TENSOR_ERROR; - } + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is null, get const tensor failed."; + return RET_NULL_PTR; } - sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end()); - return RET_OK; -} - -STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) { - MS_ASSERT(sub_graph != nullptr); - for (auto &op : sub_graph->nodes) { - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - auto attr = op->primitive->value.AsDepthwiseConv2D(); - if (attr == nullptr) { - MS_LOG(ERROR) << "attr is null"; - return RET_NULL_PTR; - } - if (attr->channelMultiplier > 1) { - // get channel attr - if (op->inputIndex.empty()) { - MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; - return RET_NULL_PTR; - } - const auto data_id = op->inputIndex[0]; - if (sub_graph->allTensors.size() <= data_id) { - MS_LOG(ERROR) << "the number of allTensors is less than " << data_id; - return RET_ERROR; - } - const auto &data_tensor = sub_graph->allTensors.at(data_id); - if (data_tensor == nullptr) { - MS_LOG(ERROR) << "the data tensor is null"; - return RET_NULL_PTR; - } - auto data_shape = data_tensor->dims; - if (data_shape.empty()) { - MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running"; - return RET_NO_CHANGE; - } - std::unique_ptr conv_attr = std::make_unique(); - if (conv_attr == nullptr) { - MS_LOG(ERROR) << "conv_attr is null"; - return RET_NULL_PTR; - } - if (data_shape[3] == 1) { - conv_attr->channelIn = data_shape[3]; - conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; - - // update attr - conv_attr->group = 1; - conv_attr->format = attr->format; - conv_attr->kernelH = attr->kernelH; - conv_attr->kernelW = attr->kernelW; - conv_attr->strideH = attr->strideH; - conv_attr->strideW = attr->strideW; - conv_attr->padMode = attr->padMode; - conv_attr->padUp = attr->padUp; - conv_attr->padDown = attr->padDown; - conv_attr->padLeft = attr->padLeft; - conv_attr->padRight = attr->padRight; - conv_attr->dilateH = attr->dilateH; - conv_attr->dilateW = attr->dilateW; - conv_attr->hasBias = attr->hasBias; - conv_attr->activationType = attr->activationType; - - op->primitive->value.type = schema::PrimitiveType_Conv2D; - op->primitive->value.value = conv_attr.release(); - - // update weight - auto weight_id = op->inputIndex[1]; - auto &weight_tensor = sub_graph->allTensors.at(weight_id); - if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { - auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter schema::Format failed."; - return RET_ERROR; - } - } else if (weight_tensor->dataType == kNumberTypeInt8) { - auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans filter format failed."; - return RET_ERROR; - } - } else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { - auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans filter format failed."; - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "The dataType of weight tensor is unsupported."; - return RET_ERROR; - } - weight_tensor->format = schema::Format::Format_CHWK; - } - } + const auto &tfliteModelBuffers = tflite_model_->buffers; + auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); + std::vector shape_vector; + (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + parameter->set_abstract(abstract_tensor); + parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter"); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape(tensor->shape); + param_value->set_tensor_type(GetTfliteDataType(tensor->type)); + param_value->set_format(schema::Format::Format_NHWC); + const auto &data = tfliteModelBuffers.at(tensor->buffer)->data; + if (!data.empty()) { + auto size = data.size(); + char *tensor_data = new (std::nothrow) char[size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; } + std::memcpy(tensor_data, data.data(), size); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + parameter->set_default_param(param_value); } return RET_OK; } -std::unique_ptr TfliteModelParser::ConstructMainGraph( - const std::unique_ptr &tflite_model, const QuantType &quant_type) { - MS_ASSERT(tflite_model != nullptr); - if (tflite_model->subgraphs.empty()) { - MS_LOG(ERROR) << "read tflite model main subgraphs failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - const auto &tflite_subgraph = tflite_model->subgraphs[0]; - - auto meta_graph = std::make_unique(); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "new meta graph failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); - return nullptr; - } - meta_graph->name = "MS_model converted by TF-Lite"; - quantType = quant_type; - // convert op - int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "parse op failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - - // convert tensor - status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "convert tensor failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; +STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is null, get output tensor failed."; + return RET_NULL_PTR; } - // set graph input/output - status = GetGraphInfo(tflite_subgraph, meta_graph.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "convert tensors failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + if (dst_cnode == nullptr) { + MS_LOG(ERROR) << "parameter is null, get output tensor failed."; + return RET_NULL_PTR; } - // update for depthwiseConv - status = ConvertGroupDepthwiseOp(meta_graph.get()); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "convert group depthwise conv failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + const auto &tflite_subgraph = tflite_model_->subgraphs.front(); + if (op->outputs.size() == 1) { + int output_idx = + op->outputs.front() < 0 ? tflite_subgraph->tensors.size() + op->outputs.front() : op->outputs.front(); + const auto &tensor = tflite_subgraph->tensors.at(output_idx); + std::vector shape_vector; + (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); + dst_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); + } else { + AbstractBasePtrList abstract_list; + int op_idx = 0; + for (auto output_idx : op->outputs) { + if (output_idx < 0) { + output_idx = output_idx + tflite_subgraph->tensors.size(); + } + const auto &tensor = tflite_subgraph->tensors.at(output_idx); + std::vector shape_vector; + (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); + abstract_list.emplace_back(std::make_shared(type_ptr, shape_vector)); + auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; + return RET_NULL_PTR; + } + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); + auto get_item_value = NewValueNode(MakeValue(op_idx)); + std::vector inputs{tuple_get_item_prim, dst_cnode, get_item_value}; + CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); + get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); + nodes_.insert(std::pair(output_idx, get_item_cnode)); + op_idx++; + } + dst_cnode->set_abstract(std::make_shared(abstract_list)); } - - return meta_graph; + return RET_OK; } - -schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - if (model_file.empty()) { - MS_LOG(ERROR) << "model_file is empty"; - return nullptr; - } - - // load graph - auto tflite_model = ReadTfliteModel(model_file.c_str()); - if (tflite_model == nullptr) { - MS_LOG(ERROR) << "read tflite model failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - - // construct main_meta_graph - auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type); - if (main_meta_graph == nullptr) { - MS_LOG(ERROR) << "ConstructMainGraph failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - - return main_meta_graph.release(); +MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + return nullptr; } -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 0ac93eeeaf..86e4810e4f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,67 +13,42 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef LITE_TFLITE_MODEL_PARSER_H +#define LITE_TFLITE_MODEL_PARSER_H -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H - -#include -#include -#include -#include -#include #include -#include -#include -#include #include -#include "securec/include/securec.h" +#include +#include #include "tools/converter/model_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include "tools/common/tensor_util.h" -#include "schema/inner/model_generated.h" namespace mindspore::lite { class TfliteModelParser : public ModelParser { public: - TfliteModelParser(); - - ~TfliteModelParser() override; - - schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quantTyp) override; - - protected: - std::unique_ptr ReadTfliteModel(const char *model_path); - - static STATUS CopyConstTensorData(const std::vector> &tflite_model_buffer, - const tflite::TensorT *tflite_tensor, schema::TensorT *tensor); - - static void SetTensorQuantParam(const std::unique_ptr &tflite_tensor, schema::TensorT *tensor); - - STATUS ConvertOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, const QuantType &quant_type, - schema::MetaGraphT *sub_graph); + TfliteModelParser() = default; - STATUS ConvertTensor(const std::unique_ptr &tflite_subgraph, - const std::vector> &tflite_model_buffer, - schema::MetaGraphT *sub_graph); + ~TfliteModelParser() override = default; - STATUS GetGraphInfo(const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph); - - static STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); - - QuantType quantType = QuantType_QUANT_NONE; - char *tfliteModelBuf = nullptr; - std::unique_ptr ConstructMainGraph(const std::unique_ptr &tflite_model, - const QuantType &quant_type); + FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; + MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; private: - TfliteTensorsInfo tensorsInfo; - std::vector tensors; - - std::map opMap; - std::map tfliteOpMap; + std::unordered_map nodes_; + std::unique_ptr tflite_model_; + FuncGraphPtr func_graph_; + char *tflite_model_buf_ = nullptr; + std::unique_ptr ReadTfliteModel(const char *model_path); + STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter); + STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode); + STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c); + STATUS ConvertOps(); + STATUS ConvertGraphInputs(); + STATUS ConvertGraphOutputs(); + STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector *quant_params); }; } // namespace mindspore::lite -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H +#endif // LITE_TFLITE_MODEL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 0cc6e1bebb..4fa0698f36 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -41,9 +41,9 @@ class TfliteNodeParser { virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) = 0; - virtual STATUS Parse(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, PrimitiveC *primitiveC) { - return RET_OK; + virtual lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + return nullptr; } static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index eb9d2ac4f9..16b6cf7348 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -64,6 +64,38 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op onehot attr failed"; + return nullptr; + } + auto axis = tflite_attr->axis; + const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (tensor == nullptr) { + MS_LOG(ERROR) << "tensor is null"; + return nullptr; + } + attr->axis = axis; + + primitive->value.type = schema::PrimitiveType_OneHot; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h index 518ad70878..2fd07b1cc4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -32,6 +32,8 @@ class TfliteOneHotParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index 3424a58167..231843174f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -90,6 +90,59 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_PAD) { + const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op pad attr failed"; + return nullptr; + } + attr->paddingMode = schema::PaddingMode_CONSTANT; + attr->constantValue = 0.0f; + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { + MS_LOG(ERROR) << "get pad -> paddings failed"; + return nullptr; + } + } else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) { + const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op pad attr failed"; + return nullptr; + } + switch (tflite_attr->mode) { + case tflite::MirrorPadMode_REFLECT: + attr->paddingMode = schema::PaddingMode_REFLECT; + break; + case tflite::MirrorPadMode_SYMMETRIC: + attr->paddingMode = schema::PaddingMode_SYMMETRIC; + break; + default: + MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; + return nullptr; + } + } else { + MS_LOG(ERROR) << "this pad:" << tflite_op_type << " hasn't been supported"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Pad; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h index da2b6d26f7..c166b3ecd1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -32,6 +32,8 @@ class TflitePadParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index c7af3f6419..9dd2843655 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -19,8 +19,7 @@ #include #include -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { @@ -43,17 +42,13 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_NULL_PTR; } - std::vector node_name_str; - Split(op->name, &node_name_str, "-"); - const char *node_name = node_name_str.data()->c_str(); - if (std::strcmp(node_name, "MeanPooling") == 0) { - MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) { attr->poolingMode = schema::PoolMode_MEAN_POOLING; - } else if (std::strcmp(node_name, "MaxPooling") == 0) { - MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; + } else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) { attr->poolingMode = schema::PoolMode_MAX_POOLING; } else { - MS_LOG(ERROR) << node_name << " hasn't been supported"; + MS_LOG(ERROR) << "pooling mode " << tflite_op_type << " hasn't been supported"; return RET_NOT_FIND_OP; } @@ -75,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un // calculate pad params auto data_index = tflite_op->inputs[0]; const auto &data_tensor = tflite_subgraph->tensors[data_index]; - std::vector params; + std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -95,8 +90,58 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } + const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op pooling attr failed"; + return nullptr; + } + attr->windowW = tflite_attr->filter_width; + attr->windowH = tflite_attr->filter_height; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format::Format_NHWC; + + attr->global = false; + attr->roundMode = schema::RoundMode_FLOOR; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + + // calculate pad params + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; + std::vector params; + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "get padding params failed"; + return nullptr; + } else if (status == RET_OK) { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_Pooling; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser()); TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser()); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h index 56b17d7635..48ec978ef3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TflitePoolingParser : public TfliteNodeParser { public: TflitePoolingParser() : TfliteNodeParser("node_name") {} @@ -32,8 +31,9 @@ class TflitePoolingParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_POOLING_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc index 1026e5ea19..42c49aa7e8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc @@ -52,6 +52,24 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->channelShared = true; + primitive->value.type = schema::PrimitiveType_PReLU; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h index ef4a9dc5c8..8d46d72be2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h @@ -32,6 +32,8 @@ class TflitePReLUParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index c8d7bb28e0..7688b55cfc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -74,6 +74,50 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "input tensor is null"; + return nullptr; + } + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "output tensor is null"; + return nullptr; + } + if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && + (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || + GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + primitive->value.type = schema::PrimitiveType_QuantDTypeCast; + primitive->value.value = attr.release(); + } else { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + primitive->value.type = schema::PrimitiveType_Cast; + primitive->value.value = attr.release(); + } + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h index 1b1cd3cc3c..de92daa604 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h @@ -31,6 +31,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index 867482a1db..9dda4059e2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -71,6 +71,43 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->dType = 0; + std::vector limit; + std::vector delta; + int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "range -> limit get failed"; + return nullptr; + } else if (status == RET_OK) { + status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> end get failed"; + return nullptr; + } + } + if (status == RET_OK) { + attr->limit = limit.front(); + attr->delta = delta.front(); + } + primitive->value.type = schema::PrimitiveType_Range; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h index eab8849449..cdebb1bc39 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -32,6 +32,8 @@ class TfliteRangeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc index fbc2f86f36..bb1cd0323e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -50,6 +50,24 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Rank; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h index d105387714..629026be8d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -32,6 +32,8 @@ class TfliteRankParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index c5f9183ba1..921b583491 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -85,11 +85,65 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op reduce attr failed"; + return nullptr; + } + attr->keepDims = tflite_attr->keep_dims; + + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MAX) { + MS_LOG(DEBUG) << "parse TfliteReduceMaxParser"; + attr->mode = schema::ReduceMode_ReduceMax; + } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MIN) { + MS_LOG(DEBUG) << "parse TfliteReduceMinParser"; + attr->mode = schema::ReduceMode_ReduceMin; + } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_PROD) { + MS_LOG(DEBUG) << "parse TfliteReduceProdParser"; + attr->mode = schema::ReduceMode_ReduceProd; + } else if (tflite_op_type == tflite::BuiltinOperator_SUM) { + MS_LOG(DEBUG) << "parse TfliteSumParser"; + attr->mode = schema::ReduceMode_ReduceSum; + } else if (tflite_op_type == tflite::BuiltinOperator_MEAN) { + MS_LOG(DEBUG) << "parse TfliteMeanParser"; + attr->mode = schema::ReduceMode_ReduceMean; + } else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_ANY) { + // attr->mode; + MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) { + MS_LOG(ERROR) << "get reduce -> axes failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Reduce; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} -TfliteNodeRegister g_tfliteSumParser("Sum", new TfliteReduceParser()); -TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteReduceParser()); -TfliteNodeRegister g_tfliteReduceMaxParser("ReduceMax", new TfliteReduceParser()); -TfliteNodeRegister g_tfliteReduceMinParser("ReduceMin", new TfliteReduceParser()); -TfliteNodeRegister g_tfliteReduceProdParser("ReduceProd", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteMeanParser("Mean", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceParser()); +TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h index a9340b3e1d..4ecfa5ade5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteReduceParser : public TfliteNodeParser { public: TfliteReduceParser() : TfliteNodeParser("node_name") {} @@ -32,8 +31,9 @@ class TfliteReduceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index baa5b72f21..92185346c3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -18,8 +18,7 @@ #include #include -namespace mindspore { -namespace lite { +namespace mindspore::lite { STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { @@ -43,8 +42,8 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_NULL_PTR; } - const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions(); - if (tfliteAttr == nullptr) { + const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions(); + if (tflite_attr == nullptr) { if (tflite_op->inputs.size() < 2) { MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); return RET_ERROR; @@ -68,9 +67,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un } } else { attr->format = schema::Format::Format_NHWC; - attr->shape.resize(tfliteAttr->new_shape.size()); - for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) { - attr->shape[i] = tfliteAttr->new_shape[i]; + attr->shape.resize(tflite_attr->new_shape.size()); + for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) { + attr->shape[i] = tflite_attr->new_shape[i]; } } @@ -83,7 +82,50 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + const auto &tflite_subgraph = tflite_model->subgraphs.front(); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions(); + if (tflite_attr == nullptr) { + if (tflite_op->inputs.size() < 2) { + MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); + return nullptr; + } + auto shape_tensor_index = tflite_op->inputs[1]; + const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index]; + if (shape_tensor == nullptr) { + MS_LOG(ERROR) << "shape_tensor is null"; + return nullptr; + } + auto &buf_data = tflite_model->buffers[shape_tensor->buffer]; + if (buf_data == nullptr) { + MS_LOG(ERROR) << "buf_data is null"; + return nullptr; + } + if (!buf_data->data.empty()) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) { + MS_LOG(ERROR) << "get reshape -> shape failed"; + return nullptr; + } + } + } else { + attr->format = schema::Format::Format_NHWC; + attr->shape.resize(tflite_attr->new_shape.size()); + for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) { + attr->shape[i] = tflite_attr->new_shape[i]; + } + } + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_Reshape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h index e6eff70a36..9ac8908245 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteReshapeParser : public TfliteNodeParser { public: TfliteReshapeParser() : TfliteNodeParser("Reshape") {} @@ -32,8 +31,10 @@ class TfliteReshapeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + + lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index b385e5e7f3..50db0619c8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -120,6 +120,89 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON; + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + if (tflite_op_type == tflite::BuiltinOperator_RESIZE_BILINEAR) { + MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op ResizeBilinear attr failed"; + return nullptr; + } + if (tfliteAttr->align_corners) { + attr->alignCorners = tfliteAttr->align_corners; + attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; + } + if (tfliteAttr->half_pixel_centers) { + attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON + ? schema::CoordinateTransformMode_TF_HALF_PIXEL + : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); + } + attr->method = schema::ResizeMethod_LINEAR; + } else if (tflite_op_type == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) { + MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); + if (tfliteAttr == nullptr) { + MS_LOG(ERROR) << "get op ResizeNearestNeighbor attr failed"; + return nullptr; + } + if (tfliteAttr->align_corners) { + attr->alignCorners = tfliteAttr->align_corners; + attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; + } + if (tfliteAttr->half_pixel_centers) { + attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON + ? schema::CoordinateTransformMode_TF_HALF_PIXEL + : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); + } + attr->method = schema::ResizeMethod_NEAREST; + attr->nearestMode = schema::NearestMode_NORMAL; + } else { + MS_LOG(ERROR) << "wrong resize type"; + return nullptr; + } + + attr->format = schema::Format::Format_NHWC; + attr->preserveAspectRatio = false; + + auto tfliteResizeTensorIndex = tflite_op->inputs[1]; + const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex]; + if (shape_tensor == nullptr) { + MS_LOG(ERROR) << "shape_tensor is null"; + return nullptr; + } + auto resizeTensorBufferIndex = shape_tensor->buffer; + const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex); + if (buff == nullptr) { + MS_LOG(ERROR) << "buff_data is null"; + return nullptr; + } + auto buffData = reinterpret_cast(buff->data.data()); + if (buffData != nullptr) { + auto height = buffData[0]; + auto width = buffData[1]; + attr->newWidth = width; + attr->newHeight = height; + } + + primitive->value.type = schema::PrimitiveType_Resize; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser()); TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h index ead33d090a..896ef20bab 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h @@ -23,8 +23,7 @@ #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { class TfliteResizeParser : public TfliteNodeParser { public: TfliteResizeParser() : TfliteNodeParser("node_name") {} @@ -32,8 +31,9 @@ class TfliteResizeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index d92777d560..7368eced67 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -55,6 +55,30 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) { + MS_LOG(ERROR) << "get reverse -> axis failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Reverse; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h index 278e3f7ece..0b1a270df4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -32,6 +32,8 @@ class TfliteReverseParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index d0f9db6f52..c5f8b7fa6a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -62,6 +62,32 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op reverse attr failed"; + return nullptr; + } + attr->seqAxis = tflite_attr->seq_dim; + attr->batchAxis = tflite_attr->batch_dim; + + primitive->value.type = schema::PrimitiveType_ReverseSequence; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h index d183297ced..ab11a933d0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -32,6 +32,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index 15eb9dc3dc..5408eee8e4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -58,6 +58,29 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op ScatterNd attr failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_ScatterND; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h index 064a3d70c7..7ea31fadd1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -32,6 +32,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc index b28b67d369..8853fca811 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -50,6 +50,24 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Shape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h index 42013f29a8..5085c15889 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -32,6 +32,8 @@ class TfliteShapeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc index 89a1a0cac8..6378f71adf 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -59,6 +59,33 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op attr failed"; + return nullptr; + } + attr->includeAllGrams = tflite_attr->include_all_ngrams; + attr->maxSkipSize = tflite_attr->max_skip_size; + attr->ngramSize = tflite_attr->ngram_size; + + primitive->value.type = schema::PrimitiveType_SkipGram; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h index 56d80d05b8..369b8ed13f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h @@ -32,6 +32,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index e58441a40c..d5a62250e0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -66,6 +66,41 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->format = schema::Format::Format_NHWC; + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) { + MS_LOG(ERROR) << "get slice -> begin failed"; + return nullptr; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) { + MS_LOG(ERROR) << "get slice -> size failed"; + return nullptr; + } + std::vector axes; + axes.clear(); + for (size_t i = 0; i < attr->begin.size(); ++i) { + axes.push_back(i); + } + attr->axes = axes; + primitive->value.type = schema::PrimitiveType_Slice; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h index 48801162b6..f6ffd1ddb1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -32,6 +32,8 @@ class TfliteSliceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 53a36fc93c..27d515f924 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -17,7 +17,6 @@ #include "tools/converter/parser/tflite/tflite_softmax_parser.h" #include #include -#include "src/ops/softmax.h" namespace mindspore::lite { STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, @@ -53,12 +52,19 @@ STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_OK; } -STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, PrimitiveC *primitiveC) { - auto softmaxPrimitive = new SoftMax(); - softmaxPrimitive->SetAxis(-1); - primitiveC = softmaxPrimitive; - return RET_OK; +PrimitiveC *TfliteSoftmaxParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->axis = -1; + auto primitive = std::make_unique(); + primitive->value.type = schema::PrimitiveType_SoftMax; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } TfliteNodeRegister g_tfliteSoftmaxParser("Softmax", new TfliteSoftmaxParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h index f5da7a7ec8..9339feba60 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -31,8 +31,9 @@ class TfliteSoftmaxParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; - STATUS Parse(const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, - PrimitiveC *primitiveC) override; + + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index 297c047912..fcebe459e8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -62,6 +62,34 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSpaceToBatchNDParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { + MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; + return nullptr; + } + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { + MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_SpaceToBatchND; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSpaceToBatchNDParser("SpaceToBatchND", new TfliteSpaceToBatchNDParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h index cfd2794c07..fa55f1d681 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -32,6 +32,8 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index c139e432b6..863e0c5d2c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -60,6 +60,32 @@ STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSpaceToDepthParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSpaceToDepthOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op space to depth attr failed"; + return nullptr; + } + attr->blockSize = tflite_attr->block_size; + attr->format = schema::Format::Format_NHWC; + + primitive->value.type = schema::PrimitiveType_SpaceToDepth; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSpaceToDepthParser("SpaceToDepth", new TfliteSpaceToDepthParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h index b3b6370817..f2c4c110ae 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -32,6 +32,8 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index 789dd965ed..b1cd394429 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -57,6 +57,25 @@ STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSparseToDenseParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->validateIndices = false; + primitive->value.type = schema::PrimitiveType_SparseToDense; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSparseToDenseParser("SparseToDense", new TfliteSparseToDenseParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h index 6f9a56f1ba..6cefe92a67 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -32,6 +32,8 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index c38cf5def2..12bd11c550 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -96,6 +96,63 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq } return RET_OK; } +PrimitiveC *TfliteSplitParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + auto &tflite_subgraph = tflite_model->subgraphs.front(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSplitOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op split attr failed"; + return nullptr; + } + auto num_splits = tflite_attr->num_splits; + + const auto &shape_tensor = tflite_subgraph->tensors[tflite_op->inputs[1]]; + if (shape_tensor == nullptr) { + MS_LOG(ERROR) << "shape_tensor is null"; + return nullptr; + } + const auto tensor_shape = shape_tensor->shape; + const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (axis_tensor == nullptr) { + MS_LOG(ERROR) << "axis_tensor is null"; + return nullptr; + } + auto axis = *(reinterpret_cast(tflite_model->buffers[axis_tensor->buffer]->data.data())); + if (axis < 0) { + axis += tensor_shape.size(); + } + if (axis >= static_cast(tensor_shape.size())) { + MS_LOG(ERROR) << "axis value is too large"; + return nullptr; + } + attr->splitDim = axis; + if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) { + MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; + return nullptr; + } + attr->numberSplit = num_splits; + if (tensor_shape[axis] / num_splits != 0) { + for (int i = 0; i < num_splits; i++) { + attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); + } + } + + primitive->value.type = schema::PrimitiveType_Split; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSplitParser("Split", new TfliteSplitParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h index a430eed7dd..b6ec9ed66c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -32,6 +32,8 @@ class TfliteSplitParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index b02ccd484b..962ab8992a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -91,6 +91,58 @@ STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } return RET_OK; } +PrimitiveC *TfliteSplitVParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSplitVOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op splitv attr failed"; + return nullptr; + } + attr->numberSplit = tflite_attr->num_splits; + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->sizeSplits)) { + MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; + return nullptr; + } + + const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; + if (tensor == nullptr) { + MS_LOG(ERROR) << "tensor_shape is null"; + return nullptr; + } + auto tensor_shape = tensor->shape; + const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[2]]; + if (axis_tensor == nullptr) { + MS_LOG(ERROR) << "axis_tensor is null"; + return nullptr; + } + auto axis = *(reinterpret_cast(tflite_model->buffers[axis_tensor->buffer]->data.data())); + if (axis < 0) { + axis += tensor_shape.size(); + } + if (axis >= static_cast(tensor_shape.size())) { + MS_LOG(ERROR) << "axis value is too large"; + return nullptr; + } + attr->splitDim = axis; + + primitive->value.type = schema::PrimitiveType_Split; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSplitVParser("SplitV", new TfliteSplitVParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h index 002a8f7e9b..b167deabb9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -32,6 +32,8 @@ class TfliteSplitVParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index 61de999fc9..58789ba7be 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -57,6 +57,31 @@ STATUS TfliteSqueezeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteSqueezeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsSqueezeOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op squeeze attr failed"; + return nullptr; + } + attr->axis = tflite_attr->squeeze_dims; + + primitive->value.type = schema::PrimitiveType_Squeeze; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteSqueezeParser("Squeeze", new TfliteSqueezeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h index 538a4f4251..6173434a43 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -32,6 +32,8 @@ class TfliteSqueezeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc index f810b4f611..2e684cb0a8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -62,6 +62,35 @@ STATUS TfliteStackParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteStackParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsPackOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op stack attr failed"; + return nullptr; + } + attr->axis = tflite_attr->axis; + attr->n = tflite_attr->values_count; + attr->isScale.assign(tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.end()); + + primitive->value.type = schema::PrimitiveType_Stack; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteStackParser("Stack", new TfliteStackParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h index e67d5f47b5..263e08b114 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -32,6 +32,8 @@ class TfliteStackParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index b9dab37019..48a7a34887 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -84,6 +84,56 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteStridedSliceParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsStridedSliceOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op strideslice attr failed"; + return nullptr; + } + attr->beginMask = tflite_attr->begin_mask; + attr->endMask = tflite_attr->end_mask; + attr->ellipsisMask = tflite_attr->ellipsis_mask; + attr->newAxisMask = tflite_attr->new_axis_mask; + attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; + + int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> begin get failed"; + return nullptr; + } else if (status == RET_OK) { + status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->end); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> end get failed"; + return nullptr; + } else if (status == RET_OK) { + status = GetTfliteData(tflite_op->inputs[3], tflite_subgraph->tensors, tflite_model->buffers, attr->stride); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> stride get failed"; + return nullptr; + } + } + } + attr->isScale.assign(tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.end()); + + primitive->value.type = schema::PrimitiveType_StridedSlice; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteStridedSliceParser("StridedSlice", new TfliteStridedSliceParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h index 36171e2061..ff538782ab 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -32,6 +32,8 @@ class TfliteStridedSliceParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index 74a2323d4b..975f2be160 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -60,6 +60,34 @@ STATUS TfliteTileParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteTileParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->multiples)) { + MS_LOG(ERROR) << "get tile -> multiples failed"; + return nullptr; + } + std::vector dims(attr->multiples.size(), 0); + for (size_t i = 0; i < dims.size(); ++i) { + dims[i] = i; + } + attr->dims = dims; + primitive->value.type = schema::PrimitiveType_Tile; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteTileParser("Tile", new TfliteTileParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h index 11d703e276..e548d81645 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -32,6 +32,8 @@ class TfliteTileParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index 0e3ff760f6..d3db4ca03b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -62,6 +62,33 @@ STATUS TfliteTopKV2Parser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } return RET_OK; } +PrimitiveC *TfliteTopKV2Parser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + attr->sorted = true; + std::vector k; + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, k)) { + MS_LOG(ERROR) << "get topKV2 -> k failed"; + return nullptr; + } + attr->k = k.front(); + + primitive->value.type = schema::PrimitiveType_TopK; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteTopKV2Parser("TopKV2", new TfliteTopKV2Parser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h index f9f2c9b83b..e42e45c514 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -32,6 +32,8 @@ class TfliteTopKV2Parser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index 4792adb14c..7f928dd2e3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -57,6 +57,31 @@ STATUS TfliteTransposeParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteTransposeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->perm)) { + MS_LOG(ERROR) << "get transpose -> perm failed"; + return nullptr; + } + + attr->conjugate = false; + primitive->value.type = schema::PrimitiveType_Transpose; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteTransposeParser("Transpose", new TfliteTransposeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h index 6babc266e1..c76608943a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -32,6 +32,8 @@ class TfliteTransposeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc index f6951eb869..c2610ab3c5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -60,6 +60,31 @@ STATUS TfliteUniqueParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } return RET_OK; } +PrimitiveC *TfliteUniqueParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsUniqueOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op unique attr failed"; + return nullptr; + } + attr->outType = GetTfliteDataType(tflite_attr->idx_out_type); + + primitive->value.type = schema::PrimitiveType_Unique; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteUniqueParser("Unique", new TfliteUniqueParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h index 4e7c98ac5b..ff2bb7df07 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -32,6 +32,8 @@ class TfliteUniqueParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index 76bf61d1c3..d495f54d34 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -61,6 +61,32 @@ STATUS TfliteUnstackParser::Parse(TfliteTensorsInfo *tensors_info, const std::un } return RET_OK; } +PrimitiveC *TfliteUnstackParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsUnpackOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op unstack attr failed"; + return nullptr; + } + attr->num = tflite_attr->num; + attr->axis = tflite_attr->axis; + + primitive->value.type = schema::PrimitiveType_Unstack; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteUnstackParser("Unstack", new TfliteUnstackParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h index 873121b31b..9bac2395f7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -32,6 +32,8 @@ class TfliteUnstackParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 01c309a347..22a1c345ae 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -42,7 +42,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_TRANSPOSE, "Transpose"}, {tflite::BuiltinOperator_PACK, "Stack"}, {tflite::BuiltinOperator_MEAN, "Mean"}, - {tflite::BuiltinOperator_RELU6, "Relu6"}, + {tflite::BuiltinOperator_RELU6, "ReLU6"}, {tflite::BuiltinOperator_TANH, "Tanh"}, {tflite::BuiltinOperator_RSQRT, "Rsqrt"}, {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, @@ -51,7 +51,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_TRANSPOSE_CONV, "DeConv2D"}, {tflite::BuiltinOperator_PAD, "Pad"}, {tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, "NearestNeighbor"}, - {tflite::BuiltinOperator_RELU, "Relu"}, + {tflite::BuiltinOperator_RELU, "ReLU"}, {tflite::BuiltinOperator_LEAKY_RELU, "LeakyRelu"}, {tflite::BuiltinOperator_SQUEEZE, "Squeeze"}, {tflite::BuiltinOperator_POW, "Pow"}, @@ -87,7 +87,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_LOGICAL_NOT, "LogicalNot"}, {tflite::BuiltinOperator_LOGICAL_AND, "LogicalAnd"}, {tflite::BuiltinOperator_LOGICAL_OR, "LogicalOr"}, - {tflite::BuiltinOperator_HARD_SWISH, "HardSwish"}, + {tflite::BuiltinOperator_HARD_SWISH, "HSwish"}, {tflite::BuiltinOperator_SUM, "Sum"}, {tflite::BuiltinOperator_REDUCE_PROD, "ReduceProd"}, {tflite::BuiltinOperator_REDUCE_MAX, "ReduceMax"}, @@ -171,6 +171,16 @@ schema::PadMode GetPadMode(tflite::Padding tflite_padmode) { } } +std::string GetPadModeStr(tflite::Padding tflite_padmode) { + if (tflite_padmode == tflite::Padding_SAME) { + return "same"; + } else if (tflite_padmode == tflite::Padding_VALID) { + return "valid"; + } else { + return "pad"; + } +} + size_t GetDataTypeSize(const TypeId &data_type) { switch (data_type) { case TypeId::kNumberTypeFloat32: @@ -194,7 +204,7 @@ size_t GetDataTypeSize(const TypeId &data_type) { } STATUS getPaddingParam(const std::unique_ptr &tensor, schema::PadMode pad_mode, int strideH, - int strideW, int windowH, int windowW, std::vector *params) { + int strideW, int windowH, int windowW, std::vector *params) { if (tensor == nullptr) { MS_LOG(ERROR) << "the input tensor is null"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h index f1f816f7fc..a2769e35a0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -32,6 +32,8 @@ namespace mindspore { namespace lite { schema::PadMode GetPadMode(tflite::Padding tflite_padmode); +std::string GetPadModeStr(tflite::Padding tflite_padmode); + size_t GetDataTypeSize(const TypeId &data_type); schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType); @@ -41,7 +43,7 @@ std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType); TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); STATUS getPaddingParam(const std::unique_ptr &tensor, schema::PadMode pad_mode, int strideH, - int strideW, int windowH, int windowW, std::vector *params); + int strideW, int windowH, int windowW, std::vector *params); void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc index d5985a7275..61246c0016 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -58,6 +58,30 @@ STATUS TfliteWhereParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteWhereParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto &tflite_subgraph = tflite_model->subgraphs.front(); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + if (GetTfliteData(tflite_op->inputs[0], tflite_subgraph->tensors, tflite_model->buffers, attr->condition)) { + MS_LOG(ERROR) << "get where -> condition failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_Where; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteWhereParser("Where", new TfliteWhereParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h index dcb97bfedc..c28cb5cf96 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -32,6 +32,8 @@ class TfliteWhereParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc index 68aff392e6..b0792a3083 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc @@ -63,6 +63,33 @@ STATUS TfliteWhileParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq } return RET_OK; } +PrimitiveC *TfliteWhileParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsWhileOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op while attr failed"; + return nullptr; + } + + attr->condSubgraphIndex = tflite_attr->cond_subgraph_index; + attr->bodySubgraphIndex = tflite_attr->body_subgraph_index; + + primitive->value.type = schema::PrimitiveType_While; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteWhileParser("While", new TfliteWhileParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h index 258c0caa28..e19cf8ce6a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h @@ -32,6 +32,8 @@ class TfliteWhileParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index 09c02fa032..19334623f3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -52,6 +52,24 @@ STATUS TfliteZerosLikeParser::Parse(TfliteTensorsInfo *tensors_info, AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } +PrimitiveC *TfliteZerosLikeParser::ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is null"; + return nullptr; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_ZerosLike; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} TfliteNodeRegister g_tfliteZerosLikeParser("ZerosLike", new TfliteZerosLikeParser()); } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h index 9a412ee20d..f9e55cc0e0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -32,6 +32,8 @@ class TfliteZerosLikeParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; + PrimitiveC *ParseLitePrimitive(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index baadfc4c28..4c47422e37 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -1182,6 +1182,29 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for return RET_ERROR; } } break; + case schema::Format::Format_CHWK: { + switch (src_format) { + case schema::Format::Format_KHWC: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKHWC2CHWK); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKHWC2CHWK); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKHWC2CHWK); + } else if (data_type == kNumberTypeFloat16) { + status = TransFilterFormat(tensor, kKHWC2CHWK); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CHWK: + return RET_OK; + default: + MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } + } break; default: MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); return RET_ERROR; diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc new file mode 100644 index 0000000000..ff31b15eec --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" +#include +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "src/ops/primitive_c.h" +#include "schema/inner/model_generated.h" +#include "src/tensor.h" +#include "tools/converter/quantizer/quant_cast.h" +#include "src/common/log_adapter.h" +#include "securec/include/securec.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore::opt { +namespace { +constexpr size_t kConvWeightIndex = 2; +constexpr size_t kConvInputIndex = 1; +} // namespace +bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + if (opt::GetCNodeType(node) != schema::PrimitiveType_DepthwiseConv2D) { + continue; + } + + auto depthwise_cnode = node->cast(); + auto depthwise_primitivec = GetValueNode>(depthwise_cnode->input(0)); + auto attr = depthwise_primitivec->primitiveT()->value.AsDepthwiseConv2D(); + if (attr == nullptr) { + MS_LOG(ERROR) << "the input of depthwiseConv2d is null"; + return false; + } + + auto data_node = depthwise_cnode->input(kConvInputIndex)->abstract(); + auto data_shape = utils::cast(data_node->GetShapeTrack())->shape(); + + auto conv_attr = std::make_unique(); + if (conv_attr == nullptr) { + MS_LOG(ERROR) << "conv_attr is null"; + return false; + } + + if (data_shape[3] == 1) { + conv_attr->channelIn = data_shape[3]; + conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; + + // update attr + conv_attr->group = 1; + conv_attr->format = attr->format; + conv_attr->kernelH = attr->kernelH; + conv_attr->kernelW = attr->kernelW; + conv_attr->strideH = attr->strideH; + conv_attr->strideW = attr->strideW; + conv_attr->padMode = attr->padMode; + conv_attr->padUp = attr->padUp; + conv_attr->padDown = attr->padDown; + conv_attr->padLeft = attr->padLeft; + conv_attr->padRight = attr->padRight; + conv_attr->dilateH = attr->dilateH; + conv_attr->dilateW = attr->dilateW; + conv_attr->hasBias = attr->hasBias; + conv_attr->activationType = attr->activationType; + + depthwise_primitivec->primitiveT()->value.type = schema::PrimitiveType_Conv2D; + depthwise_primitivec->primitiveT()->value.value = conv_attr.release(); + + MS_ASSERT(depthwise_cnode->inputs().size() > kConvWeightIndex); + auto weight_node = depthwise_cnode->input(kConvWeightIndex); + MS_ASSERT(weight_node != nullptr); + auto weight_value = GetLiteParamValue(weight_node); + if (weight_value == nullptr) { + MS_LOG(ERROR) << "weight node must param value"; + return false; + } + MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || + weight_value->tensor_type() == TypeId::kNumberTypeInt8); + lite::STATUS status; + schema::Format weight_dst_format = schema::Format::Format_CHWK; + weight_value->set_format(schema::Format_KHWC); + status = TransFilterFormat(weight_value, weight_dst_format); + if (status == RET_OK) { + weight_value->set_format(weight_dst_format); + } else { + MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To" + << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope(); + return false; + } + auto type_id = static_cast(weight_value->tensor_type()); + auto type_ptr = TypeIdToType(type_id); + auto shape = weight_value->tensor_shape(); + std::vector shape_vector; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + weight_node->set_abstract(abstract_tensor); + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h new file mode 100644 index 0000000000..fd696c22e5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_GROUP_DEPTHWISE_OP_CONVERT_PASS_H +#define LITE_GROUP_DEPTHWISE_OP_CONVERT_PASS_H +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" + +namespace mindspore::opt { +class GroupDepthwiseOpConvertPass : public Pass { + public: + GroupDepthwiseOpConvertPass() : Pass("group_depthwise_op_convert_pass") {} + ~GroupDepthwiseOpConvertPass() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace mindspore::opt +#endif // LITE_GROUP_DEPTHWISE_OP_CONVERT_PASS_H diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc new file mode 100644 index 0000000000..2f8b08c1bf --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "schema/inner/model_generated.h" +#include "tools/converter/quantizer/quant_cast.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore::opt { +namespace { +constexpr size_t split_inputs_size = 3; +} // namespace +bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + auto primitive_c = GetValueNode>(cnode->input(0)); + if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) { + cnode->set_input(1, cnode->input(3)); + auto inputs = cnode->inputs(); + inputs.pop_back(); + cnode->set_inputs(inputs); + continue; + } + + if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) { + cnode->set_input(1, cnode->input(2)); + continue; + } + + if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce || + opt::GetCNodeType(node) == schema::PrimitiveType_StridedSlice || + opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin || + opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax || + opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch || + opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpace || + opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatchND || + opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpaceND || + opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToDepth || + (opt::GetCNodeType(node) == schema::PrimitiveType_Pad && + primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT) || + (opt::GetCNodeType(node) == schema::PrimitiveType_Resize && + primitive_c->primitiveT()->value.AsResize()->newHeight != 0 && + primitive_c->primitiveT()->value.AsResize()->newWidth != 0)) { + std::vector new_inputs; + new_inputs.emplace_back(cnode->input(0)); + new_inputs.emplace_back(cnode->input(1)); + cnode->set_inputs(new_inputs); + continue; + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.h b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.h new file mode 100644 index 0000000000..566cec6090 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_TFLITE_INPUTS_ORDER_EXCHANGE_PASS_H +#define LITE_TFLITE_INPUTS_ORDER_EXCHANGE_PASS_H + +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" + +namespace mindspore::opt { +class TfliteInputsOrderExchangePass : public Pass { + public: + TfliteInputsOrderExchangePass() : Pass("tflite_inputs_order_exchange_pass") {} + ~TfliteInputsOrderExchangePass() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace mindspore::opt +#endif // LITE_TFLITE_INPUTS_ORDER_EXCHANGE_PASS_H