| @@ -23,71 +23,6 @@ | |||
| #include "tools/converter/parser/tflite/tflite_util.h" | |||
| namespace mindspore::lite { | |||
| STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Relu") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteReluParser"; | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (std::strcmp(node_name, "Relu6") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteRelu6Parser"; | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } else if (std::strcmp(node_name, "Tanh") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteTanhParser"; | |||
| attr->type = schema::ActivationType_TANH; | |||
| } else if (std::strcmp(node_name, "Logistic") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (std::strcmp(node_name, "Swish") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSwishParser"; | |||
| attr->type = schema::ActivationType_SWISH; | |||
| } else if (std::strcmp(node_name, "HardSwish") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; | |||
| attr->type = schema::ActivationType_HSWISH; | |||
| } else if (std::strcmp(node_name, "LeakyRelu") == 0) { | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->alpha = tflite_attr->alpha; | |||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||
| @@ -117,11 +52,10 @@ lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_p | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteReluParser(tflite::BuiltinOperator_RELU, new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteRelu6Parser(tflite::BuiltinOperator_RELU6, new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteTanhParser(tflite::BuiltinOperator_TANH, new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteSwishParser(tflite::BuiltinOperator_HARD_SWISH, new TfliteActivationParser()); | |||
| TfliteNodeRegister g_tfliteLogisticParser(tflite::BuiltinOperator_LOGISTIC, new TfliteActivationParser()); | |||
| TfliteNodeRegister g_TfliteLeakyReluParser(tflite::BuiltinOperator_LEAKY_RELU, new TfliteActivationParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,10 +28,6 @@ class TfliteActivationParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -22,40 +22,6 @@ | |||
| #include "src/ops/addn.h" | |||
| namespace mindspore::lite { | |||
| STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::AddNT> attr = std::make_unique<schema::AddNT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->N = tflite_subgraph->tensors.size() - 1; | |||
| op->primitive->value.type = schema::PrimitiveType_AddN; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto attr = std::make_unique<schema::AddNT>(); | |||
| @@ -69,5 +35,5 @@ lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tfl | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser()); | |||
| TfliteNodeRegister g_tfliteAddNParser(tflite::BuiltinOperator_ADD_N, new TfliteAddNParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -29,10 +29,6 @@ class TfliteAddNParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteAddNParser() : TfliteNodeParser("AddN") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,63 +19,7 @@ | |||
| #include <vector> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->outMaxValue = false; | |||
| attr->topK = 1; | |||
| attr->keepDims = false; | |||
| attr->axisType = 1; | |||
| // get axis attr | |||
| auto axis_idx = tflite_op->inputs[1]; | |||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto buffer_idx = axis_tensor->buffer; | |||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||
| if (buf_data == nullptr) { | |||
| MS_LOG(ERROR) << "the buf data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_ptr = buf_data->data.data(); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "the data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | |||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -110,6 +54,5 @@ PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteArgmaxParser(tflite::BuiltinOperator_ARG_MAX, new TfliteArgmaxParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -29,10 +29,6 @@ class TfliteArgmaxParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,63 +19,7 @@ | |||
| #include <vector> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->outMaxValue = false; | |||
| attr->topK = 1; | |||
| attr->keepDims = false; | |||
| attr->axisType = 1; | |||
| // get axis attr | |||
| auto axis_idx = tflite_op->inputs[1]; | |||
| auto axis_tensor = tflite_subgraph->tensors[axis_idx].get(); | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto buffer_idx = axis_tensor->buffer; | |||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||
| if (buf_data == nullptr) { | |||
| MS_LOG(ERROR) << "the buf data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_ptr = buf_data->data.data(); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "the data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | |||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -110,6 +54,5 @@ PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteArgminParser(tflite::BuiltinOperator_ARG_MIN, new TfliteArgminParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -23,19 +23,14 @@ | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| class TfliteArgminParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteArgminParser() : TfliteNodeParser("Argmin") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H | |||
| @@ -19,166 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Add") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||
| auto attr = std::make_unique<schema::AddT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); | |||
| if (nullptr == tfliteAttr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Sub") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSubParser"; | |||
| auto attr = std::make_unique<schema::SubT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); | |||
| if (nullptr == tfliteAttr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_Sub; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Mul") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||
| auto attr = std::make_unique<schema::MulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); | |||
| if (nullptr == tfliteAttr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Div") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||
| auto attr = std::make_unique<schema::DivT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); | |||
| if (nullptr == tfliteAttr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "FloorDiv") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | |||
| std::unique_ptr<schema::FloorDivT> attr = std::make_unique<schema::FloorDivT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_FloorDiv; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "FloorMod") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | |||
| auto attr = std::make_unique<schema::FloorModT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_FloorMod; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "RealDiv") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteRealDivParser"; | |||
| std::unique_ptr<schema::RealDivT> attr = std::make_unique<schema::RealDivT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "SquaredDifference") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; | |||
| auto attr = std::make_unique<schema::SquaredDifferenceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_SquaredDifference; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Pow") == 0) { | |||
| MS_LOG(DEBUG) << "parse TflitePowParser"; | |||
| auto attr = std::make_unique<schema::PowerT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->power = 1.0f; | |||
| attr->scale = 1.0f; | |||
| attr->shift = 0.0f; | |||
| op->primitive->value.type = schema::PrimitiveType_Power; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Maximum") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteMaximumParser"; | |||
| auto attr = std::make_unique<schema::MaximumT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Maximum; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Minimum") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteMinimumParser"; | |||
| auto attr = std::make_unique<schema::MinimumT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| // set input | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| @@ -307,146 +148,6 @@ PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Abs") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteAbsParser"; | |||
| auto attr = std::make_unique<schema::AbsT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Abs; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Exp") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteExpParser"; | |||
| auto attr = std::make_unique<schema::ExpT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->base = -1; // -1 represent base = e | |||
| attr->scale = 1; | |||
| attr->shift = 0; | |||
| op->primitive->value.type = schema::PrimitiveType_Exp; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Sqrt") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSqrtParser"; | |||
| auto attr = std::make_unique<schema::SqrtT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Sqrt; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Rsqrt") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; | |||
| auto attr = std::make_unique<schema::RsqrtT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Rsqrt; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Square") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSquareParser"; | |||
| auto attr = std::make_unique<schema::SquareT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Square; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Sin") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSinParser"; | |||
| auto attr = std::make_unique<schema::SinT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Sin; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Cos") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteCosParser"; | |||
| std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Cos; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Log") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLogParser"; | |||
| auto attr = std::make_unique<schema::LogT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Log; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Round") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteRoundParser"; | |||
| auto attr = std::make_unique<schema::RoundT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Round; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Ceil") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteCeilParser"; | |||
| auto attr = std::make_unique<schema::CeilT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Ceil; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "flOOR") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteFloorParser"; | |||
| auto attr = std::make_unique<schema::FloorT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Floor; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Neg") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteNegParser"; | |||
| auto attr = std::make_unique<schema::NegT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Neg; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| @@ -566,91 +267,6 @@ PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Equal") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteEqualParser"; | |||
| std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Equal; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "NotEqual") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | |||
| std::unique_ptr<schema::NotEqualT> attr = std::make_unique<schema::NotEqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_NotEqual; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Greater") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | |||
| std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Greater; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "GreaterEqual") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | |||
| std::unique_ptr<schema::GreaterEqualT> attr = std::make_unique<schema::GreaterEqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_GreaterEqual; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Less") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLessParser"; | |||
| std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Less; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "LessEqual") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | |||
| std::unique_ptr<schema::LessEqualT> attr = std::make_unique<schema::LessEqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| @@ -714,36 +330,35 @@ PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr<tfli | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TflitePowParser("Pow", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteAddParser(tflite::BuiltinOperator_ADD, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSubParser(tflite::BuiltinOperator_SUB, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMulParser(tflite::BuiltinOperator_MUL, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteDivParser(tflite::BuiltinOperator_DIV, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorDivParser(tflite::BuiltinOperator_FLOOR_DIV, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorModParser(tflite::BuiltinOperator_FLOOR_MOD, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TflitePowParser(tflite::BuiltinOperator_POW, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser(tflite::BuiltinOperator_SQUARED_DIFFERENCE, | |||
| new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMaximumParser(tflite::BuiltinOperator_MAXIMUM, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteMinimumParser(tflite::BuiltinOperator_MINIMUM, new TfliteDoubleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteAbsParser(tflite::BuiltinOperator_ABS, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteExpParser(tflite::BuiltinOperator_EXP, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSqrtParser(tflite::BuiltinOperator_SQRT, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRsqrtParser(tflite::BuiltinOperator_RSQRT, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSquareParser(tflite::BuiltinOperator_SQUARE, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteSinParser(tflite::BuiltinOperator_SIN, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteCosParser(tflite::BuiltinOperator_COS, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteLogParser(tflite::BuiltinOperator_LOG, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteRoundParser(tflite::BuiltinOperator_ROUND, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_TfliteCeilParser(tflite::BuiltinOperator_CEIL, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteFloorParser(tflite::BuiltinOperator_FLOOR, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteNegParser(tflite::BuiltinOperator_NEG, new TfliteSingleInputOpParser()); | |||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteCompareOpParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteEqualParser(tflite::BuiltinOperator_EQUAL, new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteNotEqualParser(tflite::BuiltinOperator_NOT_EQUAL, new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEParser(tflite::BuiltinOperator_GREATER, new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteGreaterEqualParser(tflite::BuiltinOperator_GREATER_EQUAL, new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessParser(tflite::BuiltinOperator_LESS, new TfliteCompareOpParser()); | |||
| TfliteNodeRegister g_tfliteLessEqualParser(tflite::BuiltinOperator_LESS_EQUAL, new TfliteCompareOpParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -29,10 +29,6 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -41,10 +37,6 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -53,10 +45,6 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,59 +21,7 @@ | |||
| #include <string> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "BatchToSpace") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | |||
| } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { | |||
| MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) { | |||
| MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -98,7 +46,6 @@ PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<t | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | |||
| TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteBatchToSpaceNDParser(tflite::BuiltinOperator_BATCH_TO_SPACE_ND, | |||
| new TfliteBatchToSpaceParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -23,20 +23,15 @@ | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H | |||
| @@ -19,44 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) { | |||
| MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -80,6 +43,4 @@ PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tf | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -28,10 +28,6 @@ class TfliteBroadcastToParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -18,51 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||
| if (in_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||
| if (out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -90,6 +46,5 @@ PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::O | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteCastParser(tflite::BuiltinOperator_CAST, new TfliteCastParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -29,9 +29,6 @@ class TfliteCastParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteCastParser() : TfliteNodeParser("Cast") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -18,48 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); | |||
| if (tfliteAttr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = tfliteAttr->axis; | |||
| attr->n = tflite_op->inputs.size(); | |||
| op->primitive->value.type = schema::PrimitiveType_Concat; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -81,6 +40,5 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteConcatParser(tflite::BuiltinOperator_CONCATENATION, new TfliteConcatParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -29,9 +29,6 @@ class TfliteConcatParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteConcatParser() : TfliteNodeParser("Concat") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,82 +19,6 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->group = 1; | |||
| attr->strideW = tflite_attr->stride_w; | |||
| attr->strideH = tflite_attr->stride_h; | |||
| attr->dilateH = tflite_attr->dilation_h_factor; | |||
| attr->dilateW = tflite_attr->dilation_w_factor; | |||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| attr->hasBias = true; | |||
| // get the conv op weight tensor | |||
| auto weight_index = tflite_op->inputs[1]; | |||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||
| if (weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto weight_shape = weight_tensor->shape; | |||
| attr->channelIn = weight_shape[3]; | |||
| attr->channelOut = weight_shape[0]; | |||
| attr->kernelH = weight_shape[1]; | |||
| attr->kernelW = weight_shape[2]; | |||
| // calculate pad params | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||
| std::vector<int64_t> params; | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| attr->padRight = params.at(3); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -153,5 +77,5 @@ lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tfl | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); | |||
| TfliteNodeRegister g_tfliteConv2DParser(tflite::BuiltinOperator_CONV_2D, new TfliteConvParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,9 +28,6 @@ class TfliteConvParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteConvParser() : TfliteNodeParser("Conv2D") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -207,70 +207,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr, | |||
| return RET_OK; | |||
| } | |||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &custom_attr = tflite_op->custom_options; | |||
| const auto opcode_index = tflite_op->opcode_index; | |||
| const auto &operator_code = tflite_model->operator_codes[opcode_index]; | |||
| if (operator_code == nullptr) { | |||
| MS_LOG(ERROR) << "operator_code is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &custom_type = operator_code->custom_code; | |||
| int status = RET_OK; | |||
| if (custom_type == "TFLite_Detection_PostProcess") { | |||
| status = DetectPostProcess(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "Predict") { | |||
| status = Predict(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "Normalize") { | |||
| status = Normalize(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "ExtractFeatures") { | |||
| status = ExtractFeatures(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "AudioSpectrogram") { | |||
| status = AudioSpectrogram(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "Mfcc") { | |||
| status = Mfcc(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexRFFT") { | |||
| status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph); | |||
| } else if (custom_type == "FlexReal") { | |||
| status = FftReal(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexImag") { | |||
| status = FftImag(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexIdentityN" || custom_type == "FlexIdentity") { | |||
| status = Identity(custom_attr, op, tflite_op); | |||
| } else if (custom_type == "FlexBatchMatMul") { | |||
| status = BatchMatMul(custom_attr, op, tflite_op); | |||
| } else { | |||
| MS_LOG(ERROR) << "the custom op hasn't been supported now"; | |||
| status = RET_NOT_FIND_OP; | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| for (int output : tflite_op->outputs) { | |||
| AddOpOutput(op, tensors_info, output, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| return status; | |||
| } | |||
| PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -314,6 +250,6 @@ PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive); | |||
| } | |||
| TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); | |||
| TfliteNodeRegister g_tfliteCustomParser(tflite::BuiltinOperator_CUSTOM, new TfliteCustomParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,9 +28,6 @@ class TfliteCustomParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteCustomParser() : TfliteNodeParser("Custom") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| @@ -18,84 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->group = 1; | |||
| attr->strideW = tflite_attr->stride_w; | |||
| attr->strideH = tflite_attr->stride_h; | |||
| attr->dilateH = 1; | |||
| attr->dilateW = 1; | |||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| attr->hasBias = true; | |||
| // get the conv op weight tensor | |||
| auto weight_index = tflite_op->inputs[1]; | |||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||
| if (weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto weight_shape = weight_tensor->shape; | |||
| attr->channelIn = weight_shape[3]; | |||
| attr->channelOut = weight_shape[0]; | |||
| attr->kernelH = weight_shape[1]; | |||
| attr->kernelW = weight_shape[2]; | |||
| // calculate pad params | |||
| auto data_index = tflite_op->inputs[2]; | |||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||
| std::vector<int64_t> params; | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| attr->padRight = params.at(3); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -155,6 +78,5 @@ PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| TfliteNodeRegister g_tfliteDeConv2DParser(tflite::BuiltinOperator_TRANSPOSE_CONV, new TfliteDeConvParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -23,19 +23,14 @@ | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| class TfliteDeConvParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H | |||
| @@ -21,45 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->blockSize = tflite_attr->block_size; | |||
| attr->format = schema::Format::Format_NHWC; | |||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); | |||
| @@ -81,6 +42,6 @@ PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<t | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser()); | |||
| TfliteNodeRegister g_tfliteDepthToSpaceParser(tflite::BuiltinOperator_DEPTH_TO_SPACE, new TfliteDepthToSpaceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,90 +19,6 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DepthwiseConv2DT> attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->strideW = tflite_attr->stride_w; | |||
| attr->strideH = tflite_attr->stride_h; | |||
| attr->dilateH = tflite_attr->dilation_h_factor; | |||
| attr->dilateW = tflite_attr->dilation_w_factor; | |||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| attr->hasBias = true; | |||
| attr->channelMultiplier = tflite_attr->depth_multiplier; | |||
| // get the data tensor | |||
| auto data_index = tflite_op->inputs[1]; | |||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||
| if (data_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the data tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_shape = data_tensor->shape; | |||
| attr->channelIn = data_shape[3]; | |||
| // get the weight tensor | |||
| auto weight_index = tflite_op->inputs[1]; | |||
| const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; | |||
| if (weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto weight_shape = weight_tensor->shape; | |||
| attr->kernelH = weight_shape[1]; | |||
| attr->kernelW = weight_shape[2]; | |||
| // calculate pad params | |||
| std::vector<int64_t> params; | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| attr->padRight = params.at(3); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | |||
| @@ -169,5 +85,6 @@ lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::uni | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); | |||
| TfliteNodeRegister g_tfliteDepthwiseConv2DParser(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | |||
| new TfliteDepthwiseConv2DParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,10 +28,6 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,62 +19,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||
| if (in_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||
| if (out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||
| (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.value = attr.release(); | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| } else { | |||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.value = attr.release(); | |||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||
| } | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -115,6 +59,6 @@ PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tfl | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | |||
| TfliteNodeRegister g_tfliteDequantizeParser(tflite::BuiltinOperator_DEQUANTIZE, new TfliteDequantizeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,10 +28,6 @@ class TfliteDequantizeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,42 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ExpandDimsT> attr = std::make_unique<schema::ExpandDimsT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<int> dims; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) { | |||
| MS_LOG(ERROR) << "get expand_dims -> dim failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->dim = dims[0]; | |||
| op->primitive->value.type = schema::PrimitiveType_ExpandDims; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -80,6 +44,6 @@ PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr<tfl | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | |||
| TfliteNodeRegister g_tfliteExpandDimsParser(tflite::BuiltinOperator_EXPAND_DIMS, new TfliteExpandDimsParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteExpandDimsParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,43 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::FillT> attr = std::make_unique<schema::FillT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tflite_op->inputs.size() > 1) { | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { | |||
| MS_LOG(ERROR) << "get fill -> dims failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Fill; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -84,6 +47,6 @@ PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::O | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser()); | |||
| TfliteNodeRegister g_tfliteFillParser(tflite::BuiltinOperator_FILL, new TfliteFillParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,10 +29,6 @@ class TfliteFillParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteFillParser() : TfliteNodeParser("Fill") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,55 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::FullConnectionT> attr = std::make_unique<schema::FullConnectionT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| bool hasBias = tflite_op->inputs.size() > 2 && tflite_op->inputs[2] != -1; | |||
| attr->hasBias = hasBias; | |||
| attr->axis = 1; | |||
| attr->useAxis = false; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); | |||
| if (hasBias) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -101,7 +52,8 @@ PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); | |||
| TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser()); | |||
| TfliteNodeRegister g_tfliteFullyConnectedParser(tflite::BuiltinOperator_FULLY_CONNECTED, | |||
| new TfliteFullyConnectedParser()); | |||
| TfliteNodeRegister g_tfliteFakeQuantParser(tflite::BuiltinOperator_FAKE_QUANT, new TfliteFullyConnectedParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,40 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::GatherNdT> attr = std::make_unique<schema::GatherNdT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->batchDims = 0; | |||
| op->primitive->value.type = schema::PrimitiveType_GatherNd; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -75,6 +41,6 @@ PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflit | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser()); | |||
| TfliteNodeRegister g_tfliteGatherNdParser(tflite::BuiltinOperator_GATHER_ND, new TfliteGatherNdParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteGatherNdParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,46 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteGatherParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = tflite_attr->axis; | |||
| attr->batchDims = 0; | |||
| op->primitive->value.type = schema::PrimitiveType_Gather; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -87,6 +47,6 @@ PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser()); | |||
| TfliteNodeRegister g_tfliteGatherParser(tflite::BuiltinOperator_GATHER, new TfliteGatherParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteGatherParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteGatherParser() : TfliteNodeParser("Gather") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,41 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| for (int output : tflite_op->outputs) { | |||
| AddOpOutput(op, tensors_info, output, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -74,6 +39,7 @@ PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_pt | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | |||
| TfliteNodeRegister g_tfliteHashtableLookupParser(tflite::BuiltinOperator_HASHTABLE_LOOKUP, | |||
| new TfliteHashtableLookupParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,40 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::L2NormT> attr = std::make_unique<schema::L2NormT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions(); | |||
| attr->axis = {-1}; | |||
| attr->epsilon = 1e-6f; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| op->primitive->value.type = schema::PrimitiveType_L2Norm; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| std::unique_ptr<schema::L2NormT> attr = std::make_unique<schema::L2NormT>(); | |||
| @@ -77,6 +43,6 @@ PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser()); | |||
| TfliteNodeRegister g_tfliteL2NormParser(tflite::BuiltinOperator_L2_NORMALIZATION, new TfliteL2NormParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteL2NormParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,63 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "LogicalAnd") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; | |||
| std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "LogicalNot") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLogicalNotParser"; | |||
| std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LogicalNot; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "LogicalOr") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteLogicalOrParser"; | |||
| std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LogicalOr; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -117,8 +60,8 @@ PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr<tflite | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser()); | |||
| TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser()); | |||
| TfliteNodeRegister g_tfliteLogicalOrParser("LogicalOr", new TfliteLogicalParser()); | |||
| TfliteNodeRegister g_tfliteLogicalAndParser(tflite::BuiltinOperator_LOGICAL_AND, new TfliteLogicalParser()); | |||
| TfliteNodeRegister g_tfliteLogicalNotParser(tflite::BuiltinOperator_LOGICAL_NOT, new TfliteLogicalParser()); | |||
| TfliteNodeRegister g_tfliteLogicalOrParser(tflite::BuiltinOperator_LOGICAL_OR, new TfliteLogicalParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,9 +28,6 @@ class TfliteLogicalParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteLogicalParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,46 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteLRNParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->depth_radius = tflite_attr->radius; | |||
| attr->alpha = tflite_attr->alpha; | |||
| attr->beta = tflite_attr->beta; | |||
| attr->bias = tflite_attr->bias; | |||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -89,6 +49,6 @@ PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr<tflite::Op | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser()); | |||
| TfliteNodeRegister g_tfliteLRNParser(tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, new TfliteLRNParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteLRNParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,50 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions(); | |||
| switch (tflite_attr->type) { | |||
| case tflite::LSHProjectionType_SPARSE: | |||
| attr->type = schema::LshProjectionType_SPARSE; | |||
| break; | |||
| case tflite::LSHProjectionType_DENSE: | |||
| attr->type = schema::LshProjectionType_DENSE; | |||
| break; | |||
| default: | |||
| attr->type = schema::LshProjectionType_UNKNOWN; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LshProjection; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -94,6 +50,6 @@ PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | |||
| TfliteNodeRegister g_tfliteLshProjectionParser(tflite::BuiltinOperator_LSH_PROJECTION, new TfliteLshProjectionParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteLshProjectionParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,41 +19,28 @@ | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteMatMulParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteMatMulParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| namespace mindspore::lite { | |||
| PrimitiveC *TfliteMatMulParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is null"; | |||
| return nullptr; | |||
| } | |||
| std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions(); | |||
| attr->transposeA = tflite_attr->adj_x; | |||
| attr->transposeB = tflite_attr->adj_y; | |||
| attr->broadcast = false; | |||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | |||
| op->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||
| primitive->value.value = attr.release(); | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteMatMulParser("MatMul", new TfliteMatMulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -23,17 +23,14 @@ | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| class TfliteMatMulParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteMatMulParser() : TfliteNodeParser("MatMul") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H | |||
| @@ -86,12 +86,12 @@ STATUS TfliteModelParser::ConvertOps() { | |||
| STATUS status = RET_OK; | |||
| int op_idx = 0; | |||
| for (auto &op : tflite_subgraph->operators) { | |||
| auto tfliteOpType = (tflite_model_->operator_codes[op->opcode_index])->builtin_code; | |||
| auto op_type = GetMSOpType(tfliteOpType); | |||
| auto tflite_op_type = (tflite_model_->operator_codes[op->opcode_index])->builtin_code; | |||
| auto op_type = GetMSOpType(tflite_op_type); | |||
| auto op_name = op_type + "-" + std::to_string(op_idx); | |||
| op_idx++; | |||
| // parse primitive | |||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| @@ -38,52 +38,11 @@ class TfliteNodeParser { | |||
| virtual ~TfliteNodeParser() = default; | |||
| virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0; | |||
| virtual lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| return nullptr; | |||
| } | |||
| static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, | |||
| schema::Format format) { | |||
| MS_ASSERT(op != nullptr); | |||
| MS_ASSERT(tensors_info != nullptr); | |||
| int new_idx = tensors_info->tensorsId.size(); | |||
| auto iter = tensors_info->tensorsIdMap.find(idx); | |||
| if (iter != tensors_info->tensorsIdMap.end()) { | |||
| op->inputIndex.emplace_back(iter->second); | |||
| } else { | |||
| if (idx < 0) { | |||
| idx += total; | |||
| } | |||
| tensors_info->tensorsId.emplace_back(idx); | |||
| tensors_info->tensorsFormat.emplace_back(format); | |||
| tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); | |||
| op->inputIndex.emplace_back(new_idx); | |||
| } | |||
| } | |||
| static void AddOpOutput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, | |||
| schema::Format format) { | |||
| MS_ASSERT(op != nullptr); | |||
| MS_ASSERT(tensors_info != nullptr); | |||
| int new_idx = tensors_info->tensorsId.size(); | |||
| auto iter = tensors_info->tensorsIdMap.find(idx); | |||
| if (iter != tensors_info->tensorsIdMap.end()) { | |||
| op->outputIndex.emplace_back(iter->second); | |||
| } else { | |||
| if (idx < 0) { | |||
| idx += total; | |||
| } | |||
| tensors_info->tensorsId.emplace_back(idx); | |||
| tensors_info->tensorsFormat.emplace_back(format); | |||
| tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); | |||
| op->outputIndex.emplace_back(new_idx); | |||
| } | |||
| } | |||
| template <typename T> | |||
| STATUS GetTfliteData(const int32_t tensor_index, const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| @@ -34,8 +34,8 @@ TfliteNodeParserRegistry *TfliteNodeParserRegistry::GetInstance() { | |||
| return &instance; | |||
| } | |||
| TfliteNodeParser *TfliteNodeParserRegistry::GetNodeParser(const std::string &name) { | |||
| auto it = parsers.find(name); | |||
| TfliteNodeParser *TfliteNodeParserRegistry::GetNodeParser(const tflite::BuiltinOperator &type) { | |||
| auto it = parsers.find(type); | |||
| if (it != parsers.end()) { | |||
| return it->second; | |||
| } | |||
| @@ -31,15 +31,15 @@ class TfliteNodeParserRegistry { | |||
| static TfliteNodeParserRegistry *GetInstance(); | |||
| TfliteNodeParser *GetNodeParser(const std::string &name); | |||
| TfliteNodeParser *GetNodeParser(const tflite::BuiltinOperator &type); | |||
| std::unordered_map<std::string, TfliteNodeParser *> parsers; | |||
| std::unordered_map<tflite::BuiltinOperator, TfliteNodeParser *> parsers; | |||
| }; | |||
| class TfliteNodeRegister { | |||
| public: | |||
| TfliteNodeRegister(const std::string &name, TfliteNodeParser *parser) { | |||
| TfliteNodeParserRegistry::GetInstance()->parsers[name] = parser; | |||
| TfliteNodeRegister(const tflite::BuiltinOperator &type, TfliteNodeParser *parser) { | |||
| TfliteNodeParserRegistry::GetInstance()->parsers[type] = parser; | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| @@ -20,50 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteOneHotParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = tflite_attr->axis; | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -97,6 +53,6 @@ PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser()); | |||
| TfliteNodeRegister g_tfliteOneHotParser(tflite::BuiltinOperator_ONE_HOT, new TfliteOneHotParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteOneHotParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteOneHotParser() : TfliteNodeParser("OneHot") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,75 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TflitePadParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Pad") == 0) { | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||
| attr->constantValue = 0.0f; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { | |||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (std::strcmp(node_name, "MirrorPad") == 0) { | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| switch (tflite_attr->mode) { | |||
| case tflite::MirrorPadMode_REFLECT: | |||
| attr->paddingMode = schema::PaddingMode_REFLECT; | |||
| break; | |||
| case tflite::MirrorPadMode_SYMMETRIC: | |||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| if (std::strcmp(node_name, "MirrorPad") == 0) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -144,7 +75,7 @@ PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr<tflite::Op | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); | |||
| TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); | |||
| TfliteNodeRegister g_tflitePadParser(tflite::BuiltinOperator_PAD, new TflitePadParser()); | |||
| TfliteNodeRegister g_tfliteMirorPadParser(tflite::BuiltinOperator_MIRROR_PAD, new TflitePadParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TflitePadParser : public TfliteNodeParser { | |||
| public: | |||
| TflitePadParser() : TfliteNodeParser("Pad") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,76 +20,6 @@ | |||
| #include <string> | |||
| namespace mindspore::lite { | |||
| STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) { | |||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||
| } else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) { | |||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||
| } else { | |||
| MS_LOG(ERROR) << "pooling mode " << tflite_op_type << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->windowW = tflite_attr->filter_width; | |||
| attr->windowH = tflite_attr->filter_height; | |||
| attr->strideW = tflite_attr->stride_w; | |||
| attr->strideH = tflite_attr->stride_h; | |||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->global = false; | |||
| attr->roundMode = schema::RoundMode_FLOOR; | |||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||
| // calculate pad params | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_subgraph->tensors[data_index]; | |||
| std::vector<int64_t> params; | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| attr->padRight = params.at(3); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -142,6 +72,6 @@ lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser()); | |||
| TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser()); | |||
| TfliteNodeRegister g_tfliteMeanPoolingParser(tflite::BuiltinOperator_AVERAGE_POOL_2D, new TflitePoolingParser()); | |||
| TfliteNodeRegister g_tfliteMaxPoolingParser(tflite::BuiltinOperator_MAX_POOL_2D, new TflitePoolingParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,9 +28,6 @@ class TflitePoolingParser : public TfliteNodeParser { | |||
| public: | |||
| TflitePoolingParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,37 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TflitePReLUParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->channelShared = true; | |||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -71,6 +40,6 @@ PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr<tflite:: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser()); | |||
| TfliteNodeRegister g_tflitePReLUParser(tflite::BuiltinOperator_PRELU, new TflitePReLUParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TflitePReLUParser : public TfliteNodeParser { | |||
| public: | |||
| TflitePReLUParser() : TfliteNodeParser("PRELU") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,61 +19,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||
| if (in_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; | |||
| if (out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||
| (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||
| op->primitive->value.value = attr.release(); | |||
| } | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -119,6 +64,6 @@ PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflit | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser()); | |||
| TfliteNodeRegister g_tfliteQuantizeParser(tflite::BuiltinOperator_QUANTIZE, new TfliteQuantizeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,9 +28,6 @@ class TfliteQuantizeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,57 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteRangeParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->dType = 0; | |||
| std::vector<int> limit; | |||
| std::vector<int> delta; | |||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "range -> limit get failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (status == RET_OK) { | |||
| attr->limit = limit.front(); | |||
| attr->delta = delta.front(); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||
| op->primitive->value.value = attr.release(); | |||
| int input_num = status == RET_OK ? 1 : 3; | |||
| for (int i = 0; i < input_num; ++i) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -109,6 +58,6 @@ PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr<tflite:: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser()); | |||
| TfliteNodeRegister g_tfliteRangeParser(tflite::BuiltinOperator_RANGE, new TfliteRangeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteRangeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteRangeParser() : TfliteNodeParser("Range") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,36 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteRankParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::RankT> attr = std::make_unique<schema::RankT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Rank; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -69,6 +39,6 @@ PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr<tflite::O | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser()); | |||
| TfliteNodeRegister g_tfliteRankParser(tflite::BuiltinOperator_RANK, new TfliteRankParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteRankParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteRankParser() : TfliteNodeParser("Rank") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,70 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->keepDims = tflite_attr->keep_dims; | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "ReduceMax") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteReduceMaxParser"; | |||
| attr->mode = schema::ReduceMode_ReduceMax; | |||
| } else if (std::strcmp(node_name, "ReduceMin") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteReduceMinParser"; | |||
| attr->mode = schema::ReduceMode_ReduceMin; | |||
| } else if (std::strcmp(node_name, "ReduceProd") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteReduceProdParser"; | |||
| attr->mode = schema::ReduceMode_ReduceProd; | |||
| } else if (std::strcmp(node_name, "Sum") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteSumParser"; | |||
| attr->mode = schema::ReduceMode_ReduceSum; | |||
| } else if (std::strcmp(node_name, "Mean") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteMeanParser"; | |||
| attr->mode = schema::ReduceMode_ReduceMean; | |||
| } else { | |||
| MS_LOG(ERROR) << node_name << " hasn't been supported"; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) { | |||
| MS_LOG(ERROR) << "get reduce -> axes failed"; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -139,11 +75,11 @@ PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteMeanParser("Mean", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteSumParser(tflite::BuiltinOperator_SUM, new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteMeanParser(tflite::BuiltinOperator_MEAN, new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceMaxParser(tflite::BuiltinOperator_REDUCE_MAX, new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceMinParser(tflite::BuiltinOperator_REDUCE_MIN, new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceProdParser(tflite::BuiltinOperator_REDUCE_PROD, new TfliteReduceParser()); | |||
| TfliteNodeRegister g_TfliteReduceAnyParser(tflite::BuiltinOperator_REDUCE_ANY, new TfliteReduceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,9 +28,6 @@ class TfliteReduceParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteReduceParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,69 +19,6 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| MS_LOG(DEBUG) << "parse TfliteReshapeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| if (tflite_op->inputs.size() < 2) { | |||
| MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto shape_tensor_index = tflite_op->inputs[1]; | |||
| const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index]; | |||
| if (shape_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "shape_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto &buf_data = tflite_model->buffers[shape_tensor->buffer]; | |||
| if (buf_data == nullptr) { | |||
| MS_LOG(ERROR) << "buf_data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (!buf_data->data.empty()) { | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) { | |||
| MS_LOG(ERROR) << "get reshape -> shape failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->shape.resize(tflite_attr->new_shape.size()); | |||
| for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) { | |||
| attr->shape[i] = tflite_attr->new_shape[i]; | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Reshape; | |||
| op->primitive->value.value = attr.release(); | |||
| for (int input : tflite_op->inputs) { | |||
| AddOpInput(op, tensors_info, input, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| const auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -127,5 +64,5 @@ lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); | |||
| TfliteNodeRegister g_tfliteReshapeParser(tflite::BuiltinOperator_RESHAPE, new TfliteReshapeParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,10 +28,6 @@ class TfliteReshapeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -22,104 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON; | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "ResizeBilinear") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions(); | |||
| if (tfliteAttr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tfliteAttr->align_corners) { | |||
| attr->alignCorners = tfliteAttr->align_corners; | |||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; | |||
| } | |||
| if (tfliteAttr->half_pixel_centers) { | |||
| attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON | |||
| ? schema::CoordinateTransformMode_TF_HALF_PIXEL | |||
| : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); | |||
| } | |||
| attr->method = schema::ResizeMethod_LINEAR; | |||
| } else if (std::strcmp(node_name, "NearestNeighbor") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; | |||
| const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); | |||
| if (tfliteAttr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tfliteAttr->align_corners) { | |||
| attr->alignCorners = tfliteAttr->align_corners; | |||
| attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS; | |||
| } | |||
| if (tfliteAttr->half_pixel_centers) { | |||
| attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON | |||
| ? schema::CoordinateTransformMode_TF_HALF_PIXEL | |||
| : schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL); | |||
| } | |||
| attr->method = schema::ResizeMethod_NEAREST; | |||
| attr->nearestMode = schema::NearestMode_NORMAL; | |||
| } else { | |||
| MS_LOG(ERROR) << "wrong resize type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->format = schema::Format::Format_NHWC; | |||
| attr->preserveAspectRatio = false; | |||
| auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | |||
| const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex]; | |||
| if (shape_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "shape_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto resizeTensorBufferIndex = shape_tensor->buffer; | |||
| const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex); | |||
| if (buff == nullptr) { | |||
| MS_LOG(ERROR) << "buff_data is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto buffData = reinterpret_cast<int32_t *>(buff->data.data()); | |||
| if (buffData != nullptr) { | |||
| auto height = buffData[0]; | |||
| auto width = buffData[1]; | |||
| attr->newWidth = width; | |||
| attr->newHeight = height; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Resize; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| if (buffData == nullptr) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -204,7 +106,8 @@ PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser()); | |||
| TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser()); | |||
| TfliteNodeRegister g_tfliteResizeBilinearParser(tflite::BuiltinOperator_RESIZE_BILINEAR, new TfliteResizeParser()); | |||
| TfliteNodeRegister g_tfliteResizeNearestNeighborParser(tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, | |||
| new TfliteResizeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -28,9 +28,6 @@ class TfliteResizeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteResizeParser() : TfliteNodeParser("node_name") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,41 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteReverseParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReverseT> attr = std::make_unique<schema::ReverseT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) { | |||
| MS_LOG(ERROR) << "get reverse -> axis failed"; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Reverse; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -80,6 +45,6 @@ PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr<tflite | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser()); | |||
| TfliteNodeRegister g_tfliteReverseParser(tflite::BuiltinOperator_REVERSE_V2, new TfliteReverseParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteReverseParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteReverseParser() : TfliteNodeParser("reverse") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,47 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReverseSequenceT> attr = std::make_unique<schema::ReverseSequenceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->seqAxis = tflite_attr->seq_dim; | |||
| attr->batchAxis = tflite_attr->batch_dim; | |||
| op->primitive->value.type = schema::PrimitiveType_ReverseSequence; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -89,6 +48,7 @@ PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_pt | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser()); | |||
| TfliteNodeRegister g_tfliteReverseSequenceParser(tflite::BuiltinOperator_REVERSE_SEQUENCE, | |||
| new TfliteReverseSequenceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,44 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ScatterNDT> attr = std::make_unique<schema::ScatterNDT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_ScatterND; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -82,6 +44,6 @@ PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr<tfli | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser()); | |||
| TfliteNodeRegister g_tfliteScatterNdParser(tflite::BuiltinOperator_SCATTER_ND, new TfliteScatterNdParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteScatterNdParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,36 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteShapeParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Shape; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -69,6 +39,6 @@ PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr<tflite:: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser()); | |||
| TfliteNodeRegister g_tfliteShapeParser(tflite::BuiltinOperator_SHAPE, new TfliteShapeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteShapeParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteShapeParser() : TfliteNodeParser("Shape") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,45 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->includeAllGrams = tflite_attr->include_all_ngrams; | |||
| attr->maxSkipSize = tflite_attr->max_skip_size; | |||
| attr->ngramSize = tflite_attr->ngram_size; | |||
| op->primitive->value.type = schema::PrimitiveType_SkipGram; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -87,6 +48,6 @@ PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr<tflit | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); | |||
| TfliteNodeRegister g_tfliteSkiGramParser(tflite::BuiltinOperator_SKIP_GRAM, new TfliteSkipGramParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSkipGramParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -20,52 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSliceParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->format = schema::Format::Format_NHWC; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) { | |||
| MS_LOG(ERROR) << "get slice -> begin failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) { | |||
| MS_LOG(ERROR) << "get slice -> size failed"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> axes; | |||
| axes.clear(); | |||
| for (size_t i = 0; i < attr->begin.size(); ++i) { | |||
| axes.push_back(i); | |||
| } | |||
| attr->axes = axes; | |||
| op->primitive->value.type = schema::PrimitiveType_Slice; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -102,6 +56,6 @@ PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr<tflite:: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); | |||
| TfliteNodeRegister g_tfliteSliceParser(tflite::BuiltinOperator_SLICE, new TfliteSliceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSliceParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSliceParser() : TfliteNodeParser("Slice") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -19,38 +19,6 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = -1; | |||
| op->primitive->value.type = schema::PrimitiveType_SoftMax; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSoftmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| @@ -67,5 +35,5 @@ PrimitiveC *TfliteSoftmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSoftmaxParser("Softmax", new TfliteSoftmaxParser()); | |||
| TfliteNodeRegister g_tfliteSoftmaxParser(tflite::BuiltinOperator_SOFTMAX, new TfliteSoftmaxParser()); | |||
| } // namespace mindspore::lite | |||
| @@ -28,10 +28,6 @@ class TfliteSoftmaxParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,47 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SpaceToBatchNDT> attr = std::make_unique<schema::SpaceToBatchNDT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { | |||
| MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { | |||
| MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSpaceToBatchNDParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -91,6 +50,7 @@ PrimitiveC *TfliteSpaceToBatchNDParser::ParseLitePrimitive(const std::unique_ptr | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSpaceToBatchNDParser("SpaceToBatchND", new TfliteSpaceToBatchNDParser()); | |||
| TfliteNodeRegister g_tfliteSpaceToBatchNDParser(tflite::BuiltinOperator_SPACE_TO_BATCH_ND, | |||
| new TfliteSpaceToBatchNDParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,45 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsSpaceToDepthOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op:" << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->blockSize = tflite_attr->block_size; | |||
| attr->format = schema::Format::Format_NHWC; | |||
| op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSpaceToDepthParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -87,6 +48,6 @@ PrimitiveC *TfliteSpaceToDepthParser::ParseLitePrimitive(const std::unique_ptr<t | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSpaceToDepthParser("SpaceToDepth", new TfliteSpaceToDepthParser()); | |||
| TfliteNodeRegister g_tfliteSpaceToDepthParser(tflite::BuiltinOperator_SPACE_TO_DEPTH, new TfliteSpaceToDepthParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,42 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SparseToDenseT> attr = std::make_unique<schema::SparseToDenseT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->validateIndices = false; | |||
| op->primitive->value.type = schema::PrimitiveType_SparseToDense; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSparseToDenseParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -77,6 +41,7 @@ PrimitiveC *TfliteSparseToDenseParser::ParseLitePrimitive(const std::unique_ptr< | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSparseToDenseParser("SparseToDense", new TfliteSparseToDenseParser()); | |||
| TfliteNodeRegister g_tfliteSparseToDenseParser(tflite::BuiltinOperator_SPARSE_TO_DENSE, | |||
| new TfliteSparseToDenseParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,81 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSplitParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsSplitOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto num_splits = tflite_attr->num_splits; | |||
| const auto &shape_tensor = tflite_subgraph->tensors[tflite_op->inputs[1]]; | |||
| if (shape_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "shape_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto tensor_shape = shape_tensor->shape; | |||
| const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto axis = *(reinterpret_cast<int32_t *>(tflite_model->buffers[axis_tensor->buffer]->data.data())); | |||
| if (axis < 0) { | |||
| axis += tensor_shape.size(); | |||
| } | |||
| if (axis >= static_cast<int>(tensor_shape.size())) { | |||
| MS_LOG(ERROR) << "axis value is too large"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->splitDim = axis; | |||
| if (num_splits == 0) { | |||
| MS_LOG(ERROR) << "Divide-by-zero error!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) { | |||
| MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; | |||
| return RET_ERROR; | |||
| } | |||
| attr->numberSplit = num_splits; | |||
| if (tensor_shape[axis] / num_splits != 0) { | |||
| for (int i = 0; i < num_splits; i++) { | |||
| attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| for (int output : tflite_op->outputs) { | |||
| AddOpOutput(op, tensors_info, output, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSplitParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| @@ -154,6 +79,6 @@ PrimitiveC *TfliteSplitParser::ParseLitePrimitive(const std::unique_ptr<tflite:: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSplitParser("Split", new TfliteSplitParser()); | |||
| TfliteNodeRegister g_tfliteSplitParser(tflite::BuiltinOperator_SPLIT, new TfliteSplitParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSplitParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSplitParser() : TfliteNodeParser("Split") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||
| @@ -21,76 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "parse TfliteSplitVParser"; | |||
| MS_ASSERT(tflite_op != nullptr); | |||
| MS_ASSERT(tflite_model != nullptr); | |||
| MS_ASSERT(tflite_subgraph != nullptr); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsSplitVOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->numberSplit = tflite_attr->num_splits; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->sizeSplits)) { | |||
| MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_shape is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tensor_shape = tensor->shape; | |||
| const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[2]]; | |||
| if (axis_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "axis_tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &axis_buf = tflite_model->buffers[axis_tensor->buffer]; | |||
| if (axis_buf == nullptr) { | |||
| MS_LOG(ERROR) << "axis_buf is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto axis = *(reinterpret_cast<int32_t *>(axis_buf->data.data())); | |||
| if (axis < 0) { | |||
| axis += tensor_shape.size(); | |||
| } | |||
| if (axis >= static_cast<int>(tensor_shape.size())) { | |||
| MS_LOG(ERROR) << "axis value is too large"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->splitDim = axis; | |||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| for (int output : tflite_op->outputs) { | |||
| AddOpOutput(op, tensors_info, output, tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *TfliteSplitVParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||
| auto &tflite_subgraph = tflite_model->subgraphs.front(); | |||
| @@ -144,6 +74,6 @@ PrimitiveC *TfliteSplitVParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| TfliteNodeRegister g_tfliteSplitVParser("SplitV", new TfliteSplitVParser()); | |||
| TfliteNodeRegister g_tfliteSplitVParser(tflite::BuiltinOperator_SPLIT_V, new TfliteSplitVParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -29,9 +29,6 @@ class TfliteSplitVParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSplitVParser() : TfliteNodeParser("SplitV") {} | |||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override; | |||
| PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model) override; | |||
| }; | |||