append ut refactor tflite parsers modify tflite parser, ut and model supplement caffe flatten parser fix the weight tensor format of deconv bug fix bug when idx=-1 fix the weight tensor format of depthConv bug.tags/v0.7.0-beta
| @@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./ | |||||
| TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ | TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ | ||||
| cp -fr $TEST_DATA_DIR/testPK ./data | cp -fr $TEST_DATA_DIR/testPK ./data | ||||
| #./lite-test --gtest_filter="*MindDataTestTensorDE*" | |||||
| #./lite-test --gtest_filter="*MindDataTestEager*" | |||||
| # | |||||
| #./lite-test --gtest_filter="TestTfliteParser*" | |||||
| # | |||||
| #./lite-test --gtest_filter="*TestHebing*" | |||||
| # | |||||
| #./lite-test --gtest_filter=TestFcFp32* | |||||
| #./lite-test --gtest_filter=TestConv1x1Fp32* | |||||
| #./lite-test --gtest_filter=TestStrassenFp32* | |||||
| #./lite-test --gtest_filter=TestDeConvolutionFp32* | |||||
| # | |||||
| #./lite-test --gtest_filter=TestPadInt8.* | |||||
| #./lite-test --gtest_filter=TestDeconvInt8.* | |||||
| ./lite-test --gtest_filter="*MindDataTestTensorDE*" | |||||
| ./lite-test --gtest_filter="*MindDataTestEager*" | |||||
| ./lite-test --gtest_filter="TestTfliteParser*" | |||||
| ./lite-test --gtest_filter="*TestHebing*" | |||||
| ./lite-test --gtest_filter=TestFcFp32* | |||||
| ./lite-test --gtest_filter=TestConv1x1Fp32* | |||||
| ./lite-test --gtest_filter=TestStrassenFp32* | |||||
| ./lite-test --gtest_filter=TestDeConvolutionFp32* | |||||
| ./lite-test --gtest_filter=TestPadInt8.* | |||||
| ./lite-test --gtest_filter=TestDeconvInt8.* | |||||
| ./lite-test --gtest_filter="TestTfliteParser*" | ./lite-test --gtest_filter="TestTfliteParser*" | ||||
| @@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) { | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserRelu, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); | |||||
| ASSERT_EQ(val->type, schema::ActivationType_RELU); | |||||
| } | |||||
| class TestTfliteParserRelu6 : public TestTfliteParser { | class TestTfliteParserRelu6 : public TestTfliteParser { | ||||
| public: | public: | ||||
| TestTfliteParserRelu6() = default; | TestTfliteParserRelu6() = default; | ||||
| @@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) { | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserRelu6, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); | |||||
| ASSERT_EQ(val->type, schema::ActivationType_RELU6); | |||||
| } | |||||
| class TestTfliteParserTanh : public TestTfliteParser { | class TestTfliteParserTanh : public TestTfliteParser { | ||||
| public: | public: | ||||
| TestTfliteParserTanh() = default; | TestTfliteParserTanh() = default; | ||||
| @@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) { | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | ||||
| } | } | ||||
| // logistic | |||||
| TEST_F(TestTfliteParserTanh, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); | |||||
| ASSERT_EQ(val->type, schema::ActivationType_TANH); | |||||
| } | |||||
| class TestTfliteParserLogistic : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserLogistic() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserLogistic, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserLogistic, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); | |||||
| ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); | |||||
| } | |||||
| class TestTfliteParserHardSwish : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserHardSwish() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserHardSwish, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserHardSwish, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); | |||||
| ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); | |||||
| } | |||||
| class TestTfliteParserPrelu : public TestTfliteParser { | class TestTfliteParserPrelu : public TestTfliteParser { | ||||
| public: | public: | ||||
| @@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserPrelu, AttrValue) { | TEST_F(TestTfliteParserPrelu, AttrValue) { | ||||
| std::vector<float> slope(20, 0); | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope); | |||||
| auto val = meta_graph->nodes.front()->primitive->value; | |||||
| std::vector<float> slope(20, 0); | |||||
| ASSERT_EQ(val.AsPrelu()->slope, slope); | |||||
| ASSERT_EQ(val.type, schema::PrimitiveType_Prelu); | |||||
| } | } | ||||
| class TestTfliteParserLeakyRelu : public TestTfliteParser { | class TestTfliteParserLeakyRelu : public TestTfliteParser { | ||||
| @@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserLeakyRelu, AttrValue) { | TEST_F(TestTfliteParserLeakyRelu, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU(); | |||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->negativeSlope, 0.20000000298023224); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value; | |||||
| ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224); | |||||
| ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserAddN, AttrValue) { | TEST_F(TestTfliteParserAddN, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsAddN(); | |||||
| ASSERT_EQ(val->N, 4); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserArgmax : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserArgmax() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserArgmax, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserArgmax, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsArgMax(); | |||||
| ASSERT_EQ(val->axis, 1); | |||||
| ASSERT_EQ(val->topK, 1); | |||||
| ASSERT_EQ(val->axisType, 1); | |||||
| ASSERT_EQ(val->keepDims, false); | |||||
| ASSERT_EQ(val->outMaxValue, false); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserArgmin, OpType) { | TEST_F(TestTfliteParserArgmin, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserArgmin, AttrValue) { | TEST_F(TestTfliteParserArgmin, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); | auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); | ||||
| ASSERT_EQ(val->axis, 1); | ASSERT_EQ(val->axis, 1); | ||||
| ASSERT_EQ(val->topK, 1); | ASSERT_EQ(val->topK, 1); | ||||
| @@ -19,234 +19,57 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // doubleInputOp | // doubleInputOp | ||||
| class TestTfliteParserAdd1 : public TestTfliteParser { | |||||
| class TestTfliteParserAdd : public TestTfliteParser { | |||||
| public: | public: | ||||
| TestTfliteParserAdd1() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); } | |||||
| TestTfliteParserAdd() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); } | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserAdd1, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserAdd1, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserAdd2 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserAdd2() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserAdd2, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserAdd2, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserAdd3 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserAdd3() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserAdd3, OpType) { | |||||
| TEST_F(TestTfliteParserAdd, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserAdd3, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserSub1 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserSub1() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserSub1, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserSub1, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserSub2 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserSub2() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserSub2, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserSub2, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserSub3 : public TestTfliteParser { | |||||
| class TestTfliteParserSub : public TestTfliteParser { | |||||
| public: | public: | ||||
| TestTfliteParserSub3() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); } | |||||
| TestTfliteParserSub() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); } | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserSub3, OpType) { | |||||
| TEST_F(TestTfliteParserSub, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserSub3, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserMul1 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserMul1() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserMul1, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserMul1, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserMul2 : public TestTfliteParser { | |||||
| class TestTfliteParserMul : public TestTfliteParser { | |||||
| public: | public: | ||||
| TestTfliteParserMul2() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); } | |||||
| TestTfliteParserMul() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); } | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserMul2, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserMul2, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserMul3 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserMul3() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserMul3, OpType) { | |||||
| TEST_F(TestTfliteParserMul, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserMul3, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserDiv1 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserDiv1() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserDiv1, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserDiv1, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserDiv2 : public TestTfliteParser { | |||||
| class TestTfliteParserDiv : public TestTfliteParser { | |||||
| public: | public: | ||||
| TestTfliteParserDiv2() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); } | |||||
| TestTfliteParserDiv() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); } | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserDiv2, OpType) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserDiv2, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserDiv3 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserDiv3() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserDiv3, OpType) { | |||||
| TEST_F(TestTfliteParserDiv, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserDiv3, Tensor) { | |||||
| ASSERT_GT(meta_graph->allTensors.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); | |||||
| ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); | |||||
| } | |||||
| class TestTfliteParserFloorDiv : public TestTfliteParser { | class TestTfliteParserFloorDiv : public TestTfliteParser { | ||||
| public: | public: | ||||
| TestTfliteParserFloorDiv() = default; | TestTfliteParserFloorDiv() = default; | ||||
| @@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserFloorDiv, OpType) { | TEST_F(TestTfliteParserFloorDiv, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; | ||||
| @@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserFloorMod, OpType) { | TEST_F(TestTfliteParserFloorMod, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; | ||||
| } | } | ||||
| // realDiv | |||||
| class TestTfliteParserRealDiv : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserRealDiv() = default; | |||||
| void SetUp() override { | |||||
| meta_graph = LoadAndConvert("./realdiv.tflite"); | |||||
| } | |||||
| }; | |||||
| TEST_F(TestTfliteParserRealDiv, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; | |||||
| } | |||||
| class TestTfliteParserSquaredDifference : public TestTfliteParser { | class TestTfliteParserSquaredDifference : public TestTfliteParser { | ||||
| public: | public: | ||||
| @@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserPow, OpType) { | TEST_F(TestTfliteParserPow, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserPow, AttrValue) { | TEST_F(TestTfliteParserPow, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsPower(); | auto val = meta_graph->nodes.front()->primitive->value.AsPower(); | ||||
| ASSERT_EQ(val->scale, 1.0); | ASSERT_EQ(val->scale, 1.0); | ||||
| ASSERT_EQ(val->shift, 0.0); | ASSERT_EQ(val->shift, 0.0); | ||||
| ASSERT_EQ(val->power, 0.0); | ASSERT_EQ(val->power, 0.0); | ||||
| @@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserFloor, OpType) { | TEST_F(TestTfliteParserFloor, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; | ||||
| @@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { | TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { | ||||
| const std::vector<int> blockShape{2, 2}; | |||||
| const std::vector<int> crops{0, 0, 2, 0}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace(); | |||||
| const std::vector<int> blockShape = {2, 2}; | |||||
| ASSERT_EQ(val->blockShape, blockShape); | |||||
| const std::vector<int> crops = {0, 0, 2, 0}; | |||||
| ASSERT_EQ(val->crops, crops); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserCast, AttrValue) { | TEST_F(TestTfliteParserCast, AttrValue) { | ||||
| // float32 --> int32 | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsCast(); | |||||
| ASSERT_EQ(val->srcT, 43); | |||||
| ASSERT_EQ(val->dstT, 34); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserConcat : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserConcat() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserConcat, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserConcat, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); | |||||
| ASSERT_EQ(val->axis, 1); | |||||
| ASSERT_EQ(val->n, 2); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserConv : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserConv() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserConv, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; | |||||
| ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserConv, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); | |||||
| auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| ASSERT_EQ(val->group, 1); | |||||
| ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); | |||||
| ASSERT_EQ(val->hasBias, true); | |||||
| ASSERT_EQ(val->channelIn, 1); | |||||
| ASSERT_EQ(val->channelOut, 4); | |||||
| ASSERT_EQ(val->kernelH, 3); | |||||
| ASSERT_EQ(val->kernelW, 3); | |||||
| ASSERT_EQ(val->strideH, 1); | |||||
| ASSERT_EQ(val->strideW, 1); | |||||
| ASSERT_EQ(val->dilateH, 1); | |||||
| ASSERT_EQ(val->dilateW, 1); | |||||
| ASSERT_EQ(val->padMode, schema::PadMode_SAME); | |||||
| ASSERT_EQ(val->padUp, 1); | |||||
| ASSERT_EQ(val->padDown, 1); | |||||
| ASSERT_EQ(val->padLeft, 1); | |||||
| ASSERT_EQ(val->padRight, 1); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserDeConv : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserDeConv() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserDeConv, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; | |||||
| ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserDeConv, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr); | |||||
| auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| ASSERT_EQ(val->group, 1); | |||||
| ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); | |||||
| ASSERT_EQ(val->hasBias, true); | |||||
| ASSERT_EQ(val->channelIn, 1); | |||||
| ASSERT_EQ(val->channelOut, 4); | |||||
| ASSERT_EQ(val->kernelH, 3); | |||||
| ASSERT_EQ(val->kernelW, 3); | |||||
| ASSERT_EQ(val->strideH, 1); | |||||
| ASSERT_EQ(val->strideW, 1); | |||||
| ASSERT_EQ(val->dilateH, 1); | |||||
| ASSERT_EQ(val->dilateW, 1); | |||||
| ASSERT_EQ(val->padMode, schema::PadMode_SAME); | |||||
| ASSERT_EQ(val->padUp, 1); | |||||
| ASSERT_EQ(val->padDown, 1); | |||||
| ASSERT_EQ(val->padLeft, 1); | |||||
| ASSERT_EQ(val->padRight, 1); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserDepthToSpace, AttrValue) { | TEST_F(TestTfliteParserDepthToSpace, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace(); | |||||
| ASSERT_EQ(val->blockSize, 4); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,92 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserDepthwiseConv1 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserDepthwiseConv1() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserDepthwiseConv1, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; | |||||
| ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); | |||||
| auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| ASSERT_EQ(val->group, 0); | |||||
| ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); | |||||
| ASSERT_EQ(val->hasBias, true); | |||||
| ASSERT_EQ(val->channelIn, 1); | |||||
| ASSERT_EQ(val->channelOut, 4); | |||||
| ASSERT_EQ(val->kernelH, 3); | |||||
| ASSERT_EQ(val->kernelW, 3); | |||||
| ASSERT_EQ(val->strideH, 1); | |||||
| ASSERT_EQ(val->strideW, 1); | |||||
| ASSERT_EQ(val->dilateH, 1); | |||||
| ASSERT_EQ(val->dilateW, 1); | |||||
| ASSERT_EQ(val->padMode, schema::PadMode_SAME); | |||||
| ASSERT_EQ(val->padUp, 1); | |||||
| ASSERT_EQ(val->padDown, 1); | |||||
| ASSERT_EQ(val->padLeft, 1); | |||||
| ASSERT_EQ(val->padRight, 1); | |||||
| } | |||||
| class TestTfliteParserDepthwiseConv2 : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserDepthwiseConv2() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserDepthwiseConv2, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; | |||||
| ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr); | |||||
| auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); | |||||
| ASSERT_EQ(val->hasBias, true); | |||||
| ASSERT_EQ(val->channelIn, 2); | |||||
| ASSERT_EQ(val->channelMultiplier, 1); | |||||
| ASSERT_EQ(val->kernelH, 3); | |||||
| ASSERT_EQ(val->kernelW, 3); | |||||
| ASSERT_EQ(val->strideH, 1); | |||||
| ASSERT_EQ(val->strideW, 1); | |||||
| ASSERT_EQ(val->dilateH, 1); | |||||
| ASSERT_EQ(val->dilateW, 1); | |||||
| ASSERT_EQ(val->padMode, schema::PadMode_SAME); | |||||
| ASSERT_EQ(val->padUp, 1); | |||||
| ASSERT_EQ(val->padDown, 1); | |||||
| ASSERT_EQ(val->padLeft, 1); | |||||
| ASSERT_EQ(val->padRight, 1); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserFill, OpType) { | TEST_F(TestTfliteParserFill, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserFill, AttrValue) { | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| TEST_F(TestTfliteParserFill, AttrValue) {; | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsFill(); | auto val = meta_graph->nodes.front()->primitive->value.AsFill(); | ||||
| std::vector<int32_t> dims = {9}; | std::vector<int32_t> dims = {9}; | ||||
| ASSERT_EQ(val->dims, dims); | ASSERT_EQ(val->dims, dims); | ||||
| } | } | ||||
| @@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserGatherNd, OpType) { | TEST_F(TestTfliteParserGatherNd, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserGatherNd, AttrValue) { | TEST_F(TestTfliteParserGatherNd, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); | auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); | ||||
| ASSERT_EQ(val->batchDims, 0); | ASSERT_EQ(val->batchDims, 0); | ||||
| } | } | ||||
| @@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserGather, OpType) { | TEST_F(TestTfliteParserGather, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserGather, AttrValue) { | TEST_F(TestTfliteParserGather, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsGather(); | auto val = meta_graph->nodes.front()->primitive->value.AsGather(); | ||||
| ASSERT_EQ(val->axis, 0); | ASSERT_EQ(val->axis, 0); | ||||
| ASSERT_EQ(val->batchDims, 0); | ASSERT_EQ(val->batchDims, 0); | ||||
| @@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserLRN, OpType) { | TEST_F(TestTfliteParserLRN, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, | ||||
| @@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserLRN, AttrValue) { | TEST_F(TestTfliteParserLRN, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); | auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); | ||||
| ASSERT_EQ(val->alpha, 1); | ASSERT_EQ(val->alpha, 1); | ||||
| ASSERT_EQ(val->beta, 0.5); | ASSERT_EQ(val->beta, 0.5); | ||||
| @@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserOneHot, AttrValue) { | TEST_F(TestTfliteParserOneHot, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); | ||||
| // in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size() | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsOneHot(); | |||||
| ASSERT_EQ(val->axis, 2); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserPad, OpType) { | TEST_F(TestTfliteParserPad, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserPad, AttrValue) { | TEST_F(TestTfliteParserPad, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsPad(); | auto val = meta_graph->nodes.front()->primitive->value.AsPad(); | ||||
| std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4}; | std::vector<int32_t> paddings = {1, 1, 2, 2, 3, 3, 4, 4}; | ||||
| ASSERT_EQ(val->paddings, paddings); | ASSERT_EQ(val->paddings, paddings); | ||||
| } | } | ||||
| @@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserMaxPooling, AttrValue) { | TEST_F(TestTfliteParserMaxPooling, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); | auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | ASSERT_EQ(val->format, schema::Format_NHWC); | ||||
| ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); | ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); | ||||
| ASSERT_EQ(val->global, false); | ASSERT_EQ(val->global, false); | ||||
| @@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserAvgPooling, AttrValue) { | TEST_F(TestTfliteParserAvgPooling, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); | auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | ASSERT_EQ(val->format, schema::Format_NHWC); | ||||
| ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); | ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); | ||||
| ASSERT_EQ(val->global, false); | ASSERT_EQ(val->global, false); | ||||
| @@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserReduceMax, AttrValue) { | TEST_F(TestTfliteParserReduceMax, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode"; | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax); | |||||
| ASSERT_EQ(val->keepDims, false); | ASSERT_EQ(val->keepDims, false); | ||||
| std::vector<int32_t> axes = {2}; | std::vector<int32_t> axes = {2}; | ||||
| ASSERT_EQ(val->axes, axes); | ASSERT_EQ(val->axes, axes); | ||||
| @@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserReduceMin, AttrValue) { | TEST_F(TestTfliteParserReduceMin, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode"; | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin); | |||||
| ASSERT_EQ(val->keepDims, false); | ASSERT_EQ(val->keepDims, false); | ||||
| std::vector<int32_t> axes = {2}; | std::vector<int32_t> axes = {2}; | ||||
| ASSERT_EQ(val->axes, axes); | ASSERT_EQ(val->axes, axes); | ||||
| @@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserReduceProd, AttrValue) { | TEST_F(TestTfliteParserReduceProd, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode"; | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd); | |||||
| ASSERT_EQ(val->keepDims, false); | ASSERT_EQ(val->keepDims, false); | ||||
| std::vector<int32_t> axes = {2}; | std::vector<int32_t> axes = {2}; | ||||
| ASSERT_EQ(val->axes, axes); | ASSERT_EQ(val->axes, axes); | ||||
| @@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSum, AttrValue) { | TEST_F(TestTfliteParserSum, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode"; | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum); | |||||
| ASSERT_EQ(val->keepDims, false); | ASSERT_EQ(val->keepDims, false); | ||||
| std::vector<int32_t> axes = {2}; | std::vector<int32_t> axes = {2}; | ||||
| ASSERT_EQ(val->axes, axes); | ASSERT_EQ(val->axes, axes); | ||||
| @@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserMean, AttrValue) { | TEST_F(TestTfliteParserMean, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode"; | |||||
| ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean); | |||||
| ASSERT_EQ(val->keepDims, true); | ASSERT_EQ(val->keepDims, true); | ||||
| std::vector<int32_t> axes = {2, 3}; | std::vector<int32_t> axes = {2, 3}; | ||||
| ASSERT_EQ(val->axes, axes); | ASSERT_EQ(val->axes, axes); | ||||
| @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserReshape, AttrValue) { | TEST_F(TestTfliteParserReshape, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); | ||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReshape(); | |||||
| std::vector<int64_t> shape = {3, 5, 20}; | std::vector<int64_t> shape = {3, 5, 20}; | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32 | |||||
| ASSERT_EQ(val->shape, shape); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserResizeNN, OpType) { | TEST_F(TestTfliteParserResizeNN, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserResizeNN, AttrValue) { | TEST_F(TestTfliteParserResizeNN, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsResize(); | auto val = meta_graph->nodes.front()->primitive->value.AsResize(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->alignCorners, false); | ASSERT_EQ(val->alignCorners, false); | ||||
| ASSERT_EQ(val->newHeight, 3); | ASSERT_EQ(val->newHeight, 3); | ||||
| ASSERT_EQ(val->newWidth, 100); | ASSERT_EQ(val->newWidth, 100); | ||||
| @@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserResizeBilinear, OpType) { | TEST_F(TestTfliteParserResizeBilinear, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserResizeBilinear, AttrValue) { | TEST_F(TestTfliteParserResizeBilinear, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsResize(); | auto val = meta_graph->nodes.front()->primitive->value.AsResize(); | ||||
| ASSERT_NE(val, nullptr); | |||||
| ASSERT_EQ(val->alignCorners, false); | ASSERT_EQ(val->alignCorners, false); | ||||
| ASSERT_EQ(val->newHeight, 75); | ASSERT_EQ(val->newHeight, 75); | ||||
| ASSERT_EQ(val->newWidth, 4); | ASSERT_EQ(val->newWidth, 4); | ||||
| @@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser { | |||||
| }; | }; | ||||
| TEST_F(TestTfliteParserReverse, OpType) { | TEST_F(TestTfliteParserReverse, OpType) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | ASSERT_GT(meta_graph->nodes.size(), 0); | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserReverse, AttrValue) { | TEST_F(TestTfliteParserReverse, AttrValue) { | ||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); | auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); | ||||
| std::vector<int32_t> axis = {3}; | std::vector<int32_t> axis = {3}; | ||||
| ASSERT_EQ(val->axis, axis); | ASSERT_EQ(val->axis, axis); | ||||
| } | } | ||||
| @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserReverseSequence, AttrValue) { | TEST_F(TestTfliteParserReverseSequence, AttrValue) { | ||||
| std::vector<int> seq_length{7, 2, 3, 5}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence(); | |||||
| ASSERT_EQ(val->seqAxis, 1); | |||||
| ASSERT_EQ(val->seqAxis, 1); | |||||
| std::vector<int> seq_length = {7, 2, 3, 5}; | |||||
| ASSERT_EQ(val->seqLengths, seq_length); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserSlice : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserSlice() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserSlice, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserSlice, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSlice(); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| std::vector<int32_t> begin = {1, 0, 0}; | |||||
| ASSERT_EQ(val->begin, begin); | |||||
| std::vector<int32_t> size = {1, 1, 3}; | |||||
| ASSERT_EQ(val->size, size); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSoftmax, AttrValue) { | TEST_F(TestTfliteParserSoftmax, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax(); | |||||
| ASSERT_EQ(val->axis, -1); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { | TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { | ||||
| std::vector<int> blockshape{2, 2}; | |||||
| std::vector<int> padding{0, 0, 2, 0}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(); | |||||
| std::vector<int> blockshape = {2, 2}; | |||||
| ASSERT_EQ(val->blockShape, blockshape); | |||||
| std::vector<int> padding = {0, 0, 2, 0}; | |||||
| ASSERT_EQ(val->paddings, padding); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { | TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(); | |||||
| ASSERT_EQ(val->blockSize, 2); | |||||
| ASSERT_EQ(val->format, schema::Format_NHWC); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSparseToDense, AttrValue) { | TEST_F(TestTfliteParserSparseToDense, AttrValue) { | ||||
| std::vector<int> outputShape{5, 5}; | |||||
| std::vector<int> sparseValue{1}; | |||||
| std::vector<int> defaultValue{0}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense(); | |||||
| std::vector<int> outputShape = {5, 5}; | |||||
| ASSERT_EQ(val->outputShape, outputShape); | |||||
| std::vector<int> sparseValue = {1}; | |||||
| ASSERT_EQ(val->sparseValue, sparseValue); | |||||
| std::vector<int> defaultValue = {0}; | |||||
| ASSERT_EQ(val->defaultValue, defaultValue); | |||||
| ASSERT_EQ(val->validateIndices, false); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSplit, AttrValue) { | TEST_F(TestTfliteParserSplit, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); | ||||
| const std::vector<int> sizeSplits{2, 2}; | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); | |||||
| ASSERT_EQ(val->splitDim, 2); | |||||
| ASSERT_EQ(val->numberSplit, 2); | |||||
| const std::vector<int> sizeSplits = {2, 2}; | |||||
| ASSERT_EQ(val->sizeSplits, sizeSplits); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserSplitV, AttrValue) { | TEST_F(TestTfliteParserSplitV, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); | ||||
| const std::vector<int> sizeSplits{1, 3}; | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); | |||||
| ASSERT_EQ(val->splitDim, 0); | |||||
| ASSERT_EQ(val->numberSplit, 2); | |||||
| const std::vector<int> sizeSplits = {1, 3}; | |||||
| ASSERT_EQ(val->sizeSplits, sizeSplits); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserStack : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserStack() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserStack, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserStack, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsStack(); | |||||
| ASSERT_EQ(val->axis, 1); | |||||
| ASSERT_EQ(val->n, 2); | |||||
| const std::vector<int> isScale = {3, 2, 3}; | |||||
| ASSERT_EQ(val->isScale, isScale); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserStridedSlice, AttrValue) { | TEST_F(TestTfliteParserStridedSlice, AttrValue) { | ||||
| std::vector<int> begin{1, -1, 0}; | |||||
| std::vector<int> end{2, -3, 3}; | |||||
| std::vector<int> stride{1, -1, 1}; | |||||
| std::vector<int> isscale{3, 2, 3}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice(); | |||||
| ASSERT_EQ(val->beginMask, 0); | |||||
| ASSERT_EQ(val->endMask, 0); | |||||
| ASSERT_EQ(val->beginMask, 0); | |||||
| ASSERT_EQ(val->beginMask, 0); | |||||
| std::vector<int> begin = {1, -1, 0}; | |||||
| ASSERT_EQ(val->begin, begin); | |||||
| std::vector<int> end = {2, -3, 3}; | |||||
| ASSERT_EQ(val->end, end); | |||||
| std::vector<int> stride = {1, -1, 1}; | |||||
| ASSERT_EQ(val->stride, stride); | |||||
| std::vector<int> isscale = {3, 2, 3}; | |||||
| ASSERT_EQ(val->isScale, isscale); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserTile, AttrValue) { | TEST_F(TestTfliteParserTile, AttrValue) { | ||||
| std::vector<int> multiply{2, 3, 4}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsTile(); | |||||
| std::vector<int> multiply = {2, 3, 4}; | |||||
| ASSERT_EQ(val->multiples, multiply); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserTopKV2, AttrValue) { | TEST_F(TestTfliteParserTopKV2, AttrValue) { | ||||
| // attr->sorted default is true | |||||
| std::vector<int> k{3}; | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2(); | |||||
| std::vector<int> k = {3}; | |||||
| ASSERT_EQ(val->k, k); | |||||
| ASSERT_EQ(val->sorted, true); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" | |||||
| #include <iostream> | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| class TestTfliteParserTranspose : public TestTfliteParser { | |||||
| public: | |||||
| TestTfliteParserTranspose() = default; | |||||
| void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); } | |||||
| }; | |||||
| TEST_F(TestTfliteParserTranspose, OpType) { | |||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; | |||||
| } | |||||
| TEST_F(TestTfliteParserTranspose, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsTranspose(); | |||||
| ASSERT_EQ(val->conjugate, false); | |||||
| std::vector<int32_t> perm = {1, 0}; | |||||
| ASSERT_EQ(val->perm, perm); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserUnique, AttrValue) { | TEST_F(TestTfliteParserUnique, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32 | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsUnique(); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); | |||||
| ASSERT_EQ(val->outType, 34); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) { | |||||
| } | } | ||||
| TEST_F(TestTfliteParserUnstack, AttrValue) { | TEST_F(TestTfliteParserUnstack, AttrValue) { | ||||
| ASSERT_NE(meta_graph, nullptr); | |||||
| ASSERT_GT(meta_graph->nodes.size(), 0); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5); | |||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsUnstack(); | |||||
| ASSERT_EQ(val->num, 5); | |||||
| ASSERT_EQ(val->axis, 1); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | ||||
| } else if (weightTensor->format == schema::Format_KCHW) { | } else if (weightTensor->format == schema::Format_KCHW) { | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | ||||
| } else if (weightTensor->format == schema::Format_CHWK) { | |||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tflite | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | ||||
| @@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | ||||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | ||||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | ||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tf | |||||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tflite | |||||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | ||||
| @@ -21,11 +21,16 @@ namespace lite { | |||||
| STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | ||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| // MS_LOGE("null pointer dereferencing."); | |||||
| // MS_LOG(ERROR) << "null pointer dereferencing."; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::unique_ptr<schema::ReshapeT> attr(new schema::ReshapeT()); | |||||
| attr->format = schema::Format_NCHW; | |||||
| std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | |||||
| const caffe::FlattenParameter flattenParam = proto.flatten_param(); | |||||
| attr->axis = (int32_t)flattenParam.axis(); | |||||
| attr->useAxis = true; | |||||
| attr->hasBias = false; | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | op->primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| op->primitive->value.type = schema::PrimitiveType_Flatten; | op->primitive->value.type = schema::PrimitiveType_Flatten; | ||||
| @@ -14,18 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_activation_parser.h" | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include "tools/converter/parser/tflite/tflite_activation_parser.h" | |||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | MS_LOG(ERROR) << "op->primitive is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT()); | ||||
| std::vector<std::string> node_name_str; | std::vector<std::string> node_name_str; | ||||
| Split(op->name, &node_name_str, "-"); | Split(op->name, &node_name_str, "-"); | ||||
| const char *node_name = node_name_str.data()->c_str(); | const char *node_name = node_name_str.data()->c_str(); | ||||
| if (std::strcmp(node_name, "Relu") == 0) { | if (std::strcmp(node_name, "Relu") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteReluParser"; | MS_LOG(DEBUG) << "parse TfliteReluParser"; | ||||
| attr->type = schema::ActivationType_RELU; | attr->type = schema::ActivationType_RELU; | ||||
| @@ -54,29 +55,31 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| } else if (std::strcmp(node_name, "Logistic") == 0) { | } else if (std::strcmp(node_name, "Logistic") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | ||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (std::strcmp(node_name, "LeakyRelu") == 0) { | |||||
| const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions(); | |||||
| if (option == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||||
| attr->alpha = option->alpha; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong activation type"; | |||||
| return RET_ERROR; | |||||
| } else if (std::strcmp(node_name, "HardSwish") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; | |||||
| attr->type = schema::ActivationType_SIGMOID; | |||||
| } | } | ||||
| attr->alpha = 0.2f; | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | op->primitive->value.type = schema::PrimitiveType_Activation; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TflitePreluParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -86,23 +89,64 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | MS_LOG(ERROR) << "op->primitive is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "paser TflitePreluParser"; | |||||
| std::unique_ptr<schema::PreluT> attr(new schema::PreluT()); | std::unique_ptr<schema::PreluT> attr(new schema::PreluT()); | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { | if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { | ||||
| MS_LOG(ERROR) << "get pRelu -> slope failed"; | MS_LOG(ERROR) << "get pRelu -> slope failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Prelu; | op->primitive->value.type = schema::PrimitiveType_Prelu; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT()); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->negativeSlope = tflite_attr->alpha; | |||||
| op->primitive->value.type = schema::PrimitiveType_LeakyReLU; | |||||
| op->primitive->value.value = attr.release(); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | ||||
| TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | ||||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | ||||
| TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); | |||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | ||||
| TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); | TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); | ||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | ||||
| @@ -14,13 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_RELU_PARSER_H | |||||
| #define PREDICT_TFLITE_RELU_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | TfliteActivationParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| class TfliteReluParser : public TfliteActivationParser { | class TfliteReluParser : public TfliteActivationParser { | ||||
| @@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser { | |||||
| TfliteLogisticParser() : TfliteActivationParser() {} | TfliteLogisticParser() : TfliteActivationParser() {} | ||||
| }; | }; | ||||
| class TfliteLeakyReluParser : public TfliteActivationParser { | |||||
| class TfliteHardSwishParser : public TfliteActivationParser { | |||||
| public: | public: | ||||
| TfliteLeakyReluParser() : TfliteActivationParser() {} | |||||
| TfliteHardSwishParser() : TfliteActivationParser() {} | |||||
| }; | }; | ||||
| class TflitePreluParser : public TfliteNodeParser { | class TflitePreluParser : public TfliteNodeParser { | ||||
| @@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | |||||
| class TfliteLeakyReluParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_RELU_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H | |||||
| @@ -18,14 +18,20 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_addn_parser.h" | #include "tools/converter/parser/tflite/tflite_addn_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | |||||
| // set attr | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | |||||
| std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); | std::unique_ptr<schema::AddNT> attr(new schema::AddNT()); | ||||
| attr->N = tfliteTensors.size() - 1; | |||||
| attr->N = tflite_tensors.size() - 1; | |||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | op->primitive->value.type = schema::PrimitiveType_AddN; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| // set input | |||||
| for (int i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_ADDN_PARSER_H | |||||
| #define LITE_TFLITE_ADDN_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_ADDN_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H | |||||
| @@ -17,16 +17,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_argmax_parser.h" | #include "tools/converter/parser/tflite/tflite_argmax_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | |||||
| std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT()); | ||||
| attr->outMaxValue = false; | attr->outMaxValue = false; | ||||
| @@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->keepDims = false; | attr->keepDims = false; | ||||
| attr->axisType = 1; | attr->axisType = 1; | ||||
| auto axis_idx = tfliteOp->inputs[1]; | |||||
| std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); | |||||
| auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; | |||||
| // get axis attr | |||||
| auto axis_idx = tflite_op->inputs[1]; | |||||
| std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); | |||||
| auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; | |||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | op->primitive->value.type = schema::PrimitiveType_ArgMax; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_ARGMAX_PARSER_H | |||||
| #define PREDICT_TFLITE_ARGMAX_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_ARGMAX_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H | |||||
| @@ -17,14 +17,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_argmin_parser.h" | #include "tools/converter/parser/tflite/tflite_argmin_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | |||||
| std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); | std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT()); | ||||
| attr->outMaxValue = false; | attr->outMaxValue = false; | ||||
| @@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->keepDims = false; | attr->keepDims = false; | ||||
| attr->axisType = 1; | attr->axisType = 1; | ||||
| auto axis_idx = tfliteOp->inputs[1]; | |||||
| std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); | |||||
| auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; | |||||
| // get axis attr | |||||
| auto axis_idx = tflite_op->inputs[1]; | |||||
| std::for_each(tflite_tensors[axis_idx]->shape.begin(), | |||||
| tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); | |||||
| auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; | |||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | op->primitive->value.type = schema::PrimitiveType_ArgMin; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_ARGMIN_PARSER_H | |||||
| #define PREDICT_TFLITE_ARGMIN_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteArgminParser() : TfliteNodeParser("Argmin") {} | TfliteArgminParser() : TfliteNodeParser("Argmin") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_ARGMIN_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H | |||||
| @@ -18,14 +18,17 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| std::vector<std::string> node_name_str; | std::vector<std::string> node_name_str; | ||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | const char *node_name = node_name_str.data()->c_str(); | ||||
| if (std::strcmp(node_name, "Add") == 0 | |||||
| || std::strcmp(node_name, "Sub") == 0 | |||||
| || std::strcmp(node_name, "Mul") == 0 | |||||
| || std::strcmp(node_name, "Div") == 0) { | |||||
| auto x_index = tfliteOp->inputs[0]; | |||||
| const auto &x_tensor = tfliteTensors[x_index]; | |||||
| if (x_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| if (std::strcmp(node_name, "Add") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); | |||||
| if (x_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the first input is null"; | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "Sub") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSubParser"; | |||||
| std::unique_ptr<schema::SubT> attr(new schema::SubT()); | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (!x_data->data.empty()) { | |||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Sub; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "Mul") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||||
| std::unique_ptr<schema::MulT> attr(new schema::MulT()); | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "Div") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||||
| std::unique_ptr<schema::DivT> attr(new schema::DivT()); | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (!y_data->data.empty()) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (std::strcmp(node_name, "Add") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddParser"; | |||||
| std::unique_ptr<schema::AddT> attr(new schema::AddT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Add; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sub") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSubParser"; | |||||
| std::unique_ptr<schema::SubT> attr(new schema::SubT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Sub; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Mul") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteMulParser"; | |||||
| std::unique_ptr<schema::MulT> attr(new schema::MulT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mul; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Div") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDivParser"; | |||||
| std::unique_ptr<schema::DivT> attr(new schema::DivT()); | |||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); | |||||
| if (nullptr == tfliteAttr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||||
| op->primitive->value.value = attr.release(); | |||||
| } else if (std::strcmp(node_name, "FloorDiv") == 0) { | } else if (std::strcmp(node_name, "FloorDiv") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; | ||||
| std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); | std::unique_ptr<schema::FloorDivT> attr(new schema::FloorDivT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_FloorDiv; | op->primitive->value.type = schema::PrimitiveType_FloorDiv; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "FloorMod") == 0) { | } else if (std::strcmp(node_name, "FloorMod") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | MS_LOG(DEBUG) << "parse TfliteFloorModParser"; | ||||
| std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); | std::unique_ptr<schema::FloorModT> attr(new schema::FloorModT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_FloorMod; | op->primitive->value.type = schema::PrimitiveType_FloorMod; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "RealDiv") == 0) { | } else if (std::strcmp(node_name, "RealDiv") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteRealDivParser"; | MS_LOG(DEBUG) << "parse TfliteRealDivParser"; | ||||
| std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); | std::unique_ptr<schema::RealDivT> attr(new schema::RealDivT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_RealDiv; | |||||
| op->primitive->value.type = schema::PrimitiveType_Div; | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "SquaredDifference") == 0) { | } else if (std::strcmp(node_name, "SquaredDifference") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; | MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; | ||||
| std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT()); | std::unique_ptr<schema::SquaredDifferenceT> attr(new schema::SquaredDifferenceT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_SquaredDifference; | op->primitive->value.type = schema::PrimitiveType_SquaredDifference; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Pow") == 0) { | } else if (std::strcmp(node_name, "Pow") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TflitePowParser"; | MS_LOG(DEBUG) << "parse TflitePowParser"; | ||||
| std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); | std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); | ||||
| @@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| attr->shift = 0.0f; | attr->shift = 0.0f; | ||||
| op->primitive->value.type = schema::PrimitiveType_Power; | op->primitive->value.type = schema::PrimitiveType_Power; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Maximum") == 0) { | } else if (std::strcmp(node_name, "Maximum") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteMaximumParser"; | MS_LOG(DEBUG) << "parse TfliteMaximumParser"; | ||||
| std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); | std::unique_ptr<schema::MaximumT> attr(new schema::MaximumT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Maximum; | op->primitive->value.type = schema::PrimitiveType_Maximum; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Minimum") == 0) { | } else if (std::strcmp(node_name, "Minimum") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteMinimumParser"; | MS_LOG(DEBUG) << "parse TfliteMinimumParser"; | ||||
| std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); | std::unique_ptr<schema::MinimumT> attr(new schema::MinimumT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Minimum; | op->primitive->value.type = schema::PrimitiveType_Minimum; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| // set input | |||||
| for (int i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| std::vector<std::string> node_name_str; | std::vector<std::string> node_name_str; | ||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | const char *node_name = node_name_str.data()->c_str(); | ||||
| if (std::strcmp(node_name, "Abs") == 0) { | if (std::strcmp(node_name, "Abs") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteAbsParser"; | MS_LOG(DEBUG) << "parse TfliteAbsParser"; | ||||
| std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); | std::unique_ptr<schema::AbsT> attr(new schema::AbsT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Abs; | op->primitive->value.type = schema::PrimitiveType_Abs; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Exp") == 0) { | } else if (std::strcmp(node_name, "Exp") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteExpParser"; | MS_LOG(DEBUG) << "parse TfliteExpParser"; | ||||
| std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); | std::unique_ptr<schema::ExpT> attr(new schema::ExpT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Exp; | op->primitive->value.type = schema::PrimitiveType_Exp; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sqrt") == 0) { | } else if (std::strcmp(node_name, "Sqrt") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteSqrtParser"; | MS_LOG(DEBUG) << "parse TfliteSqrtParser"; | ||||
| std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); | std::unique_ptr<schema::SqrtT> attr(new schema::SqrtT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Sqrt; | op->primitive->value.type = schema::PrimitiveType_Sqrt; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Rsqrt") == 0) { | } else if (std::strcmp(node_name, "Rsqrt") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; | MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; | ||||
| std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT()); | std::unique_ptr<schema::RsqrtT> attr(new schema::RsqrtT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Rsqrt; | op->primitive->value.type = schema::PrimitiveType_Rsqrt; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Square") == 0) { | } else if (std::strcmp(node_name, "Square") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteSquareParser"; | MS_LOG(DEBUG) << "parse TfliteSquareParser"; | ||||
| std::unique_ptr<schema::SquareT> attr(new schema::SquareT()); | std::unique_ptr<schema::SquareT> attr(new schema::SquareT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Square; | op->primitive->value.type = schema::PrimitiveType_Square; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Sin") == 0) { | } else if (std::strcmp(node_name, "Sin") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteSinParser"; | MS_LOG(DEBUG) << "parse TfliteSinParser"; | ||||
| std::unique_ptr<schema::SinT> attr(new schema::SinT()); | std::unique_ptr<schema::SinT> attr(new schema::SinT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Sin; | op->primitive->value.type = schema::PrimitiveType_Sin; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Cos") == 0) { | } else if (std::strcmp(node_name, "Cos") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteCosParser"; | MS_LOG(DEBUG) << "parse TfliteCosParser"; | ||||
| std::unique_ptr<schema::CosT> attr(new schema::CosT()); | std::unique_ptr<schema::CosT> attr(new schema::CosT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Cos; | op->primitive->value.type = schema::PrimitiveType_Cos; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Log") == 0) { | } else if (std::strcmp(node_name, "Log") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteLogParser"; | MS_LOG(DEBUG) << "parse TfliteLogParser"; | ||||
| std::unique_ptr<schema::LogT> attr(new schema::LogT()); | std::unique_ptr<schema::LogT> attr(new schema::LogT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Log; | op->primitive->value.type = schema::PrimitiveType_Log; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Round") == 0) { | } else if (std::strcmp(node_name, "Round") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteRoundParser"; | MS_LOG(DEBUG) << "parse TfliteRoundParser"; | ||||
| std::unique_ptr<schema::RoundT> attr(new schema::RoundT()); | std::unique_ptr<schema::RoundT> attr(new schema::RoundT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Round; | op->primitive->value.type = schema::PrimitiveType_Round; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Ceil") == 0) { | } else if (std::strcmp(node_name, "Ceil") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteCeilParser"; | MS_LOG(DEBUG) << "parse TfliteCeilParser"; | ||||
| std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); | std::unique_ptr<schema::CeilT> attr(new schema::CeilT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Ceil; | op->primitive->value.type = schema::PrimitiveType_Ceil; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "flOOR") == 0) { | } else if (std::strcmp(node_name, "flOOR") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteFloorParser"; | MS_LOG(DEBUG) << "parse TfliteFloorParser"; | ||||
| std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); | std::unique_ptr<schema::FloorT> attr(new schema::FloorT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Floor; | op->primitive->value.type = schema::PrimitiveType_Floor; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| } | } | ||||
| std::vector<std::string> node_name_str; | std::vector<std::string> node_name_str; | ||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | const char *node_name = node_name_str.data()->c_str(); | ||||
| if (std::strcmp(node_name, "Equal") == 0) { | if (std::strcmp(node_name, "Equal") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteEqualParser"; | MS_LOG(DEBUG) << "parse TfliteEqualParser"; | ||||
| std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); | std::unique_ptr<schema::EqualT> attr(new schema::EqualT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Equal; | op->primitive->value.type = schema::PrimitiveType_Equal; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "NotEqual") == 0) { | } else if (std::strcmp(node_name, "NotEqual") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; | ||||
| std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); | std::unique_ptr<schema::NotEqualT> attr(new schema::NotEqualT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_NotEqual; | op->primitive->value.type = schema::PrimitiveType_NotEqual; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Greater") == 0) { | } else if (std::strcmp(node_name, "Greater") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | MS_LOG(DEBUG) << "parse TfliteGreaterParser"; | ||||
| std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); | std::unique_ptr<schema::GreaterT> attr(new schema::GreaterT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Greater; | op->primitive->value.type = schema::PrimitiveType_Greater; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "GreaterEqual") == 0) { | } else if (std::strcmp(node_name, "GreaterEqual") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; | ||||
| std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); | std::unique_ptr<schema::GreaterEqualT> attr(new schema::GreaterEqualT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_GreaterEqual; | op->primitive->value.type = schema::PrimitiveType_GreaterEqual; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "Less") == 0) { | } else if (std::strcmp(node_name, "Less") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteLessParser"; | MS_LOG(DEBUG) << "parse TfliteLessParser"; | ||||
| std::unique_ptr<schema::LessT> attr(new schema::LessT()); | std::unique_ptr<schema::LessT> attr(new schema::LessT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_Less; | op->primitive->value.type = schema::PrimitiveType_Less; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else if (std::strcmp(node_name, "LessEqual") == 0) { | } else if (std::strcmp(node_name, "LessEqual") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; | ||||
| std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); | std::unique_ptr<schema::LessEqualT> attr(new schema::LessEqualT()); | ||||
| op->primitive->value.type = schema::PrimitiveType_LessEqual; | op->primitive->value.type = schema::PrimitiveType_LessEqual; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "wrong op type"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| for (int i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | |||||
| } | } | ||||
| TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_MATH_PARSER_H | |||||
| #define PREDICT_TFLITE_MATH_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | class TfliteAddParser : public TfliteDoubleInputOpParser { | ||||
| @@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | class TfliteAbsParser : public TfliteSingleInputOpParser { | ||||
| @@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| class TfliteEqualParser : public TfliteCompareOpParser { | class TfliteEqualParser : public TfliteCompareOpParser { | ||||
| @@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_MATH_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H | |||||
| @@ -19,14 +19,17 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| std::vector<std::string> node_name_str; | std::vector<std::string> node_name_str; | ||||
| Split(op->name.data(), &node_name_str, "-"); | |||||
| Split(op->name, &node_name_str, "-"); | |||||
| const char *node_name = node_name_str.data()->c_str(); | const char *node_name = node_name_str.data()->c_str(); | ||||
| if (std::strcmp(node_name, "BatchToSpace") == 0) { | if (std::strcmp(node_name, "BatchToSpace") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; | ||||
| } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; | ||||
| // in tflite | |||||
| // blockShape should be a 1D tensor with dimension [spatial_dims_num] | |||||
| // crops should be a 2D tensor with dimension [spatial_dims_num, 2] | |||||
| } | } | ||||
| std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); | std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT()); | ||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| #define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | ||||
| @@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H | |||||
| @@ -18,14 +18,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" | #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> & | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | |||||
| std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); | std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT()); | ||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { | |||||
| MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H | |||||
| #define LITE_TFLITE_BROADCAST_TO_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_BROADCAST_TO_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H | |||||
| @@ -14,18 +14,22 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_cast_parser.h" | #include "tools/converter/parser/tflite/tflite_cast_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | |||||
| std::unique_ptr<schema::CastT> attr(new schema::CastT()); | std::unique_ptr<schema::CastT> attr(new schema::CastT()); | ||||
| const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; | |||||
| const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = dtype_map[in_tensor->type]; | |||||
| const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->dstT = dtype_map[out_tensor->type]; | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | op->primitive->value.type = schema::PrimitiveType_Cast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_CAST_PARSER_ | |||||
| #define LITE_TFLITE_CAST_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_CAST_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H | |||||
| @@ -17,14 +17,20 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_concat_parser.h" | #include "tools/converter/parser/tflite/tflite_concat_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | |||||
| // set attr | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | |||||
| std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT()); | ||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); | |||||
| const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); | |||||
| if (tfliteAttr == nullptr) { | if (tfliteAttr == nullptr) { | ||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->axis = tfliteAttr->axis; | attr->axis = tfliteAttr->axis; | ||||
| attr->n = tfliteOp->inputs.size(); | |||||
| attr->n = tflite_op->inputs.size(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Concat; | op->primitive->value.type = schema::PrimitiveType_Concat; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (int i = 0; i < tflite_op->inputs.size(); i++) { | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_CONCAT_PARSER_H | |||||
| #define PREDICT_TFLITE_CONCAT_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteConcatParser() : TfliteNodeParser("Concat") {} | TfliteConcatParser() : TfliteNodeParser("Concat") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_CONCAT_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H | |||||
| @@ -17,14 +17,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_conv_parser.h" | #include "tools/converter/parser/tflite/tflite_conv_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | |||||
| std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); | std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT()); | ||||
| const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); | |||||
| if (tfliteAttr == nullptr) { | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); | |||||
| if (tflite_attr == nullptr) { | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->group = 1; | attr->group = 1; | ||||
| attr->strideW = tfliteAttr->stride_w; | |||||
| attr->strideH = tfliteAttr->stride_h; | |||||
| attr->dilateH = tfliteAttr->dilation_h_factor; | |||||
| attr->dilateW = tfliteAttr->dilation_w_factor; | |||||
| attr->padMode = GetPadMode(tfliteAttr->padding); | |||||
| attr->strideW = tflite_attr->stride_w; | |||||
| attr->strideH = tflite_attr->stride_h; | |||||
| attr->dilateH = tflite_attr->dilation_h_factor; | |||||
| attr->dilateW = tflite_attr->dilation_w_factor; | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | |||||
| attr->format = schema::Format_NHWC; | attr->format = schema::Format_NHWC; | ||||
| attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); | |||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | |||||
| attr->hasBias = true; | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tfliteOp->inputs[1]; | |||||
| const auto &weight_tensor = tfliteTensors[weight_index]; | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "weight_tensor is null"; | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight_shape = weight_tensor->shape; | auto weight_shape = weight_tensor->shape; | ||||
| attr->channelIn = weight_shape[KHWC_C]; | |||||
| attr->channelOut = weight_shape[KHWC_K]; | |||||
| attr->kernelW = weight_shape[KHWC_W]; | |||||
| attr->kernelH = weight_shape[KHWC_H]; | |||||
| // get the conv op bias tensor | |||||
| if (tfliteOp->inputs.size() == 3) { | |||||
| attr->hasBias = true; | |||||
| auto bias_index = tfliteOp->inputs[2]; | |||||
| const auto &bias_tensor = tfliteTensors[bias_index]; | |||||
| if (bias_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "bias_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| attr->channelIn = weight_shape[3]; | |||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | |||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| std::vector<int> params; | |||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, | |||||
| attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return RET_ERROR; | |||||
| } else { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | op->primitive->value.type = schema::PrimitiveType_Conv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_CONV_PARSER_H | |||||
| #define PREDICT_TFLITE_CONV_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteConvParser() : TfliteNodeParser("Conv2D") {} | TfliteConvParser() : TfliteNodeParser("Conv2D") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_CONV_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | |||||
| @@ -19,6 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/converter.h" | #include "tools/converter/converter.h" | ||||
| #include "tools/converter/parser/tflite/tflite_model_parser.h" | #include "tools/converter/parser/tflite/tflite_model_parser.h" | ||||
| #include "tools/converter/graphdef_transform.h" | #include "tools/converter/graphdef_transform.h" | ||||
| @@ -17,14 +17,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_deconv_parser.h" | #include "tools/converter/parser/tflite/tflite_deconv_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | |||||
| std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | std::unique_ptr<schema::DeConv2DT> attr(new schema::DeConv2DT()); | ||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); | |||||
| if (tflite_attr == nullptr) { | if (tflite_attr == nullptr) { | ||||
| MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | |||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| @@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->dilateW = 1; | attr->dilateW = 1; | ||||
| attr->padMode = GetPadMode(tflite_attr->padding); | attr->padMode = GetPadMode(tflite_attr->padding); | ||||
| attr->format = schema::Format_NHWC; | attr->format = schema::Format_NHWC; | ||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| attr->hasBias = true; | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tfliteOp->inputs[1]; | |||||
| const auto &weight_tensor = tfliteTensors[weight_index]; | |||||
| auto weight_index = tflite_op->inputs[1]; | |||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "weight_tensor is null"; | |||||
| MS_LOG(ERROR) << "the weight tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->channelIn = weight_shape[3]; | |||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | |||||
| auto data_index = tflite_op->inputs[2]; | |||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| std::vector<int> params; | |||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, | |||||
| attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | } | ||||
| auto weight_shape = weight_tensor->shape; | |||||
| attr->channelIn = weight_shape[CHWK_K]; | |||||
| attr->channelOut = weight_shape[CHWK_C]; | |||||
| attr->kernelW = weight_shape[CHWK_W]; | |||||
| attr->kernelH = weight_shape[CHWK_H]; | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | op->primitive->value.type = schema::PrimitiveType_DeConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_DECONV_PARSER_H | |||||
| #define PREDICT_TFLITE_DECONV_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_op_set, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_DECONV_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H | |||||
| @@ -18,14 +18,19 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" | #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | |||||
| std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | std::unique_ptr<schema::DepthToSpaceT> attr(new schema::DepthToSpaceT()); | ||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions(); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); | |||||
| if (tflite_attr == nullptr) { | if (tflite_attr == nullptr) { | ||||
| MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->blockSize = tflite_attr->block_size; | attr->blockSize = tflite_attr->block_size; | ||||
| attr->format = schema::Format_NHWC; | attr->format = schema::Format_NHWC; | ||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| #define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantized_model) override; | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H | |||||
| @@ -17,65 +17,22 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" | #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/common/node_util.h" | #include "tools/common/node_util.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, | |||||
| const std::unique_ptr<schema::DepthwiseConv2DT> &attr, | |||||
| const std::unique_ptr<tflite::TensorT> &weightTensor, | |||||
| TensorCache *tensor_cache) { | |||||
| std::unique_ptr<schema::Conv2DT> convAttr(new schema::Conv2DT); | |||||
| convAttr->format = attr->format; | |||||
| convAttr->channelIn = attr->channelIn; | |||||
| convAttr->channelOut = attr->channelIn * attr->channelMultiplier; | |||||
| convAttr->kernelH = attr->kernelH; | |||||
| convAttr->kernelW = attr->kernelW; | |||||
| convAttr->strideH = attr->strideH; | |||||
| convAttr->strideW = attr->strideW; | |||||
| convAttr->padMode = attr->padMode; | |||||
| convAttr->padUp = attr->padUp; | |||||
| convAttr->padDown = attr->padDown; | |||||
| convAttr->padLeft = attr->padLeft; | |||||
| convAttr->padRight = attr->padRight; | |||||
| convAttr->dilateH = attr->dilateH; | |||||
| convAttr->dilateW = attr->dilateW; | |||||
| convAttr->hasBias = attr->hasBias; | |||||
| convAttr->activationType = attr->activationType; | |||||
| auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name); | |||||
| if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) { | |||||
| auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex]; | |||||
| if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) { | |||||
| // convert weight format KHWC -> CHWK | |||||
| auto status = TransFilterFormat<uint8_t>(liteWeightTensor, kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) { | |||||
| // convert weight format KHWC -> CHWK | |||||
| auto status = TransFilterFormat<float>(liteWeightTensor, kKHWC2CHWK); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = convAttr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | |||||
| std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); | std::unique_ptr<schema::DepthwiseConv2DT> attr(new schema::DepthwiseConv2DT()); | ||||
| const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); | ||||
| if (tflite_attr == nullptr) { | if (tflite_attr == nullptr) { | ||||
| @@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| attr->padMode = GetPadMode(tflite_attr->padding); | attr->padMode = GetPadMode(tflite_attr->padding); | ||||
| attr->format = schema::Format_NHWC; | attr->format = schema::Format_NHWC; | ||||
| attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); | ||||
| // get the conv op weight tensor | |||||
| auto input_index = tflite_op->inputs[0]; | |||||
| const auto &input_tenosr = tflite_tensors[input_index]; | |||||
| if (input_tenosr == nullptr) { | |||||
| MS_LOG(ERROR) << "the first input is null"; | |||||
| attr->hasBias = true; | |||||
| attr->channelMultiplier = tflite_attr->depth_multiplier; | |||||
| // get the data tensor | |||||
| auto data_index = tflite_op->inputs[1]; | |||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| if (data_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the data tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto input_shape = input_tenosr->shape; | |||||
| auto data_shape = data_tensor->shape; | |||||
| attr->channelIn = data_shape[3]; | |||||
| // get the weight tensor | |||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | const auto &weight_tensor = tflite_tensors[weight_index]; | ||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| @@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto weight_shape = weight_tensor->shape; | auto weight_shape = weight_tensor->shape; | ||||
| attr->channelIn = input_shape[KHWC_C]; | |||||
| attr->channelMultiplier = tflite_attr->depth_multiplier; | |||||
| attr->kernelH = weight_shape[KHWC_H]; | |||||
| attr->kernelW = weight_shape[KHWC_W]; | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | |||||
| attr->kernelH = weight_shape[1]; | |||||
| attr->kernelW = weight_shape[2]; | |||||
| // calculate pad params | |||||
| std::vector<int> params; | |||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, | |||||
| attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | |||||
| if (tflite_op->inputs.size() == 3) { | |||||
| attr->hasBias = true; | |||||
| auto bias_index = tflite_op->inputs[2]; | |||||
| const auto &bias_tensor = tflite_tensors[bias_index]; | |||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (attr->channelMultiplier > 1) { | |||||
| if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { | |||||
| MS_LOG(ERROR) << "Parse Group DepthwiseConv failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | } else { | ||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| attr->padUp = params.at(0); | |||||
| attr->padDown = params.at(1); | |||||
| attr->padLeft = params.at(2); | |||||
| attr->padRight = params.at(3); | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H | |||||
| #define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| private: | |||||
| STATUS ParseGroupDepthwiseConv(schema::CNodeT *op, | |||||
| const std::unique_ptr<schema::DepthwiseConv2DT> &attr, | |||||
| const std::unique_ptr<tflite::TensorT> &weightTensor, | |||||
| TensorCache *tensor_cache); | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_CONV_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H | |||||
| @@ -16,15 +16,20 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_dequantize_parser.h" | #include "tools/converter/parser/tflite/tflite_dequantize_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/common/node_util.h" | #include "tools/common/node_util.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | |||||
| std::unique_ptr<schema::CastT> attr(new schema::CastT); | std::unique_ptr<schema::CastT> attr(new schema::CastT); | ||||
| // get the dequantize input tensor | // get the dequantize input tensor | ||||
| const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; | |||||
| const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "weight_tensor is null"; | |||||
| MS_LOG(ERROR) << "input tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = dtype_map[in_tensor->type]; | |||||
| const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| MS_LOG(ERROR) << "output tensor is null"; | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->dstT = dtype_map[out_tensor->type]; | |||||
| std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| op->primitive->value.type = schema::PrimitiveType_Fp16Cast; | op->primitive->value.type = schema::PrimitiveType_Fp16Cast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return 0; | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | |||||
| } | } | ||||
| TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | ||||
| @@ -13,11 +13,12 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #define LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H | |||||
| @@ -17,16 +17,17 @@ | |||||
| #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" | #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | ||||
| std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); | std::unique_ptr<schema::ExpandDimsT> attr(new schema::ExpandDimsT()); | ||||
| const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions(); | |||||
| const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions(); | |||||
| if (tflite_attr == nullptr) { | if (tflite_attr == nullptr) { | ||||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -14,11 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| #define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) override; | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H | |||||
| @@ -1,75 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_fakequant_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | |||||
| std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT()); | |||||
| auto weight_index = tfliteOp->inputs[1]; | |||||
| const auto &weight_tensor = tfliteTensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "weight_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (tfliteOp->inputs.size() == 3) { | |||||
| attr->hasBias = true; | |||||
| auto bias_index = tfliteOp->inputs[2]; | |||||
| const auto &bias_tensor = tfliteTensors[bias_index]; | |||||
| if (bias_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "bias_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| attr->axis = 1; | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,39 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_FAKEQUANT_PARSER_H | |||||
| #define LITE_TFLITE_FAKEQUANT_PARSER_H | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteFakeQuantParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_FAKEQUANT_PARSER_H | |||||
| @@ -14,19 +14,22 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/tflite/tflite_fill_parser.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "tools/converter/parser/tflite/tflite_fill_parser.h" | |||||
| #include <map> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, | |||||
| bool quantizedModel) { | |||||
| STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | |||||
| std::unique_ptr<schema::FillT> attr(new schema::FillT()); | std::unique_ptr<schema::FillT> attr(new schema::FillT()); | ||||
| if (tfliteOp->inputs.size() > 1) { | |||||
| if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) { | |||||
| MS_LOG(ERROR) << "get Fill -> dims failed"; | |||||
| if (tflite_op->inputs.size() > 1) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) { | |||||
| MS_LOG(ERROR) << "get fill -> dims failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Fill; | op->primitive->value.type = schema::PrimitiveType_Fill; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||