Merge pull request !4189 from lyvette/tflite_parsertags/v0.7.0-beta
| @@ -228,6 +228,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { | |||||
| return new lite::Elu(const_cast<schema::Primitive *>(srcPrim)); | return new lite::Elu(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_DeDepthwiseConv2D: | case schema::PrimitiveType_DeDepthwiseConv2D: | ||||
| return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim)); | return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim)); | ||||
| case schema::PrimitiveType_Shape: | |||||
| return new lite::Shape(const_cast<schema::Primitive *>(srcPrim)); | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/slice.h" | #include "src/runtime/kernel/arm/nnacl/fp32/slice.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" | #include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" | #include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/shape.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/stack.h" | #include "src/runtime/kernel/arm/nnacl/fp32/stack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/unstack.h" | #include "src/runtime/kernel/arm/nnacl/unstack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/depth_to_space.h" | #include "src/runtime/kernel/arm/nnacl/depth_to_space.h" | ||||
| @@ -874,6 +875,16 @@ OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) { | |||||
| return reinterpret_cast<OpParameter *>(reshape_param); | return reinterpret_cast<OpParameter *>(reshape_param); | ||||
| } | } | ||||
| OpParameter *PopulateShapeParameter(const lite::Primitive *primitive) { | |||||
| ShapeParameter *shape_param = new (std::nothrow) ShapeParameter(); | |||||
| if (shape_param == nullptr) { | |||||
| MS_LOG(ERROR) << "new ShapeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| shape_param->op_parameter_.type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(shape_param); | |||||
| } | |||||
| OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { | OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { | ||||
| auto reverse_attr = primitive->Value()->value_as_Reverse(); | auto reverse_attr = primitive->Value()->value_as_Reverse(); | ||||
| ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); | ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); | ||||
| @@ -1306,6 +1317,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; | populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; | populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; | populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Shape] = PopulateShapeParameter; | |||||
| populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; | populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; | populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter; | populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter; | ||||
| @@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./ | |||||
| TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ | TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ | ||||
| cp -fr $TEST_DATA_DIR/testPK ./data | cp -fr $TEST_DATA_DIR/testPK ./data | ||||
| ./lite-test --gtest_filter="*MindDataTestTensorDE*" | |||||
| ./lite-test --gtest_filter="*MindDataTestEager*" | |||||
| ./lite-test --gtest_filter="TestTfliteParser*" | |||||
| ./lite-test --gtest_filter="*TestHebing*" | |||||
| ./lite-test --gtest_filter=TestFcFp32* | |||||
| ./lite-test --gtest_filter=TestConv1x1Fp32* | |||||
| ./lite-test --gtest_filter=TestStrassenFp32* | |||||
| ./lite-test --gtest_filter=TestDeConvolutionFp32* | |||||
| ./lite-test --gtest_filter=TestPadInt8.* | |||||
| ./lite-test --gtest_filter=TestDeconvInt8.* | |||||
| #./lite-test --gtest_filter="*MindDataTestTensorDE*" | |||||
| #./lite-test --gtest_filter="*MindDataTestEager*" | |||||
| # | |||||
| #./lite-test --gtest_filter="TestTfliteParser*" | |||||
| # | |||||
| #./lite-test --gtest_filter="*TestHebing*" | |||||
| # | |||||
| #./lite-test --gtest_filter=TestFcFp32* | |||||
| #./lite-test --gtest_filter=TestConv1x1Fp32* | |||||
| #./lite-test --gtest_filter=TestStrassenFp32* | |||||
| #./lite-test --gtest_filter=TestDeConvolutionFp32* | |||||
| # | |||||
| #./lite-test --gtest_filter=TestPadInt8.* | |||||
| #./lite-test --gtest_filter=TestDeconvInt8.* | |||||
| ./lite-test --gtest_filter="TestTfliteParser*" | ./lite-test --gtest_filter="TestTfliteParser*" | ||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_abs_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteAbsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteAbsParser"; | |||||
| std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Abs; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_ABS_PARSER_H | |||||
| #define PREDICT_TFLITE_ABS_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteAbsParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteAbsParser() : TfliteNodeParser("Abs") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_ABS_PARSER_H | |||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "tools/converter/parser/tflite/tflite_activation_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "Relu") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReluParser"; | |||||
| attr->type = schema::ActivationType_RELU; | |||||
| } else if (std::strcmp(node_name, "Relu6") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRelu6Parser"; | |||||
| attr->type = schema::ActivationType_RELU6; | |||||
| } else if (std::strcmp(node_name, "Tanh") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteTanhParser"; | |||||
| attr->type = schema::ActivationType_TANH; | |||||
| } else if (std::strcmp(node_name, "Logistic") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | |||||
| attr->type = schema::ActivationType_SIGMOID; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong activation type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "paser TflitePreluParser"; | |||||
| std::unique_ptr<schema::PreluT> attr(new schema::PreluT()); | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { | |||||
| MS_LOG(ERROR) << "get pRelu -> slope failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Prelu; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; | |||||
| std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT()); | |||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->negativeSlope = tflite_attr->alpha; | |||||
| op->primitive->value.type = schema::PrimitiveType_LeakyReLU; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | |||||
| TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | |||||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | |||||
| TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); | |||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_RELU_PARSER_H | |||||
| #define PREDICT_TFLITE_RELU_PARSER_H | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteActivationParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| class TfliteReluParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteReluParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteRelu6Parser : public TfliteActivationParser{ | |||||
| public: | |||||
| TfliteRelu6Parser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteTanhParser : public TfliteActivationParser{ | |||||
| public: | |||||
| TfliteTanhParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteLogisticParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteLogisticParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TflitePreluParser : public TfliteNodeParser { | |||||
| public: | |||||
| TflitePreluParser() : TfliteNodeParser("Prelu") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) override; | |||||
| }; | |||||
| class TfliteLeakyReluParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_RELU_PARSER_H | |||||
| @@ -1,87 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_add_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| auto x_index = tfliteOp->inputs[0]; | |||||
| const auto &x_tensor = tfliteTensors[x_index]; | |||||
| if (x_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); | |||||
| if (x_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (x_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (y_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_ADD_PARSER_H | |||||
| #define PREDICT_TFLITE_ADD_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteAddParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteAddParser() : TfliteNodeParser("Add") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_ADD_PARSER_H | |||||
| @@ -26,16 +26,23 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | MS_LOG(DEBUG) << "parse TfliteAddNParser"; | ||||
| std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); | std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); | ||||
| attr->N = tfliteTensors.size() - 1; | attr->N = tfliteTensors.size() - 1; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -27,6 +27,16 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | ||||
| std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | ||||
| @@ -49,11 +59,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| } | } | ||||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -25,6 +25,16 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | MS_LOG(DEBUG) << "parse TfliteArgminParser"; | ||||
| std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); | std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); | ||||
| @@ -47,11 +57,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| } | } | ||||
| attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr))); | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -0,0 +1,370 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_arithmetic_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "Add") == 0 | |||||
| || std::strcmp(node_name, "Sub") == 0 | |||||
| || std::strcmp(node_name, "Mul") == 0 | |||||
| || std::strcmp(node_name, "Div") == 0) { | |||||
| auto x_index = tfliteOp->inputs[0]; | |||||
| const auto &x_tensor = tfliteTensors[x_index]; | |||||
| if (x_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); | |||||
| if (x_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (!x_data->data.empty()) { | |||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (!y_data->data.empty()) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (std::strcmp(node_name, "Add") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sub") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSubParser"; | |||||
| std::unique_ptr<schema::SubT> attr(new schema::SubT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Sub; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Mul") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||||
| std::unique_ptr<schema::MulT> attr(new schema::MulT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Div") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||||
| std::unique_ptr<schema::DivT> attr(new schema::DivT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| } else if (std::strcmp(node_name, "FloorDiv") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | |||||
| std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_FloorDiv; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "FloorMod") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | |||||
| std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_FloorMod; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "RealDiv") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRealDivParser"; | |||||
| std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_RealDiv; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "SquaredDifference") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; | |||||
| std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_SquaredDifference; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Pow") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TflitePowParser"; | |||||
| std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); | |||||
| attr->power = 0.0f; | |||||
| attr->scale = 1.0f; | |||||
| attr->shift = 0.0f; | |||||
| op->primitive->value.type = schema::PrimitiveType_Power; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Maximum") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMaximumParser"; | |||||
| std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Maximum; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Minimum") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMinimumParser"; | |||||
| std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "Abs") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAbsParser"; | |||||
| std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Abs; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Exp") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteExpParser"; | |||||
| std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Exp; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sqrt") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSqrtParser"; | |||||
| std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Sqrt; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Rsqrt") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; | |||||
| std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Rsqrt; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Square") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSquareParser"; | |||||
| std::unique_ptr<schema::SquareT> attr(new schema::SquareT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Square; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sin") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSinParser"; | |||||
| std::unique_ptr<schema::SinT> attr(new schema::SinT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Sin; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Cos") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCosParser"; | |||||
| std::unique_ptr<schema::CosT> attr(new schema::CosT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Cos; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Log") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogParser"; | |||||
| std::unique_ptr<schema::LogT> attr(new schema::LogT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Log; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Round") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRoundParser"; | |||||
| std::unique_ptr<schema::RoundT> attr(new schema::RoundT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Round; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Ceil") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCeilParser"; | |||||
| std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Ceil; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "flOOR") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorParser"; | |||||
| std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Floor; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "Equal") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteEqualParser"; | |||||
| std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Equal; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "NotEqual") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | |||||
| std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_NotEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Greater") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | |||||
| std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Greater; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "GreaterEqual") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | |||||
| std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_GreaterEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Less") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessParser"; | |||||
| std::unique_ptr<schema::LessT> attr(new schema::LessT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_Less; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "LessEqual") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | |||||
| std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | |||||
| TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser()); | |||||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); | |||||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser()); | |||||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser()); | |||||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser()); | |||||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser()); | |||||
| TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser()); | |||||
| TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser()); | |||||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser()); | |||||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser()); | |||||
| TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser()); | |||||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser()); | |||||
| TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSqrtParser()); | |||||
| TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser()); | |||||
| TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSquareParser()); | |||||
| TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSinParser()); | |||||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser()); | |||||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | |||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | |||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | |||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEParser("Greater", new TfliteGreaterParser()); | |||||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser()); | |||||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser()); | |||||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,207 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MATH_PARSER_H | |||||
| #define PREDICT_TFLITE_MATH_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteAddParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSubParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteSubParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMulParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMulParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteFloorDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorModParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteFloorModParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSquaredDifferenceParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteSquaredDifferenceParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRealDivParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteRealDivParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TflitePowParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TflitePowParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMaximumParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMaximumParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteMinimumParser : public TfliteDoubleInputOpParser { | |||||
| public: | |||||
| TfliteMinimumParser() : TfliteDoubleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteAbsParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteExpParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteExpParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSqrtParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSqrtParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSquareParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSquareParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteSinParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteSinParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCosParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteCosParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRsqrtParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteRsqrtParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteLogParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteLogParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteRoundParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteRoundParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCeilParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteCeilParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteFloorParser : public TfliteSingleInputOpParser { | |||||
| public: | |||||
| TfliteFloorParser() : TfliteSingleInputOpParser() {} | |||||
| }; | |||||
| class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| class TfliteEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteNotEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteNotEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteGreaterParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteGreaterParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteGreaterEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteGreaterEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteLessParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteLessParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| class TfliteLessEqualParser : public TfliteCompareOpParser { | |||||
| public: | |||||
| TfliteLessEqualParser() : TfliteCompareOpParser() {} | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_MATH_PARSER_H | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteBatchToSpaceNDParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteBatchToSpaceNDParser"; | |||||
| std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); | |||||
| // in tflite | |||||
| // blockShape should be a 1D tensor with dimension [spatial_dims_num] | |||||
| // crops should be a 2D tensor with dimension [spatial_dims_num, 2] | |||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get BatchToSpaceNd -> blockShape failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { | |||||
| MS_LOG(ERROR) << "get BatchToSpaceNd -> crops failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H | |||||
| #define PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteBatchToSpaceNDParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteBatchToSpaceNDParser() : TfliteNodeParser("BatchToSpaceND") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H | |||||
| @@ -18,6 +18,7 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_batch_to_space_parser.h" | #include "tools/converter/parser/tflite/tflite_batch_to_space_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -26,7 +27,28 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "BatchToSpace") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | |||||
| } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | |||||
| // in tflite | |||||
| // blockShape should be a 1D tensor with dimension [spatial_dims_num] | |||||
| // crops should be a 2D tensor with dimension [spatial_dims_num, 2] | |||||
| } | |||||
| std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); | std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); | ||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { | if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { | ||||
| @@ -38,14 +60,13 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser()); | ||||
| TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,9 +32,14 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| TensorCache *tensor_cache, bool quantized_model) override; | |||||
| }; | }; | ||||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | |||||
| public: | |||||
| TfliteBatchToSpaceNDParser() : TfliteBatchToSpaceParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,6 +26,16 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> & | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | ||||
| std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); | std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); | ||||
| @@ -34,11 +44,8 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> & | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -26,6 +26,16 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | MS_LOG(DEBUG) << "parse TfliteCastParser"; | ||||
| std::unique_ptr<schema::CastT> attr(new schema::CastT()); | std::unique_ptr<schema::CastT> attr(new schema::CastT()); | ||||
| @@ -43,11 +53,8 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| } | } | ||||
| attr->dstT = dtype_map[out_tensor->type]; | attr->dstT = dtype_map[out_tensor->type]; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_ceil_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteCeilParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCeilParser"; | |||||
| std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Ceil; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_CEIL_PARSER_H | |||||
| #define PREDICT_TFLITE_CEIL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteCeilParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteCeilParser() : TfliteNodeParser("Ceil") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_CEIL_PARSER_H | |||||
| @@ -25,6 +25,16 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | MS_LOG(DEBUG) << "parse TfliteConcatParser"; | ||||
| std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | ||||
| @@ -37,11 +47,8 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->n = tfliteOp->inputs.size(); | attr->n = tfliteOp->inputs.size(); | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Concat; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Concat; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -25,6 +25,16 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | MS_LOG(DEBUG) << "parse TfliteConvParser"; | ||||
| std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); | std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); | ||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); | const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); | ||||
| @@ -49,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -69,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -77,11 +87,8 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| // calculate pad params | // calculate pad params | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ | |||||
| #define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -34,5 +34,5 @@ class TfliteConverter : public Converter { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_ | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_cos_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteCosParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteCosParser"; | |||||
| std::unique_ptr<schema::CosT> attr(new schema::CosT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Cos; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteCosParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_COS_PARSER_H | |||||
| #define PREDICT_TFLITE_COS_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteCosParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteCosParser() : TfliteNodeParser("Cos") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_COS_PARSER_H | |||||
| @@ -25,6 +25,16 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | ||||
| std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | ||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); | const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); | ||||
| @@ -49,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto weight_shape = weight_tensor->shape; | auto weight_shape = weight_tensor->shape; | ||||
| @@ -58,11 +68,8 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->kernelW = weight_shape[CHWK_W]; | attr->kernelW = weight_shape[CHWK_W]; | ||||
| attr->kernelH = weight_shape[CHWK_H]; | attr->kernelH = weight_shape[CHWK_H]; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -26,6 +26,16 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | ||||
| std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | ||||
| @@ -38,11 +48,8 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| attr->format = schema::Format_NHWC; | attr->format = schema::Format_NHWC; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -66,11 +66,8 @@ STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, | |||||
| } | } | ||||
| } | } | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = convAttr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = convAttr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -79,6 +76,16 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | ||||
| std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); | std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); | ||||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | ||||
| @@ -96,10 +103,18 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto input_index = tflite_op->inputs[0]; | auto input_index = tflite_op->inputs[0]; | ||||
| const auto &input_tenosr = tflite_tensors[input_index]; | const auto &input_tenosr = tflite_tensors[input_index]; | ||||
| if (input_tenosr == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto input_shape = input_tenosr->shape; | auto input_shape = input_tenosr->shape; | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | const auto &weight_tensor = tflite_tensors[weight_index]; | ||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | auto weight_shape = weight_tensor->shape; | ||||
| attr->channelIn = input_shape[KHWC_C]; | attr->channelIn = input_shape[KHWC_C]; | ||||
| attr->channelMultiplier = tflite_attr->depth_multiplier; | attr->channelMultiplier = tflite_attr->depth_multiplier; | ||||
| @@ -108,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -118,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| auto bias_index = tflite_op->inputs[2]; | auto bias_index = tflite_op->inputs[2]; | ||||
| const auto &bias_tensor = tflite_tensors[bias_index]; | const auto &bias_tensor = tflite_tensors[bias_index]; | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -126,11 +141,10 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| if (attr->channelMultiplier > 1) { | if (attr->channelMultiplier > 1) { | ||||
| if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { | if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { | ||||
| // MS_LOGE("Parse Group DepthwiseConv failed"); | |||||
| MS_LOG(ERROR) << "Parse Group DepthwiseConv failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } | } | ||||
| @@ -1,86 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_div_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||||
| std::unique_ptr<schema::DivT> attr(new schema::DivT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| auto x_index = tfliteOp->inputs[0]; | |||||
| const auto &x_tensor = tfliteTensors[x_index]; | |||||
| if (x_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); | |||||
| if (x_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (x_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (y_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDivParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_DIV_PARSER_H | |||||
| #define PREDICT_TFLITE_DIV_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteDivParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteDivParser() : TfliteNodeParser("Div") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_DIV_PARSER_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_equal_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteEqualParser"; | |||||
| std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Equal; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_EQUAL_PARSER_H | |||||
| #define LITE_TFLITE_EQUAL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteEqualParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteEqualParser() : TfliteNodeParser("Equal") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_EQUAL_PARSER_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_exp_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteExpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteExpParser"; | |||||
| std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Exp; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteExpParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_EXP_PARSER_H | |||||
| #define PREDICT_TFLITE_EXP_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteExpParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteExpParser() : TfliteNodeParser("Exp") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_EXP_PARSER_H | |||||
| @@ -27,6 +27,16 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | ||||
| std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); | std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); | ||||
| @@ -24,6 +24,16 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | ||||
| std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | ||||
| @@ -34,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -48,18 +58,15 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| attr->axis = 1; | attr->axis = 1; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -27,6 +27,16 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | MS_LOG(DEBUG) << "parse TfliteFillParser"; | ||||
| std::unique_ptr<schema::FillT> attr(new schema::FillT()); | std::unique_ptr<schema::FillT> attr(new schema::FillT()); | ||||
| @@ -37,11 +47,8 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| } | } | ||||
| } | } | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Fill; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Fill; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_floor_div_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteFloorDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | |||||
| std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_FloorDiv; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteFloorDivParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_FLOOR_DIV_PARSER_H | |||||
| #define PREDICT_TFLITE_FLOOR_DIV_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteFloorDivParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteFloorDivParser() : TfliteNodeParser("FloorDiv") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_FLOOR_DIV_PARSER_H | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_floor_mod_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteFloorModParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | |||||
| std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_FloorMod; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteFloorModParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_FLOOR_MOD_PARSER_H | |||||
| #define PREDICT_TFLITE_FLOOR_MOD_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteFloorModParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteFloorModParser() : TfliteNodeParser("FloorMod") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_FLOOR_MOD_PARSER_H | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_floor_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteFloorParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFloorParser"; | |||||
| std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Floor; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_FLOOR_PARSER_H | |||||
| #define PREDICT_TFLITE_FLOOR_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteFloorParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteFloorParser() : TfliteNodeParser("flOOR") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_FLOOR_PARSER_H | |||||
| @@ -25,6 +25,16 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | ||||
| std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | ||||
| @@ -35,7 +45,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -49,18 +59,15 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| attr->axis = 1; | attr->axis = 1; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -27,16 +27,23 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | ||||
| std::unique_ptr<schema::GatherNdT> attr(new schema::GatherNdT()); | std::unique_ptr<schema::GatherNdT> attr(new schema::GatherNdT()); | ||||
| attr->batchDims = 0; | attr->batchDims = 0; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_GatherNd; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_GatherNd; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -27,6 +27,16 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherParser"; | MS_LOG(DEBUG) << "parse TfliteGatherParser"; | ||||
| std::unique_ptr<schema::GatherT> attr(new schema::GatherT()); | std::unique_ptr<schema::GatherT> attr(new schema::GatherT()); | ||||
| @@ -39,11 +49,8 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->batchDims = 0; | attr->batchDims = 0; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Gather; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Gather; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_greater_equal_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteGreaterEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | |||||
| std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_GreaterEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteGreaterEqualParser("GreaterEqual", new TfliteGreaterEqualParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_GREATER_EQUAL_PARSER_H | |||||
| #define LITE_TFLITE_GREATER_EQUAL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteGreaterEqualParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteGreaterEqualParser() : TfliteNodeParser("GreaterEqual") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_GREATER_EQUAL_PARSER_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_greater_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteGreaterParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | |||||
| std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Greater; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteGreaterParser("Greater", new TfliteGreaterParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_GREATER_PARSER_H | |||||
| #define LITE_TFLITE_GREATER_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteGreaterParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteGreaterParser() : TfliteNodeParser("Greater") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_GREATER_PARSER_H | |||||
| @@ -25,16 +25,23 @@ STATUS TfliteHardSwishParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(INFO) << "parse TfliteHardSwishParser"; | MS_LOG(INFO) << "parse TfliteHardSwishParser"; | ||||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | ||||
| attr->type = schema::ActivationType_HSWISH; | attr->type = schema::ActivationType_HSWISH; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,49 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_leaky_relu_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; | |||||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||||
| attr->alpha = tflite_attr->alpha; | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_LEAKY_RELU_PARSER_H | |||||
| #define PREDICT_TFLITE_LEAKY_RELU_PARSER_H | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLeakyReluParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_LEAKY_RELU_PARSER_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_less_equal_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLessEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | |||||
| std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLessEqualParser("LessEqual", new TfliteLessEqualParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_LESS_EQUAL_PARSER_H | |||||
| #define LITE_TFLITE_LESS_EQUAL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLessEqualParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLessEqualParser() : TfliteNodeParser("LessEqual") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_LESS_EQUAL_PARSER_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_less_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLessParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLessParser"; | |||||
| std::unique_ptr<schema::LessT> attr(new schema::LessT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Less; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLessParser("Less", new TfliteLessParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_LESS_PARSER_H | |||||
| #define LITE_TFLITE_LESS_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLessParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLessParser() : TfliteNodeParser("Less") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_LESS_PARSER_H | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_log_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteLogParser"; | |||||
| std::unique_ptr<schema::LogT> attr(new schema::LogT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Log; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_LOG_PARSER_H | |||||
| #define PREDICT_TFLITE_LOG_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLogParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLogParser() : TfliteNodeParser("Log") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_LOG_PARSER_H | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_logical_and_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogicalAndParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteLogicalAndParser"; | |||||
| std::unique_ptr<schema::LogicalAndT> attr(new schema::LogicalAndT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_logical_not_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogicalNotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteLogicalNotParser"; | |||||
| std::unique_ptr<schema::LogicalNotT> attr(new schema::LogicalNotT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalNot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLogicalNotParser("LogicalNot", new TfliteLogicalNotParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_LOGICAL_NOT_PARSER_H | |||||
| #define PREDICT_TFLITE_LOGICAL_NOT_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLogicalNotParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLogicalNotParser() : TfliteNodeParser("LogicalNot") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_LOGICAL_NOT_PARSER_H | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_logical_or_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogicalOrParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteLogicalOrParser"; | |||||
| std::unique_ptr<schema::LogicalOrT> attr(new schema::LogicalOrT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalOr; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLogicalOrParser("LogicalOr", new TfliteLogicalOrParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_LOGICAL_OR_PARSER_H | |||||
| #define PREDICT_TFLITE_LOGICAL_OR_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLogicalOrParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLogicalOrParser() : TfliteNodeParser("LogicalOr") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_LOGICAL_OR_PARSER_H | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/converter/parser/tflite/tflite_logical_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogicalParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "LogicalAnd") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; | |||||
| std::unique_ptr<schema::LogicalAndT> attr(new schema::LogicalAndT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "LogicalNot") == 0) { | |||||
| MS_LOG(INFO) << "parse TfliteLogicalNotParser"; | |||||
| std::unique_ptr<schema::LogicalNotT> attr(new schema::LogicalNotT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalNot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "LogicalOr") == 0) { | |||||
| MS_LOG(INFO) << "parse TfliteLogicalOrParser"; | |||||
| std::unique_ptr<schema::LogicalOrT> attr(new schema::LogicalOrT()); | |||||
| op->primitive->value.type = schema::PrimitiveType_LogicalOr; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong logical type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser()); | |||||
| TfliteNodeRegister g_TfliteLogicalNotParser("LogicalNot", new TfliteLogicalNotParser()); | |||||
| TfliteNodeRegister g_TfliteLogicalOrParser("LogicalOr", new TfliteLogicalOrParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -24,9 +24,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class TfliteLogicalAndParser : public TfliteNodeParser { | |||||
| class TfliteLogicalParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteLogicalAndParser() : TfliteNodeParser("LogicalAnd") {} | |||||
| TfliteLogicalParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | ||||
| @@ -35,6 +36,21 @@ class TfliteLogicalAndParser : public TfliteNodeParser { | |||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) override; | bool quantizedModel) override; | ||||
| }; | }; | ||||
| class TfliteLogicalAndParser : public TfliteLogicalParser { | |||||
| public: | |||||
| TfliteLogicalAndParser() : TfliteLogicalParser() {} | |||||
| }; | |||||
| class TfliteLogicalNotParser : public TfliteLogicalParser { | |||||
| public: | |||||
| TfliteLogicalNotParser() : TfliteLogicalParser() {} | |||||
| }; | |||||
| class TfliteLogicalOrParser : public TfliteLogicalParser { | |||||
| public: | |||||
| TfliteLogicalOrParser() : TfliteLogicalParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_logistic_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteLogisticParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | |||||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | |||||
| attr->type = schema::ActivationType_SIGMOID; | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_LOGISTIC_PARSER_H | |||||
| #define PREDICT_TFLITE_LOGISTIC_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteLogisticParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLogisticParser() : TfliteNodeParser("Logistic") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_CONCAT_PARSER_H | |||||
| @@ -27,6 +27,16 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteLRNParser"; | MS_LOG(DEBUG) << "parse TfliteLRNParser"; | ||||
| std::unique_ptr<schema::LocalResponseNormalizationT> attr(new schema::LocalResponseNormalizationT()); | std::unique_ptr<schema::LocalResponseNormalizationT> attr(new schema::LocalResponseNormalizationT()); | ||||
| @@ -40,11 +50,8 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp | |||||
| attr->beta = tflite_attr->beta; | attr->beta = tflite_attr->beta; | ||||
| attr->bias = tflite_attr->bias; | attr->bias = tflite_attr->bias; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,60 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_max_pooling_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMaxPoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; | |||||
| std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT()); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->format = schema::Format_NHWC; | |||||
| // attr->global | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||||
| attr->windowW = tflite_attr->filter_width; | |||||
| attr->windowH = tflite_attr->filter_height; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| // calculate pad params | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MAX_POOLING_PARSER_H | |||||
| #define PREDICT_TFLITE_MAX_POOLING_PARSER_H | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteMaxPoolingParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteMaxPoolingParser() : TfliteNodeParser("MaxPooling") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_CONV_PARSER_H | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_maximum_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMaximumParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteMaximumParser"; | |||||
| std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Maximum; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteMaximumParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MAXIMUM_PARSER_H | |||||
| #define PREDICT_TFLITE_MAXIMUM_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteMaximumParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteMaximumParser() : TfliteNodeParser("Maximum") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_MAXIMUM_PARSER_H | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_mean_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMeanParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMeanParser"; | |||||
| std::unique_ptr<schema::MeanT> attr(new schema::MeanT()); | |||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->keepDims = tflite_attr->keep_dims; | |||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { | |||||
| MS_LOG(ERROR) << "Mean get axis attr failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mean; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteMeanParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MEAN_PARSER_H | |||||
| #define PREDICT_TFLITE_MEAN_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteMeanParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteMeanParser() : TfliteNodeParser("Mean") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_MEAN_PARSER_H | |||||
| @@ -1,60 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_mean_pooling_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMeanPoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; | |||||
| std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT()); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->windowW = tflite_attr->filter_width; | |||||
| attr->windowH = tflite_attr->filter_height; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format_NHWC; | |||||
| // attr->global | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| // calculate pad params | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TfliteMeanPoolingParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_minimum_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMinimumParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| MS_LOG(INFO) << "parse TfliteMinimumParser"; | |||||
| std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteMinimumParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MINIMUM_PARSER_H | |||||
| #define PREDICT_TFLITE_MINIMUM_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteMinimumParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteMinimumParser() : TfliteNodeParser("Minimum") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_MINIMUM_PARSER_H | |||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | #include "tools/converter/parser/tflite/tflite_model_parser.h" | ||||
| #include <fstream> | |||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | #include <memory> | ||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| @@ -71,6 +70,10 @@ STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::S | |||||
| quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); | quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); | ||||
| for (const auto &index : quant_params_index) { | for (const auto &index : quant_params_index) { | ||||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | const auto &tflite_tensor = tflite_subgraph->tensors[index]; | ||||
| if (tflite_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | ||||
| tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { | tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { | ||||
| continue; | continue; | ||||
| @@ -101,6 +104,10 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||||
| TensorCache *tensorCache) { | TensorCache *tensorCache) { | ||||
| for (const auto &index : tflite_op->outputs) { | for (const auto &index : tflite_op->outputs) { | ||||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | const auto &tflite_tensor = tflite_subgraph->tensors[index]; | ||||
| if (tflite_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); | std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); | ||||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type); | tensor->dataType = GetTfliteDataType(tflite_tensor->type); | ||||
| tensor->dims = tflite_tensor->shape; | tensor->dims = tflite_tensor->shape; | ||||
| @@ -108,7 +115,6 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||||
| auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); | auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); | ||||
| op->outputIndex.emplace_back(opOutputIndex); | op->outputIndex.emplace_back(opOutputIndex); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -123,6 +129,10 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t | |||||
| for (const auto &tflite_index : op_inputs) { | for (const auto &tflite_index : op_inputs) { | ||||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; | const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; | ||||
| if (tflite_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tensor_name = tflite_tensor->name; | auto tensor_name = tflite_tensor->name; | ||||
| auto op = tfliteOpMap[tflite_op.get()]; | auto op = tfliteOpMap[tflite_op.get()]; | ||||
| unsigned int index = tensorCache->FindTensor(tensor_name); | unsigned int index = tensorCache->FindTensor(tensor_name); | ||||
| @@ -144,10 +154,8 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_ | |||||
| std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | ||||
| op->name = opType + "-" + std::to_string(i++); | op->name = opType + "-" + std::to_string(i++); | ||||
| MS_LOG(INFO) << "parse op: " << op->name.c_str(); | |||||
| MS_LOG(INFO) << "parse op: [%s]" << op->name.c_str(); | |||||
| // 1. init op attr params | |||||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); | auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); | ||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str(); | MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str(); | ||||
| @@ -164,7 +172,7 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_ | |||||
| status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); | status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Set Op " << op->name.c_str() << " Output Index Failed!"; | |||||
| MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -175,8 +183,7 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_ | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| TensorCache *tensor_cache) { | TensorCache *tensor_cache) { | ||||
| for (const auto &index : tflite_subgraph->inputs) { | for (const auto &index : tflite_subgraph->inputs) { | ||||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | const auto &tflite_tensor = tflite_subgraph->tensors[index]; | ||||
| @@ -206,35 +213,31 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache & | |||||
| } | } | ||||
| MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { | MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { | ||||
| std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT); | |||||
| if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { | if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { | ||||
| // MS_LOGE("INPUT ILLEGAL: modelFile must be *.tflite"); | |||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MS_LOG(INFO) << "modelFile is :" << modelFile; | |||||
| std::unique_ptr<tflite::ModelT> tflite_model(new tflite::ModelT()); | std::unique_ptr<tflite::ModelT> tflite_model(new tflite::ModelT()); | ||||
| tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); | tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); | ||||
| if (tflite_model == nullptr) { | if (tflite_model == nullptr) { | ||||
| // MS_LOGE("read tflite model failed"); | |||||
| MS_LOG(ERROR) << "read tflite model failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MS_LOG(INFO) << "after read model"; | |||||
| TensorCache tensorCache; | |||||
| if (tflite_model->subgraphs.size() != 1) { | if (tflite_model->subgraphs.size() != 1) { | ||||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | MS_LOG(ERROR) << "read tflite model subgraphs failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const auto &tflite_subgraph = tflite_model->subgraphs[0]; | const auto &tflite_subgraph = tflite_model->subgraphs[0]; | ||||
| subGraph->name = "MS_model converted by TF-Lite"; | |||||
| // set dst subGraph input/output tensor | // set dst subGraph input/output tensor | ||||
| SetInputTensor(tflite_model, tflite_subgraph, &tensorCache); | |||||
| // set dst subGraph op attr etc. | |||||
| TensorCache tensorCache; | |||||
| SetInputTensor(tflite_subgraph, &tensorCache); | |||||
| // set dst subGraph op attr and tensor_cache. | |||||
| std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT); | |||||
| subGraph->name = "MS_model converted by TF-Lite"; | |||||
| auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); | auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "ParseOp failed."; | MS_LOG(ERROR) << "ParseOp failed."; | ||||
| @@ -244,21 +247,20 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); | auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); | ||||
| if (status_tmp != RET_OK) { | if (status_tmp != RET_OK) { | ||||
| // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); | |||||
| MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!"; | |||||
| } | } | ||||
| } | } | ||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); | auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); | ||||
| if (statusTmp != RET_OK) { | if (statusTmp != RET_OK) { | ||||
| // MS_LOGE("ParseTfliteQuantParams %s Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); | |||||
| MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!"; | |||||
| } | } | ||||
| } | } | ||||
| SetGraphTensorIndex(tensorCache, subGraph.get()); | SetGraphTensorIndex(tensorCache, subGraph.get()); | ||||
| SetAllTensors(tensorCache, subGraph.get()); | SetAllTensors(tensorCache, subGraph.get()); | ||||
| return subGraph.release(); | return subGraph.release(); | ||||
| // return Fb2Anf(subGraph.release()); | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,29 +14,24 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #include <fcntl.h> | #include <fcntl.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <google/protobuf/io/coded_stream.h> | #include <google/protobuf/io/coded_stream.h> | ||||
| #include <google/protobuf/io/zero_copy_stream_impl.h> | #include <google/protobuf/io/zero_copy_stream_impl.h> | ||||
| #include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #include "mindspore/lite/schema/inner/model_generated.h" | #include "mindspore/lite/schema/inner/model_generated.h" | ||||
| // using namespace tflite; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class TfliteModelParser : public ModelParser { | class TfliteModelParser : public ModelParser { | ||||
| @@ -50,8 +45,7 @@ class TfliteModelParser : public ModelParser { | |||||
| private: | private: | ||||
| std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf); | std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf); | ||||
| void SetInputTensor(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); | |||||
| void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); | |||||
| void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, | void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, | ||||
| schema::MetaGraphT *subGraphDef); | schema::MetaGraphT *subGraphDef); | ||||
| @@ -82,6 +76,5 @@ class TfliteModelParser : public ModelParser { | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_CONV | |||||
| // ERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H | |||||
| @@ -1,88 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_mul_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteMulParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||||
| std::unique_ptr<schema::MulT> attr(new schema::MulT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| auto x_index = tfliteOp->inputs[0]; | |||||
| const auto &x_tensor = tfliteTensors[x_index]; | |||||
| if (x_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); | |||||
| if (x_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the first input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (x_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (y_data->data.size() > 0) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_MUL_PARSER_H | |||||
| #define PREDICT_TFLITE_MUL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteMulParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteMulParser() : TfliteNodeParser("Mul") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_MUL_PARSER_H | |||||
| @@ -16,80 +16,36 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | |||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { | |||||
| const tflite::TensorT *tflite_tensor, | |||||
| schema::TensorT *tensor) { | |||||
| auto count = 1; | auto count = 1; | ||||
| std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); | std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); | ||||
| auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); | auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); | ||||
| auto buffer_idx = tflite_tensor->buffer; | auto buffer_idx = tflite_tensor->buffer; | ||||
| if (!tfliteModelBuffer[buffer_idx]->data.empty()) { | if (!tfliteModelBuffer[buffer_idx]->data.empty()) { | ||||
| tensor->data.resize(data_size); | tensor->data.resize(data_size); | ||||
| auto ret = memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size); | |||||
| if (ret) { | |||||
| MS_LOG(ERROR) << "memcpy tensor data failed, error code: %d" << ret; | |||||
| return ret; | |||||
| if (memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size)) { | |||||
| MS_LOG(ERROR) << "memcpy tensor data failed"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "src tensor data is empty."; | |||||
| MS_LOG(ERROR) << "src tensor data is empty"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteNodeParser::ParseWeight(const std::vector<tflite::TensorT *> &weight_tenosrs, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| mindspore::lite::TensorCache *tensor_cache, schema::Format format) { | |||||
| for (const auto &weight_tensor : weight_tenosrs) { | |||||
| auto idx = tensor_cache->FindTensor(weight_tensor->name); | |||||
| if (idx < 0) { | |||||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); | |||||
| tensor->dataType = GetTfliteDataType(weight_tensor->type); | |||||
| tensor->dims = weight_tensor->shape; | |||||
| tensor->nodeType = schema::NodeType_ValueNode; | |||||
| // memcpy tensor data | |||||
| // buffer is 0 (which refers to an always existent empty buffer) | |||||
| if (weight_tensor->buffer > 0) { | |||||
| CopyTfliteTensorData(tfliteModelBuffer, weight_tensor, tensor.get()); | |||||
| } | |||||
| MS_LOG(DEBUG) << "add weight tensor name: %s", weight_tensor->name.c_str(); | |||||
| tensor_cache->AddTensor(weight_tensor->name, tensor.release(), TF_CONST); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteNodeParser::ParseBias(const std::vector<tflite::TensorT *> &bias_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| TensorCache *tensor_cache) { | |||||
| for (const auto &bias_tensor : bias_tensors) { | |||||
| auto idx = tensor_cache->FindTensor(bias_tensor->name); | |||||
| if (idx < 0) { | |||||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); | |||||
| tensor->dataType = GetTfliteDataType(bias_tensor->type); | |||||
| tensor->dims = bias_tensor->shape; | |||||
| tensor->nodeType = schema::NodeType_ValueNode; | |||||
| // memcpy tensor data | |||||
| // buffer is 0 (which refers to an always existent empty buffer) | |||||
| if (bias_tensor->buffer > 0) { | |||||
| CopyTfliteTensorData(tfliteModelBuffer, bias_tensor, tensor.get()); | |||||
| } | |||||
| // MS_LOGD("add weight tensor name: %s", bias_tensor->name.c_str()); | |||||
| tensor_cache->AddTensor(bias_tensor->name, tensor.release(), TF_CONST); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| mindspore::lite::TensorCache *tensor_cache, int node_type, | |||||
| bool ifCopy) { | |||||
| mindspore::lite::TensorCache *tensor_cache, | |||||
| int node_type) { | |||||
| for (const auto &t : ts) { | for (const auto &t : ts) { | ||||
| auto idx = tensor_cache->FindTensor(t->name); | auto idx = tensor_cache->FindTensor(t->name); | ||||
| if (idx < 0) { | if (idx < 0) { | ||||
| @@ -97,29 +53,15 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | |||||
| tensor->dataType = GetTfliteDataType(t->type); | tensor->dataType = GetTfliteDataType(t->type); | ||||
| tensor->dims = t->shape; | tensor->dims = t->shape; | ||||
| // memcpy tensor data, buffer is 0 (which refers to an always existent empty buffer) | |||||
| if (ifCopy && t->buffer > 0) { | |||||
| if (t->buffer > 0) { | |||||
| CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); | CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "add weight tensor name: %s", t->name.c_str(); | |||||
| MS_LOG(DEBUG) << "add tensor name: " << t->name.c_str(); | |||||
| tensor_cache->AddTensor(t->name, tensor.release(), node_type); | tensor_cache->AddTensor(t->name, tensor.release(), node_type); | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TypeId TfliteNodeParser::GetTfliteDataType(const tflite::TensorType &tflite_data_type) { | |||||
| static std::unordered_map<int, TypeId> type_map = { | |||||
| {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, | |||||
| {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | |||||
| }; | |||||
| auto iter = type_map.find(tflite_data_type); | |||||
| if (iter == type_map.end()) { | |||||
| return kTypeUnknown; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_NODE_PARSER_H | |||||
| #define PREDICT_TFLITE_NODE_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -34,30 +34,24 @@ class TfliteNodeParser { | |||||
| public: | public: | ||||
| explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} | explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} | ||||
| virtual ~TfliteNodeParser() {} | |||||
| virtual ~TfliteNodeParser() = default; | |||||
| virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) = 0; | |||||
| STATUS ParseWeight(const std::vector<tflite::TensorT *> &weight_tenosr, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, TensorCache *tensor_cache, | |||||
| schema::Format format); | |||||
| STATUS ParseBias(const std::vector<tflite::TensorT *> &weight_tenosr, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, TensorCache *tensor_cache); | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) = 0; | |||||
| STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts, | STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| mindspore::lite::TensorCache *tensor_cache, int node_type, | |||||
| bool ifCopy); | |||||
| mindspore::lite::TensorCache *tensor_cache, | |||||
| int node_type); | |||||
| STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const tflite::TensorT *tflite_tensor, schema::TensorT *tensor); | |||||
| TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); | |||||
| const tflite::TensorT *tflite_tensor, | |||||
| schema::TensorT *tensor); | |||||
| template <typename T> | template <typename T> | ||||
| STATUS GetTfliteData(const int32_t tensor_index, const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | STATUS GetTfliteData(const int32_t tensor_index, const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | ||||
| @@ -67,6 +61,10 @@ class TfliteNodeParser { | |||||
| std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), | std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), | ||||
| [&](int32_t sha) { count *= sha; }); | [&](int32_t sha) { count *= sha; }); | ||||
| auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; | auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; | ||||
| if (buf_data == nullptr) { | |||||
| MS_LOG(ERROR) << "buf_data is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto data_ptr = buf_data->data.data(); | auto data_ptr = buf_data->data.data(); | ||||
| switch (tfliteTensors[tensor_index]->type) { | switch (tfliteTensors[tensor_index]->type) { | ||||
| case tflite::TensorType_UINT8: { | case tflite::TensorType_UINT8: { | ||||
| @@ -117,18 +115,18 @@ class TfliteNodeParser { | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| default: { | |||||
| MS_LOG(ERROR) << "wrong tensor type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| protected: | |||||
| bool isQuantizedModel(); | |||||
| protected: | protected: | ||||
| const std::string &name; | const std::string &name; | ||||
| bool quantizedModel; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_NODE_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_H | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| #define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -46,5 +46,5 @@ class TfliteNodeRegister { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_not_equal_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteNotEqualParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | |||||
| std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_NotEqual; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_NOT_EQUAL_PARSER_H | |||||
| #define LITE_TFLITE_NOT_EQUAL_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteNotEqualParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteNotEqualParser() : TfliteNodeParser("NotEqual") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_NOT_EQUAL_PARSER_H | |||||
| @@ -25,6 +25,16 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(INFO) << "parse TfliteOneHotParser"; | MS_LOG(INFO) << "parse TfliteOneHotParser"; | ||||
| std::unique_ptr<schema::OneHotT> attr(new schema::OneHotT()); | std::unique_ptr<schema::OneHotT> attr(new schema::OneHotT()); | ||||
| @@ -46,11 +56,8 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| } | } | ||||
| attr->axis = axis; | attr->axis = axis; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,47 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_p_relu_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "paser TflitePreluParser"; | |||||
| std::unique_ptr<schema::PreluT> attr(new schema::PreluT()); | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { | |||||
| MS_LOG(ERROR) << "get pRelu -> slope failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Prelu; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_P_RELU_PARSER_H | |||||
| #define LITE_TFLITE_P_RELU_PARSER_H | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TflitePreluParser : public TfliteNodeParser { | |||||
| public: | |||||
| TflitePreluParser() : TfliteNodeParser("Prelu") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_P_RELU_PARSER_H | |||||
| @@ -25,6 +25,16 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | ||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TflitePadParser"; | MS_LOG(DEBUG) << "parse TflitePadParser"; | ||||
| std::unique_ptr<schema::PadT> attr(new schema::PadT()); | std::unique_ptr<schema::PadT> attr(new schema::PadT()); | ||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions(); | const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions(); | ||||
| @@ -40,11 +50,8 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -0,0 +1,80 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/converter/parser/tflite/tflite_pooling_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TflitePoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::PoolingT> attr(new schema::PoolingT()); | |||||
| std::vector<std::string> node_name_str; | |||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | |||||
| if (std::strcmp(node_name, "MeanPooling") == 0) { | |||||
| MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| } else if (std::strcmp(node_name, "MaxPooling") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong pooling type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->windowW = tflite_attr->filter_width; | |||||
| attr->windowH = tflite_attr->filter_height; | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format_NHWC; | |||||
| // attr->global | |||||
| // calculate pad params | |||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TfliteMeanPoolingParser()); | |||||
| TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -24,9 +24,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class TfliteMeanPoolingParser : public TfliteNodeParser { | |||||
| class TflitePoolingParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteMeanPoolingParser() : TfliteNodeParser("MeanPooling") {} | |||||
| TflitePoolingParser() : TfliteNodeParser("node_name") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | ||||
| @@ -35,6 +35,16 @@ class TfliteMeanPoolingParser : public TfliteNodeParser { | |||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) override; | bool quantizedModel) override; | ||||
| }; | }; | ||||
| class TfliteMeanPoolingParser : public TflitePoolingParser { | |||||
| public: | |||||
| TfliteMeanPoolingParser() : TflitePoolingParser() {} | |||||
| }; | |||||
| class TfliteMaxPoolingParser : public TflitePoolingParser { | |||||
| public: | |||||
| TfliteMaxPoolingParser() : TflitePoolingParser() {} | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,47 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_pow_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TflitePowParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| MS_LOG(DEBUG) << "parse TflitePowParser"; | |||||
| std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); | |||||
| attr->power = 0.0f; | |||||
| attr->scale = 1.0f; | |||||
| attr->shift = 0.0f; | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Power; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_TflitePowParser("Pow", new TflitePowParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PREDICT_TFLITE_POW_PARSER_H | |||||
| #define PREDICT_TFLITE_POW_PARSER_H | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TflitePowParser : public TfliteNodeParser { | |||||
| public: | |||||
| TflitePowParser() : TfliteNodeParser("Pow") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // PREDICT_TFLITE_POW_PARSER_H | |||||
| @@ -27,16 +27,23 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteRangeParser"; | MS_LOG(DEBUG) << "parse TfliteRangeParser"; | ||||
| std::unique_ptr<schema::RangeT> attr(new schema::RangeT()); | std::unique_ptr<schema::RangeT> attr(new schema::RangeT()); | ||||
| attr->dType = 0; | attr->dType = 0; | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -27,14 +27,21 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantizedModel) { | bool quantizedModel) { | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteRankParser"; | MS_LOG(DEBUG) << "parse TfliteRankParser"; | ||||
| std::unique_ptr<schema::RankT> attr(new schema::RankT()); | std::unique_ptr<schema::RankT> attr(new schema::RankT()); | ||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Rank; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Rank; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_real_div_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteRealDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRealDivParser"; | |||||
| std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); | |||||
| if (op != nullptr) { | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| op->primitive->value.type = schema::PrimitiveType_RealDiv; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteRealDivParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||